Merge branch 'master' of git://yap.git.sourceforge.net/gitroot/yap/yap-6.3

This commit is contained in:
Denys Duchier 2012-04-16 21:51:54 +02:00
commit b66b261972
169 changed files with 8751 additions and 1277262 deletions

View File

@ -490,7 +490,6 @@ Yap_HasOp(Atom a)
OpEntry *
Yap_OpPropForModule(Atom a, Term mod)
{ /* look property list of atom a for kind */
CACHE_REGS
AtomEntry *ae = RepAtom(a);
PropEntry *pp;
OpEntry *info = NULL;
@ -767,6 +766,7 @@ ExpandPredHash(void)
Prop
Yap_NewPredPropByFunctor(FunctorEntry *fe, Term cur_mod)
{
CACHE_REGS
PredEntry *p = (PredEntry *) Yap_AllocAtomSpace(sizeof(*p));
if (p == NULL) {
@ -902,6 +902,7 @@ Yap_NewThreadPred(PredEntry *ap USES_REGS)
Prop
Yap_NewPredPropByAtom(AtomEntry *ae, Term cur_mod)
{
CACHE_REGS
Prop p0;
PredEntry *p = (PredEntry *) Yap_AllocAtomSpace(sizeof(*p));

View File

@ -2053,6 +2053,7 @@ a_try(op_numbers opcode, CELL lab, CELL opr, int nofalts, int hascut, yamop *cod
yamop *newcp;
/* emit a special instruction and then a label for backpatching */
if (pass_no) {
CACHE_REGS
UInt size = (UInt)NEXTOP((yamop *)NULL,OtaLl);
if ((newcp = (yamop *)Yap_AllocCodeSpace(size)) == NULL) {
/* OOOPS, got in trouble, must do a longjmp and recover space */

View File

@ -561,6 +561,7 @@ X_API YAP_tag_t STD_PROTO(YAP_TagOfTerm,(Term));
X_API size_t STD_PROTO(YAP_ExportTerm,(Term, char *, size_t));
X_API size_t STD_PROTO(YAP_SizeOfExportedTerm,(char *));
X_API Term STD_PROTO(YAP_ImportTerm,(char *));
X_API int STD_PROTO(YAP_RequiresExtraStack,(size_t));
static UInt
current_arity(void)
@ -2705,7 +2706,6 @@ YAP_InitConsult(int mode, char *filename)
X_API IOSTREAM *
YAP_TermToStream(Term t)
{
CACHE_REGS
IOSTREAM *s;
BACKUP_MACHINE_REGS();
@ -2937,7 +2937,13 @@ YAP_Init(YAP_init_args *yap_init)
int restore_result;
int do_bootstrap = (yap_init->YapPrologBootFile != NULL);
CELL Trail = 0, Stack = 0, Heap = 0, Atts = 0;
static char boot_file[256];
char boot_file[256];
static int initialised = FALSE;
/* ignore repeated calls to YAP_Init */
if (initialised)
return YAP_BOOT_DONE_BEFOREHAND;
initialised = TRUE;
Yap_InitPageSize(); /* init memory page size, required by later functions */
#if defined(YAPOR_COPY) || defined(YAPOR_COW) || defined(YAPOR_SBA)
@ -3612,9 +3618,22 @@ YAP_ListToFloats(Term t, double *dblp, size_t sz)
if (!IsPairTerm(t))
return -1;
hd = HeadOfTerm(t);
if (!IsFloatTerm(hd))
return -1;
dblp[i++] = FloatOfTerm(hd);
if (IsFloatTerm(hd)) {
dblp[i++] = FloatOfTerm(hd);
} else {
extern double Yap_gmp_to_float(Term hd);
if (IsIntTerm(hd))
dblp[i++] = IntOfTerm(hd);
else if (IsLongIntTerm(hd))
dblp[i++] = LongIntOfTerm(hd);
#if USE_GMP
else if (IsBigIntTerm(hd))
dblp[i++] = Yap_gmp_to_float(hd);
#endif
else
return -1;
}
if (i == sz)
return sz;
t = TailOfTerm(t);
@ -4108,3 +4127,24 @@ YAP_ImportTerm(char * buf) {
return Yap_ImportTerm(buf);
}
X_API int
YAP_RequiresExtraStack(size_t sz) {
CACHE_REGS
if (sz < 16*1024)
sz = 16*1024;
if (H <= ASP-sz) {
return FALSE;
}
BACKUP_H();
while (H > ASP-sz) {
CACHE_REGS
RECOVER_H();
if (!dogc( 0, NULL PASS_REGS )) {
return -1;
}
BACKUP_H();
}
RECOVER_H();
return TRUE;
}

View File

@ -5107,6 +5107,8 @@ p_continue_static_clause( USES_REGS1 )
static void
add_code_in_lu_index(LogUpdIndex *cl, PredEntry *pp)
{
CACHE_REGS
char *code_end = (char *)cl + cl->ClSize;
Yap_inform_profiler_of_clause(cl, code_end, pp, GPROF_LU_INDEX);
cl = cl->ChildIndex;
@ -5119,6 +5121,7 @@ add_code_in_lu_index(LogUpdIndex *cl, PredEntry *pp)
static void
add_code_in_static_index(StaticIndex *cl, PredEntry *pp)
{
CACHE_REGS
char *code_end = (char *)cl + cl->ClSize;
Yap_inform_profiler_of_clause(cl, code_end, pp, GPROF_STATIC_INDEX);
cl = cl->ChildIndex;
@ -5131,6 +5134,7 @@ add_code_in_static_index(StaticIndex *cl, PredEntry *pp)
static void
add_code_in_pred(PredEntry *pp) {
CACHE_REGS
yamop *clcode;
PELOCK(49,pp);
@ -5202,6 +5206,7 @@ add_code_in_pred(PredEntry *pp) {
void
Yap_dump_code_area_for_profiler(void) {
CACHE_REGS
ModEntry *me = CurrentModules;
while (me) {

View File

@ -1887,6 +1887,7 @@ Yap_new_ludbe(Term t, PredEntry *pe, UInt nargs)
static LogUpdClause *
record_lu(PredEntry *pe, Term t, int position)
{
CACHE_REGS
LogUpdClause *cl;
if ((cl = new_lu_db_entry(t, pe)) == NULL) {

View File

@ -1066,6 +1066,7 @@ do_goal(Term t, yamop *CodeAdr, int arity, CELL *pt, int top USES_REGS)
S = CellPtr (RepPredProp (PredPropByFunc (Yap_MkFunctor(AtomCall, 1),0))); /* A1 mishaps */
out = exec_absmi(top PASS_REGS);
Yap_flush();
// if (out) {
// out = Yap_GetFromSlot(sl);
// }

View File

@ -168,6 +168,7 @@ RBfree(rb_red_blk_node *ptr)
static rb_red_blk_node *
RBTreeCreate(void) {
CACHE_REGS
rb_red_blk_node* temp;
/* see the comment in the rb_red_blk_tree structure in red_black_tree.h */
@ -210,6 +211,7 @@ RBTreeCreate(void) {
static void
LeftRotate(rb_red_blk_node* x) {
CACHE_REGS
rb_red_blk_node* y;
rb_red_blk_node* nil=LOCAL_ProfilerNil;
@ -266,6 +268,7 @@ LeftRotate(rb_red_blk_node* x) {
static void
RightRotate(rb_red_blk_node* y) {
CACHE_REGS
rb_red_blk_node* x;
rb_red_blk_node* nil=LOCAL_ProfilerNil;
@ -318,6 +321,7 @@ RightRotate(rb_red_blk_node* y) {
static void
TreeInsertHelp(rb_red_blk_node* z) {
CACHE_REGS
/* This function should only be called by InsertRBTree (see above) */
rb_red_blk_node* x;
rb_red_blk_node* y;
@ -369,6 +373,7 @@ TreeInsertHelp(rb_red_blk_node* z) {
static rb_red_blk_node *
RBTreeInsert(yamop *key, yamop *lim) {
CACHE_REGS
rb_red_blk_node * y;
rb_red_blk_node * x;
rb_red_blk_node * newNode;
@ -440,6 +445,7 @@ RBTreeInsert(yamop *key, yamop *lim) {
static rb_red_blk_node*
RBExactQuery(yamop* q) {
CACHE_REGS
rb_red_blk_node* x;
rb_red_blk_node* nil=LOCAL_ProfilerNil;
@ -460,6 +466,7 @@ RBExactQuery(yamop* q) {
static rb_red_blk_node*
RBLookup(yamop *entry) {
CACHE_REGS
rb_red_blk_node *current;
if (!LOCAL_ProfilerRoot)
@ -495,6 +502,7 @@ RBLookup(yamop *entry) {
/***********************************************************************/
static void RBDeleteFixUp(rb_red_blk_node* x) {
CACHE_REGS
rb_red_blk_node* root=LOCAL_ProfilerRoot->left;
rb_red_blk_node *w;
@ -574,6 +582,7 @@ static void RBDeleteFixUp(rb_red_blk_node* x) {
static rb_red_blk_node*
TreeSuccessor(rb_red_blk_node* x) {
CACHE_REGS
rb_red_blk_node* y;
rb_red_blk_node* nil=LOCAL_ProfilerNil;
rb_red_blk_node* root=LOCAL_ProfilerRoot;
@ -612,6 +621,7 @@ TreeSuccessor(rb_red_blk_node* x) {
static void
RBDelete(rb_red_blk_node* z){
CACHE_REGS
rb_red_blk_node* y;
rb_red_blk_node* x;
rb_red_blk_node* nil=LOCAL_ProfilerNil;
@ -664,7 +674,8 @@ RBDelete(rb_red_blk_node* z){
char *set_profile_dir(char *);
char *set_profile_dir(char *name){
int size=0;
CACHE_REGS
int size=0;
if (name!=NULL) {
size=strlen(name)+1;
@ -687,8 +698,9 @@ return LOCAL_DIRNAME;
char *profile_names(int);
char *profile_names(int k) {
static char *FNAME=NULL;
int size=200;
CACHE_REGS
static char *FNAME=NULL;
int size=200;
if (LOCAL_DIRNAME==NULL) set_profile_dir(NULL);
size=strlen(LOCAL_DIRNAME)+40;
@ -709,6 +721,7 @@ int size=200;
void del_profile_files(void);
void del_profile_files() {
CACHE_REGS
if (LOCAL_DIRNAME!=NULL) {
remove(profile_names(PROFPREDS_FILE));
remove(profile_names(PROFILING_FILE));
@ -717,6 +730,7 @@ void del_profile_files() {
void
Yap_inform_profiler_of_clause__(void *code_start, void *code_end, PredEntry *pe,gprof_info index_code) {
CACHE_REGS
buf_ptr b;
buf_extra e;
LOCAL_ProfOn = TRUE;
@ -742,6 +756,7 @@ static Int profend( USES_REGS1 );
static void
clean_tree(rb_red_blk_node* node) {
CACHE_REGS
if (node == LOCAL_ProfilerNil)
return;
clean_tree(node->left);
@ -751,6 +766,7 @@ clean_tree(rb_red_blk_node* node) {
static void
reset_tree(void) {
CACHE_REGS
clean_tree(LOCAL_ProfilerRoot);
Yap_FreeCodeSpace((char *)LOCAL_ProfilerNil);
LOCAL_ProfilerNil = LOCAL_ProfilerRoot = NULL;
@ -760,6 +776,7 @@ reset_tree(void) {
static int
InitProfTree(void)
{
CACHE_REGS
if (LOCAL_ProfilerRoot)
reset_tree();
while (!(LOCAL_ProfilerRoot = RBTreeCreate())) {
@ -773,6 +790,7 @@ InitProfTree(void)
static void RemoveCode(CODEADDR clau)
{
CACHE_REGS
rb_red_blk_node* x, *node;
PredEntry *pp;
UInt count;
@ -958,6 +976,7 @@ prof_alrm(int signo, siginfo_t *si, void *scv)
void
Yap_InformOfRemoval(void *clau)
{
CACHE_REGS
LOCAL_ProfOn = TRUE;
if (LOCAL_FPreds != NULL) {
/* just store info about what is going on */
@ -1048,6 +1067,7 @@ static Int profinit( USES_REGS1 )
static Int start_profilers(int msec)
{
CACHE_REGS
struct itimerval t;
struct sigaction sa;
@ -1157,6 +1177,7 @@ static Int profres0( USES_REGS1 ) {
void
Yap_InitLowProf(void)
{
CACHE_REGS
#if LOW_PROF
LOCAL_ProfCalls = 0;
LOCAL_ProfilerOn = FALSE;

View File

@ -718,6 +718,11 @@ AdjustScannerStacks(TokEntry **tksp, VarEntry **vep USES_REGS)
TokEntry *tktmp;
switch (tks->Tok) {
case Number_tok:
if (IsApplTerm(tks->TokInfo)) {
tks->TokInfo = AdjustAppl(tks->TokInfo PASS_REGS);
}
break;
case Var_tok:
case String_tok:
if (IsOldTrail(tks->TokInfo))

View File

@ -1888,6 +1888,7 @@ emit_single_switch_case(ClauseDef *min, struct intermediates *cint, int first, i
static UInt
suspend_indexing(ClauseDef *min, ClauseDef *max, PredEntry *ap, struct intermediates *cint)
{
CACHE_REGS
UInt tcls = ap->cs.p_code.NOfClauses;
UInt cls = (max-min)+1;

View File

@ -993,6 +993,15 @@ p_read_module_preds( USES_REGS1 )
return TRUE;
}
static void
ReInitCatch(void)
{
Term t = Yap_MkNewApplTerm(PredHandleThrow->FunctorOfPred, PredHandleThrow->ArityOfPE);
YAP_RunGoalOnce(t);
}
static Int
p_read_program( USES_REGS1 )
{
@ -1016,7 +1025,7 @@ p_read_program( USES_REGS1 )
Sclose( stream );
/* back to the top level we go */
Yap_CloseSlots(PASS_REGS1);
ReInitCatch();
Yap_RestartYap( 3 );
return TRUE;
}
@ -1030,6 +1039,7 @@ Yap_Restore(char *s, char *lib_dir)
return -1;
read_module(stream);
Sclose( stream );
ReInitCatch();
return DO_ONLY_CODE;
}

View File

@ -1619,6 +1619,7 @@ InteractSIGINT(int ch) {
static int
ProcessSIGINT(void)
{
CACHE_REGS
int ch, out;
LOCAL_PrologMode |= AsyncIntMode;

View File

@ -52,17 +52,7 @@ send_tracer_message(char *start, char *name, Int arity, char *mname, CELL *args)
if (args) {
for (i= 0; i < arity; i++) {
if (i > 0) fprintf(GLOBAL_stderr, ",");
#if DEBUG
#if COROUTINING
Yap_Portray_delays = TRUE;
#endif
#endif
Yap_plwrite(args[i], NULL, 15, Handle_vars_f, 1200);
#if DEBUG
#if COROUTINING
Yap_Portray_delays = FALSE;
#endif
#endif
Yap_plwrite(args[i], NULL, 15, Handle_vars_f|AttVar_Portray_f, 1200);
}
if (arity) {
fprintf(GLOBAL_stderr, ")");

View File

@ -4255,7 +4255,7 @@ p_is_list_or_partial_list( USES_REGS1 )
}
static Term
numbervar(Int id)
numbervar(Int id USES_REGS)
{
Term ts[1];
ts[0] = MkIntegerTerm(id);
@ -4263,7 +4263,7 @@ numbervar(Int id)
}
static Term
numbervar_singleton(void)
numbervar_singleton(USES_REGS1)
{
Term ts[1];
ts[0] = MkIntegerTerm(-1);
@ -4356,9 +4356,9 @@ static Int numbervars_in_complex_term(register CELL *pt0, register CELL *pt0_end
derefa_body(d0, ptd0, vars_in_term_unk, vars_in_term_nvar);
/* do or pt2 are unbound */
if (singles)
*ptd0 = numbervar_singleton();
*ptd0 = numbervar_singleton( PASS_REGS1 );
else
*ptd0 = numbervar(numbv++);
*ptd0 = numbervar(numbv++ PASS_REGS);
/* leave an empty slot to fill in later */
if (H+1024 > ASP) {
goto global_overflow;
@ -4450,10 +4450,10 @@ Yap_NumberVars( Term inp, Int numbv, int handle_singles ) /* numbervariables in
CELL *ptd0 = VarOfTerm(t);
TrailTerm(TR++) = (CELL)ptd0;
if (handle_singles) {
*ptd0 = numbervar_singleton();
*ptd0 = numbervar_singleton( PASS_REGS1 );
return numbv;
} else {
*ptd0 = numbervar(numbv);
*ptd0 = numbervar(numbv PASS_REGS);
return numbv+1;
}
} else if (IsPrimitiveTerm(t)) {

164
C/write.c
View File

@ -66,7 +66,7 @@ typedef struct rewind_term {
typedef struct write_globs {
void *stream;
int Quote_illegal, Ignore_ops, Handle_vars, Use_portray;
int Quote_illegal, Ignore_ops, Handle_vars, Use_portray, Portray_delays;
int Keep_terms;
int Write_Loops;
int Write_strings;
@ -90,28 +90,50 @@ STATIC_PROTO(void writeTerm, (Term, int, int, int, struct write_globs *, struct
#define wrputc(X,WF) Sputcode(X,WF) /* writes a character */
/*
protect bracket from merging with previoous character.
avoid stuff like not (2,3) -> not(2,3) or
*/
static void
protect_open_number(struct write_globs *wglb, int minus_required)
wropen_bracket(struct write_globs *wglb, int protect)
{
wrf stream = wglb->stream;
if (lastw == symbol && last_minus && !minus_required) {
if (!wglb->Ignore_ops) {
/* protect against collating - with number, and getting - 1 ^2 as (-(1))^2 */
wrputc(' ', wglb->stream);
}
wrputc('(', wglb->stream);
} else if (lastw == alphanum) {
wrputc(' ', stream);
}
if (lastw != separator && protect)
wrputc(' ', stream);
wrputc('(', stream);
lastw = separator;
}
static void
protect_close_number(struct write_globs *wglb, int minus_required)
wrclose_bracket(struct write_globs *wglb, int protect)
{
if (lastw == symbol && last_minus && !minus_required) {
wrputc(')', wglb->stream);
lastw = separator;
wrf stream = wglb->stream;
wrputc(')', stream);
lastw = separator;
}
static int
protect_open_number(struct write_globs *wglb, int lm, int minus_required)
{
wrf stream = wglb->stream;
if (lastw == symbol && lm && !minus_required) {
wropen_bracket(wglb, TRUE);
return TRUE;
} else if (lastw == alphanum ||
(lastw == symbol && minus_required)) {
wrputc(' ', stream);
}
return FALSE;
}
static void
protect_close_number(struct write_globs *wglb, int used_bracket)
{
if (used_bracket) {
wrclose_bracket(wglb, TRUE);
} else {
lastw = alphanum;
}
@ -125,8 +147,9 @@ wrputn(Int n, struct write_globs *wglb) /* writes an integer */
wrf stream = wglb->stream;
char s[256], *s1=s; /* that should be enough for most integers */
int has_minus = (n < 0);
int ob;
protect_open_number(wglb, has_minus);
ob = protect_open_number(wglb, last_minus, has_minus);
#if HAVE_SNPRINTF
snprintf(s, 256, Int_FORMAT, n);
#else
@ -134,7 +157,7 @@ wrputn(Int n, struct write_globs *wglb) /* writes an integer */
#endif
while (*s1)
wrputc(*s1++, stream);
protect_close_number(wglb, has_minus);
protect_close_number(wglb, ob);
}
#define wrputs(s, stream) Sfputs(s, stream)
@ -190,9 +213,10 @@ static void
write_mpint(MP_INT *big, struct write_globs *wglb) {
char *s;
int has_minus = mpz_sgn(big);
int ob;
s = ensure_space(3+mpz_sizeinbase(big, 10));
protect_open_number(wglb, has_minus);
ob = protect_open_number(wglb, last_minus, has_minus);
if (!s) {
s = mpz_get_str(NULL, 10, big);
if (!s)
@ -203,7 +227,7 @@ write_mpint(MP_INT *big, struct write_globs *wglb) {
mpz_get_str(s, 10, big);
wrputs(s,wglb->stream);
}
protect_close_number(wglb, has_minus);
protect_close_number(wglb, ob);
}
#endif
@ -271,6 +295,8 @@ wrputf(Float f, struct write_globs *wglb) /* writes a float */
char s[256];
wrf stream = wglb->stream;
int sgn;
int ob;
#if HAVE_ISNAN || defined(__WIN32)
if (isnan(f)) {
@ -291,7 +317,7 @@ wrputf(Float f, struct write_globs *wglb) /* writes a float */
return;
}
#endif
protect_open_number(wglb, sgn);
ob = protect_open_number(wglb, last_minus, sgn);
#if THREADS
/* old style writing */
int found_dot = FALSE, found_exp = FALSE;
@ -343,7 +369,7 @@ wrputf(Float f, struct write_globs *wglb) /* writes a float */
if (!buf) return;
wrputs(buf, stream);
#endif
protect_close_number(wglb, sgn);
protect_close_number(wglb, ob);
}
/* writes a data base reference */
@ -423,7 +449,7 @@ AtomIsSymbols(unsigned char *s) /* Is this atom just formed by symbols ? */
return(separator);
while ((ch = *s++) != '\0') {
if (Yap_chtype[ch] != SY)
return(alphanum);
return alphanum;
}
return symbol;
}
@ -669,15 +695,14 @@ write_var(CELL *t, struct write_globs *wglb, struct rewind_term *rwt)
/* make sure we don't get no creepy spaces where they shouldn't be */
lastw = separator;
if (IsAttVar(t)) {
#if defined(COROUTINING) && defined(DEBUG)
Int vcount = (t-H0);
if (Yap_Portray_delays) {
if (wglb->Portray_delays) {
exts ext = ExtFromCell(t);
struct rewind_term nrwt;
nrwt.parent = rwt;
nrwt.u.s.ptr = 0;
Yap_Portray_delays = FALSE;
wglb->Portray_delays = FALSE;
if (ext == attvars_ext) {
attvar_record *attv = RepAttVar(t);
CELL *l = &attv->Value; /* dirty low-level hack, check atts.h */
@ -691,14 +716,13 @@ write_var(CELL *t, struct write_globs *wglb, struct rewind_term *rwt)
l += 2;
writeTerm(from_pointer(l, &nrwt, wglb), 999, 1, FALSE, wglb, &nrwt);
restore_from_write(&nrwt, wglb);
wrputc(')', wglb->stream);
wrclose_bracket(wglb, TRUE);
}
Yap_Portray_delays = TRUE;
wglb->Portray_delays = TRUE;
return;
}
wrputc('D', wglb->stream);
wrputn(vcount,wglb);
#endif
} else {
wrputn(((Int) (t- H0)),wglb);
}
@ -790,6 +814,7 @@ write_list(Term t, int direction, int depth, struct write_globs *wglb, struct re
}
}
static void
writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, struct rewind_term *rwt)
/* term to write */
@ -823,8 +848,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
wrputs(",",wglb->stream);
writeTerm(from_pointer(RepPair(t)+1, &nrwt, wglb), 999, depth + 1, FALSE, wglb, &nrwt);
restore_from_write(&nrwt, wglb);
wrputc(')', wglb->stream);
lastw = separator;
wrclose_bracket(wglb, TRUE);
return;
}
if (wglb->Use_portray) {
@ -886,7 +910,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
int argno = 1;
CELL *p = ArgsOfSFTerm(t);
putAtom(atom, wglb->Quote_illegal, wglb);
wrputc('(', wglb->stream);
wropen_bracket(wglb, FALSE);
lastw = separator;
while (*p) {
Int sl = 0;
@ -904,8 +928,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
wrputc(',', wglb->stream);
argno++;
}
wrputc(')', wglb->stream);
lastw = separator;
wrclose_bracket(wglb, TRUE);
return;
}
#endif
@ -934,28 +957,22 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
!IsVarTerm(tright) && IsAtomTerm(tright) &&
Yap_IsOp(AtomOfTerm(tright));
if (op > p) {
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */
if (lastw != separator && !rinfixarg)
wrputc(' ', wglb->stream);
wrputc('(', wglb->stream);
lastw = separator;
wropen_bracket(wglb, TRUE);
}
putAtom(atom, wglb->Quote_illegal, wglb);
if (bracket_right) {
wrputc('(', wglb->stream);
lastw = separator;
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */
wropen_bracket(wglb, TRUE);
} else if (atom == AtomMinus) {
last_minus = TRUE;
}
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), rp, depth + 1, FALSE, wglb, &nrwt);
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), rp, depth + 1, TRUE, wglb, &nrwt);
restore_from_write(&nrwt, wglb);
if (bracket_right) {
wrputc(')', wglb->stream);
lastw = separator;
wrclose_bracket(wglb, TRUE);
}
if (op > p) {
wrputc(')', wglb->stream);
lastw = separator;
wrclose_bracket(wglb, TRUE);
}
} else if (!wglb->Ignore_ops &&
Arity == 1 &&
@ -963,29 +980,24 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
Term tleft = ArgOfTerm(1, t);
int bracket_left =
!IsVarTerm(tleft) && IsAtomTerm(tleft) &&
!IsVarTerm(tleft) &&
IsAtomTerm(tleft) &&
Yap_IsOp(AtomOfTerm(tleft));
if (op > p) {
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */
if (lastw != separator && !rinfixarg)
wrputc(' ', wglb->stream);
wrputc('(', wglb->stream);
lastw = separator;
wropen_bracket(wglb, TRUE);
}
if (bracket_left) {
wrputc('(', wglb->stream);
lastw = separator;
wropen_bracket(wglb, TRUE);
}
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), lp, depth + 1, rinfixarg, wglb, &nrwt);
restore_from_write(&nrwt, wglb);
if (bracket_left) {
wrputc(')', wglb->stream);
lastw = separator;
wrclose_bracket(wglb, TRUE);
}
putAtom(atom, wglb->Quote_illegal, wglb);
if (op > p) {
wrputc(')', wglb->stream);
lastw = separator;
wrclose_bracket(wglb, TRUE);
}
} else if (!wglb->Ignore_ops &&
Arity == 2 && Yap_IsInfixOp(atom, &op, &lp,
@ -1001,41 +1013,36 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
if (op > p) {
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */
if (lastw != separator && !rinfixarg)
wrputc(' ', wglb->stream);
wrputc('(', wglb->stream);
wropen_bracket(wglb, TRUE);
lastw = separator;
}
if (bracket_left) {
wrputc('(', wglb->stream);
lastw = separator;
wropen_bracket(wglb, TRUE);
}
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), lp, depth + 1, rinfixarg, wglb, &nrwt);
t = AbsAppl(restore_from_write(&nrwt, wglb)-1);
if (bracket_left) {
wrputc(')', wglb->stream);
lastw = separator;
wrclose_bracket(wglb, TRUE);
}
/* avoid quoting commas */
if (strcmp(RepAtom(atom)->StrOfAE,","))
putAtom(atom, wglb->Quote_illegal, wglb);
else {
/* avoid quoting commas and bars */
if (!strcmp(RepAtom(atom)->StrOfAE,",")) {
wrputc(',', wglb->stream);
lastw = separator;
}
if (bracket_right) {
wrputc('(', wglb->stream);
} else if (!strcmp(RepAtom(atom)->StrOfAE,"|")) {
wrputc('|', wglb->stream);
lastw = separator;
} else
putAtom(atom, wglb->Quote_illegal, wglb);
if (bracket_right) {
wropen_bracket(wglb, TRUE);
}
writeTerm(from_pointer(RepAppl(t)+2, &nrwt, wglb), rp, depth + 1, TRUE, wglb, &nrwt);
restore_from_write(&nrwt, wglb);
if (bracket_right) {
wrputc(')', wglb->stream);
lastw = separator;
wrclose_bracket(wglb, TRUE);
}
if (op > p) {
wrputc(')', wglb->stream);
lastw = separator;
wrclose_bracket(wglb, TRUE);
}
} else if (wglb->Handle_vars && functor == LOCAL_FunctorVar) {
Term ti = ArgOfTerm(1, t);
@ -1068,8 +1075,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
lastw = separator;
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), 999, depth + 1, FALSE, wglb, &nrwt);
restore_from_write(&nrwt, wglb);
wrputc(')', wglb->stream);
lastw = separator;
wrclose_bracket(wglb, TRUE);
}
} else if (!wglb->Ignore_ops && functor == FunctorBraces) {
wrputc('{', wglb->stream);
@ -1098,7 +1104,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
} else {
putAtom(atom, wglb->Quote_illegal, wglb);
lastw = separator;
wrputc('(', wglb->stream);
wropen_bracket(wglb, FALSE);
for (op = 1; op <= Arity; ++op) {
if (op == wglb->MaxArgs) {
wrputc('.', wglb->stream);
@ -1113,8 +1119,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
lastw = separator;
}
}
wrputc(')', wglb->stream);
lastw = separator;
wrclose_bracket(wglb, TRUE);
}
}
}
@ -1138,6 +1143,7 @@ Yap_plwrite(Term t, void *mywrite, int max_depth, int flags, int priority)
wglb.Quote_illegal = flags & Quote_illegal_f;
wglb.Handle_vars = flags & Handle_vars_f;
wglb.Use_portray = flags & Use_portray_f;
wglb.Portray_delays = flags & AttVar_Portray_f;
wglb.MaxDepth = max_depth;
wglb.MaxArgs = max_depth;
/* notice: we must have ASP well set when using portray, otherwise

View File

@ -498,6 +498,7 @@ void STD_PROTO(Yap_init_optyap_preds,(void));
/* pl-file.c */
struct PL_local_data *Yap_InitThreadIO(int wid);
void Yap_flush(void);
static inline
yamop *

View File

@ -159,6 +159,9 @@ typedef enum
#ifdef HAVE_LOCALE_H
#include <locale.h>
#endif
#ifdef HAVE_LIMITS_H /* get MAXPATHLEN */
#include <limits.h>
#endif
#include <setjmp.h>
#include <assert.h>
#if HAVE_SYS_PARAM_H

View File

@ -323,12 +323,6 @@ int STD_PROTO(Yap_growtrail_in_parser, (tr_fr_ptr *, TokEntry **, VarEntry **)
extern int errno;
#endif
#ifdef DEBUG
#if COROUTINING
extern int Yap_Portray_delays;
#endif
#endif
EXTERN inline UInt STD_PROTO(HashFunction, (unsigned char *));
EXTERN inline UInt STD_PROTO(WideHashFunction, (wchar_t *));

View File

@ -710,6 +710,7 @@ all: startup.yss
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE))
@ENABLE_CPLINT@ (cd packages/cplint/slipcase; $(MAKE))
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE))
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE))
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE))
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE))
@ENABLE_JPL@ @INSTALL_DLLS@ (cd packages/jpl; $(MAKE))
@ -788,6 +789,7 @@ install_unix: startup.yss libYap.a
@ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE) install)
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) install)
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) install)
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE) install)
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) install)
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) install)
@ -840,6 +842,7 @@ install_win32: startup.yss @ENABLE_WINCONSOLE@ pl-yap@EXEC_SUFFIX@
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) install)
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) install)
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) install)
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE) install)
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) install)
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) install)
@ -904,6 +907,7 @@ clean: clean_docs
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) clean)
@ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE) clean)
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) clean)
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE) clean)
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) clean)
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) clean)
@ENABLE_JPL@ @INSTALL_DLLS@ (cd packages/jpl; $(MAKE) clean)

View File

@ -257,6 +257,10 @@
#undef HAVE_WAITPID
#undef HAVE_MPZ_XOR
#if HAVE_GETHOSTNAME==1
#define HAS_GETHOSTNAME 1
#endif
#undef HAVE_SIGINFO
#undef HAVE_SIGSEGV
#undef HAVE_SIGPROF

22
configure vendored
View File

@ -625,6 +625,7 @@ ENABLE_REAL
ENABLE_MINISAT
CUDD_CPPFLAGS
CUDD_LDFLAGS
ENABLE_BDDLIB
ENABLE_CUDD
EXTRA_INCLUDES_FOR_WIN32
ENABLE_WINCONSOLE
@ -788,6 +789,7 @@ enable_depth_limit
enable_wam_profile
enable_low_level_tracer
enable_threads
enable_bddlib
enable_pthread_locking
enable_max_performance
enable_max_memory
@ -1462,6 +1464,7 @@ Optional Features:
--enable-wam-profile support low level profiling of abstract machine
--enable-low-level-tracer support support for procedure-call tracing
--enable-threads support system threads
--enable-bddlib dynamic bdd library
--enable-pthread-locking use pthread locking primitives for internal locking (requires threads)
--enable-max-performance try using the best flags for specific architecture
--enable-max-memory try using the best flags for using the memory to the most
@ -4486,6 +4489,13 @@ else
threads=no
fi
# Check whether --enable-bddlib was given.
if test "${enable_bddlib+set}" = set; then :
enableval=$enable_bddlib; dynamic_bdd="$enableval"
else
dynamic_bdd=no
fi
# Check whether --enable-pthread-locking was given.
if test "${enable_pthread_locking+set}" = set; then :
enableval=$enable_pthread_locking; pthreadlocking="$enableval"
@ -5000,7 +5010,14 @@ fi
if test "$yap_cv_cudd" = no
then
ENABLE_CUDD="@# "
ENABLE_BDDLIB="@# "
else
if test "$dynamic_bdd" = yes
then
ENABLE_BDDLIB=""
else
ENABLE_BDDLIB="@# "
fi
ENABLE_CUDD=""
fi
@ -9220,6 +9237,7 @@ fi
{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for gcc threaded code" >&5
@ -10454,6 +10472,7 @@ mkdir -p LGPL/clp
mkdir -p LGPL/swi_console
mkdir -p GPL
mkdir -p packages/
mkdir -p packages/bdd
mkdir -p packages/clib
mkdir -p packages/clib/sha1
mkdir -p packages/clib/maildrop
@ -10616,6 +10635,8 @@ ac_config_files="$ac_config_files packages/zlib/Makefile"
fi
if test "$ENABLE_CUDD" = ""; then
ac_config_files="$ac_config_files packages/bdd/Makefile"
ac_config_files="$ac_config_files packages/ProbLog/simplecudd/Makefile"
ac_config_files="$ac_config_files packages/ProbLog/simplecudd_lfi/Makefile"
@ -11398,6 +11419,7 @@ do
"packages/semweb/Makefile") CONFIG_FILES="$CONFIG_FILES packages/semweb/Makefile" ;;
"packages/sgml/Makefile") CONFIG_FILES="$CONFIG_FILES packages/sgml/Makefile" ;;
"packages/zlib/Makefile") CONFIG_FILES="$CONFIG_FILES packages/zlib/Makefile" ;;
"packages/bdd/Makefile") CONFIG_FILES="$CONFIG_FILES packages/bdd/Makefile" ;;
"packages/ProbLog/simplecudd/Makefile") CONFIG_FILES="$CONFIG_FILES packages/ProbLog/simplecudd/Makefile" ;;
"packages/ProbLog/simplecudd_lfi/Makefile") CONFIG_FILES="$CONFIG_FILES packages/ProbLog/simplecudd_lfi/Makefile" ;;
"packages/swi-minisat2/Makefile") CONFIG_FILES="$CONFIG_FILES packages/swi-minisat2/Makefile" ;;

View File

@ -158,6 +158,9 @@ AC_ARG_ENABLE(low-level-tracer,
AC_ARG_ENABLE(threads,
[ --enable-threads support system threads ],
threads="$enableval", threads=no)
AC_ARG_ENABLE(bddlib,
[ --enable-bddlib dynamic bdd library ],
dynamic_bdd="$enableval", dynamic_bdd=no)
AC_ARG_ENABLE(pthread-locking,
[ --enable-pthread-locking use pthread locking primitives for internal locking (requires threads) ],
pthreadlocking="$enableval", pthreadlocking=no)
@ -510,7 +513,14 @@ fi
if test "$yap_cv_cudd" = no
then
ENABLE_CUDD="@# "
ENABLE_BDDLIB="@# "
else
if test "$dynamic_bdd" = yes
then
ENABLE_BDDLIB=""
else
ENABLE_BDDLIB="@# "
fi
ENABLE_CUDD=""
fi
@ -1789,6 +1799,7 @@ AC_SUBST(ENABLE_WINCONSOLE)
AC_SUBST(EXTRA_INCLUDES_FOR_WIN32)
AC_SUBST(ENABLE_CUDD)
AC_SUBST(ENABLE_BDDLIB)
AC_SUBST(CUDD_LDFLAGS)
AC_SUBST(CUDD_CPPFLAGS)
AC_SUBST(ENABLE_MINISAT)
@ -2269,6 +2280,7 @@ mkdir -p LGPL/clp
mkdir -p LGPL/swi_console
mkdir -p GPL
mkdir -p packages/
mkdir -p packages/bdd
mkdir -p packages/clib
mkdir -p packages/clib/sha1
mkdir -p packages/clib/maildrop
@ -2392,6 +2404,7 @@ AC_CONFIG_FILES([packages/zlib/Makefile])
fi
if test "$ENABLE_CUDD" = ""; then
AC_CONFIG_FILES([packages/bdd/Makefile])
AC_CONFIG_FILES([packages/ProbLog/simplecudd/Makefile])
AC_CONFIG_FILES([packages/ProbLog/simplecudd_lfi/Makefile])
fi

View File

@ -75,6 +75,9 @@
#if HAVE_STRING_H
#include <string.h>
#endif
#if HAVE_IEEEFP_H
#include <ieeefp.h>
#endif
static void PROTO(do_top_goal,(YAP_Term));
static void PROTO(exec_top_level,(int, YAP_init_args *));

View File

@ -12635,6 +12635,13 @@ The path @var{Path} is a path starting at vertex @var{Vertex} in graph
The path @var{Path} is a path starting at vertex @var{Vertex} in graph
@var{Graph}.
@item dgraph_leaves(+@var{Graph}, ?@var{Vertices})
@findex dgraph_leaves/2
@snindex dgraph_leaves/2
@cnindex dgraph_leaves/2
The vertices @var{Vertices} have no outgoing edge in graph
@var{Graph}.
@end table
@node UnDGraphs, Lambda , DGraphs, Library
@ -16464,6 +16471,21 @@ then allow one to construct functors, and to obtain their name and arity.
Note that the functor is essentially a pair formed by an atom, and
arity.
Constructing terms in the stack may lead to overflow. The routine
@example
int YAP_RequiresExtraStack(size_t @var{min})
@end example
verifies whether you have at least @var{min} cells free in the stack,
and it returns true if it has to ensure enough memory by calling the
garbage collector and or stack shifter. The routine returns false if no
memory is needed, and a negative number if it cannot provide enough
memory.
You can set @var{min} to zero if you do not know how much room you need
but you do know you do not need a big chunk at a single go. Usually, the routine
would usually be called together with a long-jump to restart the
code. Slots can also be used if there is small state.
@node Unifying Terms, Manipulating Strings, Manipulating Terms, C-Interface
@section Unification

View File

@ -599,6 +599,8 @@ extern X_API size_t PROTO(YAP_SizeOfExportedTerm,(char *));
extern X_API YAP_Term PROTO(YAP_ImportTerm,(char *));
extern X_API int PROTO(YAP_RequiresExtraStack,(size_t));
#define YAP_InitCPred(N,A,F) YAP_UserCPredicate(N,F,A)
__END_DECLS

View File

@ -115,6 +115,7 @@ typedef enum {
#define YAP_BOOT_FROM_SAVED_CODE 1
#define YAP_BOOT_FROM_SAVED_STACKS 2
#define YAP_FULL_BOOT_FROM_PROLOG 4
#define YAP_BOOT_DONE_BEFOREHAND 8
#define YAP_BOOT_ERROR -1
#define YAP_WRITE_QUOTED 1

View File

@ -32,6 +32,7 @@
dgraph_min_paths/3,
dgraph_isomorphic/4,
dgraph_path/3,
dgraph_leaves/2,
dgraph_reachable/3
]).
@ -414,3 +415,13 @@ reachable([V|Vertices], Done0, DoneF, G, [V|EdgesF], Edges0) :-
rb_insert(Done0, V, [], Done1),
reachable(Kids, Done1, DoneI, G, EdgesF, EdgesI),
reachable(Vertices, DoneI, DoneF, G, EdgesI, Edges0).
dgraph_leaves(Graph, Vertices) :-
rb_visit(Graph, Pairs),
vertices_without_children(Pairs, Vertices).
vertices_without_children([], []).
vertices_without_children((V-[]).Pairs, V.Vertices) :-
vertices_without_children(Pairs, Vertices).
vertices_without_children(_V-[_|_].Pairs, Vertices) :-
vertices_without_children(Pairs, Vertices).

View File

@ -25,7 +25,10 @@
:- ensure_loaded(library(lists)).
:- load_foreign_files([matlab], ['eng','mx','ut'], init_matlab).
tell_warning :-
print_message(warning,functionality(matlab)).
:- ( catch(load_foreign_files([matlab], ['eng','mx','ut'], init_matlab),_,fail) -> true ; tell_warning).
matlab_eval_sequence(S) :-
atomic_concat(S,S1),

View File

@ -4671,6 +4671,11 @@ EndPredDefs
#if __YAP_PROLOG__
void Yap_flush(void)
{
flush_output(0);
}
void *
Yap_GetStreamHandle(Atom at)
{ GET_LD

View File

@ -16,6 +16,11 @@ YAPLIBDIR=@libdir@/Yap
#
SHAREDIR=$(ROOTDIR)/share/Yap
#
# where YAP should store documentation
#
DOCDIR=$(ROOTDIR)/share/doc/Yap
EXDIR=$(DOCDIR)/examples/CLPBN
#
#
# You shouldn't need to change what follows.
#
@ -35,6 +40,7 @@ CLPBN_EXDIR = $(srcdir)/examples
CLPBN_PROGRAMS= \
$(CLPBN_SRCDIR)/aggregates.yap \
$(CLPBN_SRCDIR)/bdd.yap \
$(CLPBN_SRCDIR)/bnt.yap \
$(CLPBN_SRCDIR)/bp.yap \
$(CLPBN_SRCDIR)/connected.yap \
@ -48,6 +54,7 @@ CLPBN_PROGRAMS= \
$(CLPBN_SRCDIR)/graphviz.yap \
$(CLPBN_SRCDIR)/ground_factors.yap \
$(CLPBN_SRCDIR)/hmm.yap \
$(CLPBN_SRCDIR)/horus.yap \
$(CLPBN_SRCDIR)/jt.yap \
$(CLPBN_SRCDIR)/matrix_cpt_utils.yap \
$(CLPBN_SRCDIR)/pgrammar.yap \
@ -72,6 +79,8 @@ CLPBN_SCHOOL_EXAMPLES= \
$(CLPBN_EXDIR)/School/parschema.yap \
$(CLPBN_EXDIR)/School/school_128.yap \
$(CLPBN_EXDIR)/School/school_32.yap \
$(CLPBN_EXDIR)/School/sch32.yap \
$(CLPBN_EXDIR)/School/school32_data.yap \
$(CLPBN_EXDIR)/School/school_64.yap \
$(CLPBN_EXDIR)/School/tables.yap
@ -92,12 +101,13 @@ CLPBN_EXAMPLES= \
install: $(CLBN_TOP) $(CLBN_PROGRAMS) $(CLPBN_PROGRAMS)
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn/learning
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn/examples/School
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn/examples/HMMer
mkdir -p $(DESTDIR)$(EXDIR)
mkdir -p $(DESTDIR)$(EXDIR)/School
mkdir -p $(DESTDIR)$(EXDIR)/HMMer
for h in $(CLPBN_TOP); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR); done
for h in $(CLPBN_PROGRAMS); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn; done
for h in $(CLPBN_LEARNING_PROGRAMS); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/learning; done
for h in $(CLPBN_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/examples; done
for h in $(CLPBN_SCHOOL_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/examples/School; done
for h in $(CLPBN_HMMER_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/examples/HMMer; done
for h in $(CLPBN_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(EXDIR); done
for h in $(CLPBN_SCHOOL_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(EXDIR)/School; done
for h in $(CLPBN_HMMER_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(EXDIR)/HMMer; done

View File

@ -13,6 +13,7 @@
clpbn_init_graph/1,
probability/2,
conditional_probability/3,
use_parfactors/1,
op( 500, xfy, with)]).
:- use_module(library(atts)).
@ -43,6 +44,7 @@
check_if_bp_done/1,
init_bp_solver/4,
run_bp_solver/3,
call_bp_ground/6,
finalize_bp_solver/1
]).
@ -61,11 +63,17 @@
run_jt_solver/3
]).
:- use_module('clpbn/bnt',
[do_bnt/3,
check_if_bnt_done/1
:- use_module('clpbn/bdd',
[bdd/3,
init_bdd_solver/4,
run_bdd_solver/3
]).
%% :- use_module('clpbn/bnt',
%% [do_bnt/3,
%% check_if_bnt_done/1
%% ]).
:- use_module('clpbn/gibbs',
[gibbs/3,
check_if_gibbs_done/1,
@ -111,7 +119,7 @@
[clpbn2gviz/4]).
:- use_module(clpbn/ground_factors,
[generate_bn/2]).
[generate_network/5]).
:- dynamic solver/1,output/1,use/1,suppress_attribute_display/1, parameter_softening/1, em_solver/1, use_parfactors/1.
@ -223,9 +231,17 @@ clpbn_marginalise(V, Dist) :-
% called by top-level
% or by call_residue/2
%
project_attributes(GVars, AVars0) :-
project_attributes(GVars, _AVars0) :-
use_parfactors(on),
clpbn_flag(solver, Solver), Solver \= fove, !,
generate_network(GVars, GKeys, Keys, Factors, Evidence),
(ground(GVars) ->
true
;
call_ground_solver(Solver, GVars, GKeys, Keys, Factors, Evidence, _Avars0)
).
project_attributes(GVars, AVars) :-
suppress_attribute_display(false),
generate_vars(GVars, AVars0, AVars),
AVars = [_|_],
solver(Solver),
( GVars = [_|_] ; Solver = graphs), !,
@ -243,11 +259,6 @@ project_attributes(GVars, AVars0) :-
).
project_attributes(_, _).
generate_vars(GVars, _, NewAVars) :-
use_parfactors(on), !,
generate_bn(GVars, NewAVars).
generate_vars(_GVars, AVars, AVars).
clpbn_vars(AVars, DiffVars, AllVars) :-
sort_vars_by_key(AVars,SortedAVars,DiffVars),
incorporate_evidence(SortedAVars, AllVars).
@ -289,6 +300,8 @@ write_out(ve, GVars, AVars, DiffVars) :-
ve(GVars, AVars, DiffVars).
write_out(jt, GVars, AVars, DiffVars) :-
jt(GVars, AVars, DiffVars).
write_out(bdd, GVars, AVars, DiffVars) :-
bdd(GVars, AVars, DiffVars).
write_out(bp, GVars, AVars, DiffVars) :-
bp(GVars, AVars, DiffVars).
write_out(gibbs, GVars, AVars, DiffVars) :-
@ -298,6 +311,11 @@ write_out(bnt, GVars, AVars, DiffVars) :-
write_out(fove, GVars, AVars, DiffVars) :-
fove(GVars, AVars, DiffVars).
% call a solver with keys, not actual variables
call_ground_solver(bp, GVars, GoalKeys, Keys, Factors, Evidence, Answ) :-
call_bp_ground(GVars, GoalKeys, Keys, Factors, Evidence, Answ).
get_bnode(Var, Goal) :-
get_atts(Var, [key(Key),dist(Dist,Parents)]),
get_dist(Dist,_,Domain,CPT),
@ -382,6 +400,9 @@ bind_clpbn(_, Var, _, _, _, _, []) :-
bind_clpbn(_, Var, _, _, _, _, []) :-
use(jt),
check_if_ve_done(Var), !.
bind_clpbn(_, Var, _, _, _, _, []) :-
use(bdd),
check_if_bdd_done(Var), !.
bind_clpbn(T, Var, Key0, _, _, _, []) :-
get_atts(Var, [key(Key)]), !,
(
@ -397,11 +418,12 @@ fresh_attvar(Var, NVar) :-
% I will now allow two CLPBN variables to be bound together.
%bind_clpbns(Key, Dist, Parents, Key, Dist, Parents).
bind_clpbns(Key, Dist, _Parents, Key1, Dist1, _Parents1) :-
bind_clpbns(Key, Dist, Parents, Key1, Dist1, Parents1) :-
Key == Key1, !,
get_dist(Dist,_Type,_Domain,_Table),
get_dist(Dist1,_Type1,_Domain1,_Table1),
Dist = Dist1.
Dist = Dist1,
Parents = Parents1.
bind_clpbns(Key, _, _, _, Key1, _, _, _) :-
Key\=Key1, !, fail.
bind_clpbns(_, _, _, _, _, _, _, _) :-
@ -452,6 +474,8 @@ clpbn_init_solver(bp, LVs, Vs0, VarsWithUnboundKeys, State) :-
init_bp_solver(LVs, Vs0, VarsWithUnboundKeys, State).
clpbn_init_solver(jt, LVs, Vs0, VarsWithUnboundKeys, State) :-
init_jt_solver(LVs, Vs0, VarsWithUnboundKeys, State).
clpbn_init_solver(bdd, LVs, Vs0, VarsWithUnboundKeys, State) :-
init_bdd_solver(LVs, Vs0, VarsWithUnboundKeys, State).
clpbn_init_solver(pcg, LVs, Vs0, VarsWithUnboundKeys, State) :-
init_pcg_solver(LVs, Vs0, VarsWithUnboundKeys, State).
@ -478,6 +502,9 @@ clpbn_run_solver(bp, LVs, LPs, State) :-
clpbn_run_solver(jt, LVs, LPs, State) :-
run_jt_solver(LVs, LPs, State).
clpbn_run_solver(bdd, LVs, LPs, State) :-
run_bdd_solver(LVs, LPs, State).
clpbn_run_solver(pcg, LVs, LPs, State) :-
run_pcg_solver(LVs, LPs, State).
@ -538,4 +565,5 @@ match_probability([p(V0=C)=Prob|_], C, V, Prob) :-
match_probability([_|Probs], C, V, Prob) :-
match_probability(Probs, C, V, Prob).
:- use_parfactors(on) -> true ; assert(use_parfactors(off)).

View File

@ -27,7 +27,7 @@
:- use_module(library('clpbn/dists'),
[
dist/4,
add_dist/6,
get_dist_domain_size/2]).
:- use_module(library('clpbn/matrix_cpt_utils'),
@ -44,8 +44,9 @@ check_for_agg_vars([_|Vs0], Vs1) :-
% transform aggregate distribution into tree
simplify_dist(avg(Domain), V, Key, Parents, Vs0, VsF) :- !,
cpt_average([V|Parents], Key, Domain, NewDist, Vs0, VsF),
dist(NewDist, Id, Key, ParentsF),
clpbn:put_atts(V, [dist(Id,ParentsF)]).
NewDist = p(Dom, Tab, Ps),
add_dist(Dom, tab, Tab, Ps, Key, Id),
clpbn:put_atts(V, [dist(Id,Ps)]).
simplify_dist(_, _, _, _, Vs0, Vs0).
cpt_average(AllVars, Key, Els0, Tab, Vs, NewVs) :-

View File

@ -0,0 +1,802 @@
/************************************************
BDDs in CLP(BN)
A variable is represented by the N possible cases it can take
V = v(Va, Vb, Vc)
The generic formula is
V <- X, Y
Va <- P*X1*Y1 + Q*X2*Y2 + ...
**************************************************/
:- module(clpbn_bdd,
[bdd/3,
set_solver_parameter/2,
init_bdd_solver/4,
run_bdd_solver/3,
finalize_bdd_solver/1,
check_if_bdd_done/1
]).
:- use_module(library('clpbn/dists'),
[dist/4,
get_dist_domain/2,
get_dist_domain_size/2,
get_dist_all_sizes/2,
get_dist_params/2
]).
:- use_module(library('clpbn/display'),
[clpbn_bind_vals/3]).
:- use_module(library('clpbn/aggregates'),
[check_for_agg_vars/2]).
:- use_module(library(atts)).
:- use_module(library(hacks)).
:- use_module(library(lists)).
:- use_module(library(dgraphs)).
:- use_module(library(bdd)).
:- use_module(library(rbtrees)).
:- use_module(library(bhash)).
:- use_module(library(matrix)).
:- dynamic network_counting/1.
:- attribute order/1.
check_if_bdd_done(_Var).
bdd([[]],_,_) :- !.
bdd([QueryVars], AllVars, AllDiffs) :-
init_bdd_solver(_, AllVars, _, BayesNet),
run_bdd_solver([QueryVars], LPs, BayesNet),
finalize_bdd_solver(BayesNet),
clpbn_bind_vals([QueryVars], [LPs], AllDiffs).
init_bdd_solver(_, AllVars0, _, bdd(Term, Leaves, Tops)) :-
% check_for_agg_vars(AllVars0, AllVars1),
sort_vars(AllVars0, AllVars, Leaves),
order_vars(AllVars, 0),
rb_new(Vars0),
rb_new(Pars0),
init_tops(Leaves,Tops),
get_vars_info(AllVars, Vars0, _Vars, Pars0, _Pars, Leaves, Tops, Term, []).
order_vars([], _).
order_vars([V|AllVars], I0) :-
put_atts(V, [order(I0)]),
I is I0+1,
order_vars(AllVars, I).
init_tops([],[]).
init_tops(_.Leaves,_.Tops) :-
init_tops(Leaves,Tops).
sort_vars(AllVars0, AllVars, Leaves) :-
dgraph_new(Graph0),
build_graph(AllVars0, Graph0, Graph),
dgraph_leaves(Graph, Leaves),
dgraph_top_sort(Graph, AllVars).
build_graph([], Graph, Graph).
build_graph(V.AllVars0, Graph0, Graph) :-
clpbn:get_atts(V, [dist(_DistId, Parents)]), !,
dgraph_add_vertex(Graph0, V, Graph1),
add_parents(Parents, V, Graph1, GraphI),
build_graph(AllVars0, GraphI, Graph).
build_graph(_V.AllVars0, Graph0, Graph) :-
build_graph(AllVars0, Graph0, Graph).
add_parents([], _V, Graph, Graph).
add_parents(V0.Parents, V, Graph0, GraphF) :-
dgraph_add_edge(Graph0, V0, V, GraphI),
add_parents(Parents, V, GraphI, GraphF).
get_vars_info([], Vs, Vs, Ps, Ps, _, _) --> [].
get_vars_info([V|MoreVs], Vs, VsF, Ps, PsF, Lvs, Outs) -->
{ clpbn:get_atts(V, [dist(DistId, Parents)]) }, !,
%{writeln(v:DistId:Parents)},
[DIST],
{ get_var_info(V, DistId, Parents, Vs, Vs2, Ps, Ps1, Lvs, Outs, DIST) },
get_vars_info(MoreVs, Vs2, VsF, Ps1, PsF, Lvs, Outs).
get_vars_info([_|MoreVs], Vs0, VsF, Ps0, PsF, VarsInfo, Lvs, Outs) :-
get_vars_info(MoreVs, Vs0, VsF, Ps0, PsF, VarsInfo, Lvs, Outs).
%
% let's have some fun with avg
%
get_var_info(V, avg(Domain), Parents, Vs, Vs2, Ps, Ps, Lvs, Outs, DIST) :- !,
length(Domain, DSize),
% run_though_avg(V, DSize, Domain, Parents, Vs, Vs2, Lvs, Outs, DIST).
top_down_with_tabling(V, DSize, Domain, Parents, Vs, Vs2, Lvs, Outs, DIST).
% bup_avg(V, DSize, Domain, Parents, Vs, Vs2, Lvs, Outs, DIST).
% standard random variable
get_var_info(V, DistId, Parents0, Vs, Vs2, Ps, Ps1, Lvs, Outs, DIST) :-
% clpbn:get_atts(V, [key(K)]), writeln(V:K:DistId:Parents),
reorder_vars(Parents0, Parents, Map),
check_p(DistId, Map, Parms, _ParmVars, Ps, Ps1),
unbound_parms(Parms, ParmVars),
check_v(V, DistId, DIST, Vs, Vs1),
DIST = info(V, Tree, Ev, Values, Formula, ParmVars, Parms),
% get a list of form [[P00,P01], [P10,P11], [P20,P21]]
get_parents(Parents, PVars, Vs1, Vs2),
cross_product(Values, Ev, PVars, ParmVars, Formula0),
% (numbervars(Formula0,0,_),writeln(formula0:Ev:Formula0), fail ; true),
get_evidence(V, Tree, Ev, Formula0, Formula, Lvs, Outs).
%, (numbervars(Formula,0,_),writeln(formula:Formula), fail ; true)
%
% reorder all variables and make sure we get a
% map of how the transfer was done.
%
% position zero is output
%
reorder_vars(Vs, OVs, Map) :-
add_pos(Vs, 1, PVs),
keysort(PVs, SVs),
remove_key(SVs, OVs, Map).
add_pos([], _, []).
add_pos([V|Vs], I0, [K-(I0,V)|PVs]) :-
get_atts(V,[order(K)]),
I is I0+1,
add_pos(Vs, I, PVs).
remove_key([], [], []).
remove_key([_-(I,V)|SVs], [V|OVs], [I|Map]) :-
remove_key(SVs, OVs, Map).
%%%%%%%%%%%%%%%%%%%%%%%%%
%
% use top-down to generate average
%
run_though_avg(V, 3, Domain, Parents0, Vs, Vs2, Lvs, Outs, DIST) :-
reorder_vars(Parents0, Parents, _Map),
check_v(V, avg(Domain,Parents0), DIST, Vs, Vs1),
DIST = info(V, Tree, Ev, [V0,V1,V2], Formula, [], []),
get_parents(Parents, PVars, Vs1, Vs2),
length(Parents, N),
generate_3tree(F00, PVars, 0, 0, 0, N, N0, N1, N2, R, (N1+2*N2 =< N/2), (N1+2*(N2+R) =< N/2)),
simplify_exp(F00, F0),
% generate_3tree(F1, PVars, 0, 0, 0, N, N0, N1, N2, R, ((N1+2*(N2+R) > N/2, N1+2*N2 < (3*N)/2))),
generate_3tree(F20, PVars, 0, 0, 0, N, N0, N1, N2, R, (N1+2*(N2+R) >= (3*N)/2), N1+2*N2 >= (3*N)/2),
% simplify_exp(F20, F2),
F20=F2,
Formula0 = [V0=F0*Ev0,V2=F2*Ev2,V1=not(F0+F2)*Ev1],
Ev = [Ev0,Ev1,Ev2],
get_evidence(V, Tree, Ev, Formula0, Formula, Lvs, Outs).
generate_3tree(OUT, _, I00, I10, I20, IR0, N0, N1, N2, R, _Exp, ExpF) :-
IR is IR0-1,
satisf(I00, I10, I20, IR, N0, N1, N2, R, ExpF),
!,
OUT = 1.
generate_3tree(OUT, [[P0,P1,P2]], I00, I10, I20, IR0, N0, N1, N2, R, Exp, _ExpF) :-
IR is IR0-1,
( satisf(I00+1, I10, I20, IR, N0, N1, N2, R, Exp) ->
L0 = [P0|L1]
;
L0 = L1
),
( satisf(I00, I10+1, I20, IR, N0, N1, N2, R, Exp) ->
L1 = [P1|L2]
;
L1 = L2
),
( satisf(I00, I10, I20+1, IR, N0, N1, N2, R, Exp) ->
L2 = [P2]
;
L2 = []
),
to_disj(L0, OUT).
generate_3tree(OUT, [[P0,P1,P2]|Ps], I00, I10, I20, IR0, N0, N1, N2, R, Exp, ExpF) :-
IR is IR0-1,
( satisf(I00+1, I10, I20, IR, N0, N1, N2, R, Exp) ->
I0 is I00+1, generate_3tree(O0, Ps, I0, I10, I20, IR, N0, N1, N2, R, Exp, ExpF)
->
L0 = [P0*O0|L1]
;
L0 = L1
),
( satisf(I00, I10+1, I20, IR0, N0, N1, N2, R, Exp) ->
I1 is I10+1, generate_3tree(O1, Ps, I00, I1, I20, IR, N0, N1, N2, R, Exp, ExpF)
->
L1 = [P1*O1|L2]
;
L1 = L2
),
( satisf(I00, I10, I20+1, IR0, N0, N1, N2, R, Exp) ->
I2 is I20+1, generate_3tree(O2, Ps, I00, I10, I2, IR, N0, N1, N2, R, Exp, ExpF)
->
L2 = [P2*O2]
;
L2 = []
),
to_disj(L0, OUT).
satisf(I0, I1, I2, IR, N0, N1, N2, R, Exp) :-
\+ \+ ( I0 = N0, I1=N1, I2=N2, IR=R, call(Exp) ).
not_satisf(I0, I1, I2, IR, N0, N1, N2, R, Exp) :-
\+ ( I0 = N0, I1=N1, I2=N2, IR=R, call(Exp) ).
%%%%%%%%%%%%%%%%%%%%%%%%%
%
% use top-down to generate average
%
top_down_with_tabling(V, Size, Domain, Parents0, Vs, Vs2, Lvs, Outs, DIST) :-
reorder_vars(Parents0, Parents, _Map),
check_v(V, avg(Domain,Parents), DIST, Vs, Vs1),
DIST = info(V, Tree, Ev, OVs, Formula, [], []),
get_parents(Parents, PVars, Vs1, Vs2),
length(Parents, N),
Max is (Size-1)*N, % This should be true
avg_borders(0, Size, Max, Borders),
b_hash_new(H0),
avg_trees(0, Max, PVars, Size, F1, 0, Borders, OVs, Ev, H0, H),
generate_avg_code(H, Formula, F),
% Formula0 = [V0=F0*Ev0,V2=F2*Ev2,V1=not(F0+F2)*Ev1],
% Ev = [Ev0,Ev1,Ev2],
get_evidence(V, Tree, Ev, F1, F, Lvs, Outs).
avg_trees(Size, _, _, Size, F0, _, F0, [], [], H, H) :- !.
avg_trees(I0, Max, PVars, Size, [V=O*E|F0], Im, [IM|Borders], [V|OVs], [E|Ev], H0, H) :-
I is I0+1,
avg_tree(PVars, 0, Max, Im, IM, Size, O, H0, HI),
Im1 is IM+1,
avg_trees(I, Max, PVars, Size, F0, Im1, Borders, OVs, Ev, HI, H).
avg_tree( _PVars, P, _, Im, IM, _Size, O, H0, H0) :-
b_hash_lookup(k(P,Im,IM), O=_Exp, H0), !.
avg_tree([], _P, _Max, _Im, _IM, _Size, 1, H, H).
avg_tree([Vals|PVars], P, Max, Im, IM, Size, O, H0, HF) :-
b_hash_insert(H0, k(P,Im,IM), O=Simp, HI),
MaxI is Max-(Size-1),
avg_exp(Vals, PVars, 0, P, MaxI, Size, Im, IM, HI, HF, Exp),
simplify_exp(Exp, Simp).
avg_exp([], _, _, _P, _Max, _Size, _Im, _IM, H, H, 0).
avg_exp([Val|Vals], PVars, I0, P0, Max, Size, Im, IM, HI, HF, O) :-
(Vals = [] -> O=O1 ; O = Val*O1+not(Val)*O2 ),
Im1 is max(0, Im-I0),
IM1 is IM-I0,
( IM1 < 0 -> O1 = 0, H2 = HI; /* we have exceed maximum */
Im1 > Max -> O1 = 0, H2 = HI; /* we cannot make to minimum */
Im1 = 0, IM1 > Max -> O1 = 1, H2 = HI; /* we cannot exceed maximum */
P is P0+1,
avg_tree(PVars, P, Max, Im1, IM1, Size, O1, HI, H2)
),
I is I0+1,
avg_exp(Vals, PVars, I, P0, Max, Size, Im, IM, H2, HF, O2).
generate_avg_code(H, Formula, Formula0) :-
b_hash_to_list(H,L),
sort(L, S),
strip_and_add(S, Formula0, Formula).
strip_and_add([], F, F).
strip_and_add([_-Exp|S], F0, F) :-
strip_and_add(S, [Exp|F0], F).
%%%%%%%%%%%%%%%%%%%%%%%%%
%
% use bottom-up dynamic programming to generate average
%
bup_avg(V, Size, Domain, Parents0, Vs, Vs2, Lvs, Outs, DIST) :-
reorder_vars(Parents0, Parents, _),
check_v(V, avg(Domain,Parents), DIST, Vs, Vs1),
DIST = info(V, Tree, Ev, OVs, Formula, [], []),
get_parents(Parents, PVars, Vs1, Vs2),
length(Parents, N),
Max is (Size-1)*N, % This should be true
ArraySize is Max+1,
functor(Protected, protected, ArraySize),
avg_domains(0, Size, 0, Max, LDomains),
Domains =.. [d|LDomains],
Reach is (Size-1),
generate_sums(PVars, Size, Max, Reach, Protected, Domains, ArraySize, Sums, F0),
% bin_sums(PVars, Sums, F00),
% reverse(F00,F0),
% easier to do recursion on lists
Sums =.. [_|LSums],
generate_avg(0, Size, 0, Max, LSums, OVs, Ev, F1, []),
reverse(F0, RF0),
get_evidence(V, Tree, Ev, F1, F2, Lvs, Outs),
append(RF0, F2, Formula).
%
% use binary approach, like what is standard
%
bin_sums(Vs, Sums, F) :-
vs_to_sums(Vs, Sums0),
bin_sums(Sums0, Sums, F, []).
vs_to_sums([], []).
vs_to_sums([V|Vs], [Sum|Sums0]) :-
Sum =.. [sum|V],
vs_to_sums(Vs, Sums0).
bin_sums([Sum], Sum) --> !.
bin_sums(LSums, Sum) -->
{ halve(LSums, Sums1, Sums2) },
bin_sums(Sums1, Sum1),
bin_sums(Sums2, Sum2),
sum(Sum1, Sum2, Sum).
halve(LSums, Sums1, Sums2) :-
length(LSums, L),
Take is L div 2,
head(Take, LSums, Sums1, Sums2).
head(0, L, [], L) :- !.
head(Take, [H|L], [H|Sums1], Sum2) :-
Take1 is Take-1,
head(Take1, L, Sums1, Sum2).
sum(Sum1, Sum2, Sum) -->
{ functor(Sum1, _, M1),
functor(Sum2, _, M2),
Max is M1+M2-2,
Max1 is Max+1,
Max0 is M2-1,
functor(Sum, sum, Max1),
Sum1 =.. [_|PVals] },
expand_sums(PVals, 0, Max0, Max1, M2, Sum2, Sum).
%
% bottom up step by step
%
%
generate_sums([PVals], Size, Max, _, _Protected, _Domains, _, Sum, []) :- !,
Max is Size-1,
Sum =.. [sum|PVals].
generate_sums([PVals|Parents], Size, Max, Reach, Protected, Domains, ASize, NewSums, F) :-
NewReach is Reach+(Size-1),
generate_sums(Parents, Size, Max0, NewReach, Protected, Domains, ASize, Sums, F0),
Max is Max0+(Size-1),
Max1 is Max+1,
functor(NewSums, sum, Max1),
protect_avg(0, Max0, Protected, Domains, ASize, Reach),
expand_sums(PVals, 0, Max0, Max1, Size, Sums, Protected, NewSums, F, F0).
protect_avg(Max0,Max0,_Protected, _Domains, _ASize, _Reach) :- !.
protect_avg(I0, Max0, Protected, Domains, ASize, Reach) :-
I is I0+1,
Top is I+Reach,
( Top > ASize ;
arg(I, Domains, CD),
arg(Top, Domains, CD)
), !,
arg(I, Protected, yes),
protect_avg(I, Max0, Protected, Domains, ASize, Reach).
protect_avg(I0, Max0, Protected, Domains, ASize, Reach) :-
I is I0+1,
protect_avg(I, Max0, Protected, Domains, ASize, Reach).
%
% outer loop: generate array of sums at level j= Sum[j0...jMax]
%
expand_sums(_Parents, Max, _, Max, _Size, _Sums, _P, _NewSums, F0, F0) :- !.
expand_sums(Parents, I0, Max0, Max, Size, Sums, Prot, NewSums, [O=SUM|F], F0) :-
I is I0+1,
arg(I, Prot, P),
var(P), !,
arg(I, NewSums, O),
sum_all(Parents, 0, I0, Max0, Sums, List),
to_disj(List, SUM),
expand_sums(Parents, I, Max0, Max, Size, Sums, Prot, NewSums, F, F0).
expand_sums(Parents, I0, Max0, Max, Size, Sums, Prot, NewSums, F, F0) :-
I is I0+1,
arg(I, Sums, O),
arg(I, NewSums, O),
expand_sums(Parents, I, Max0, Max, Size, Sums, Prot, NewSums, F, F0).
%
%inner loop: find all parents that contribute to A_ji,
% that is generate Pk*Sum_(j-1)l and k+l st k+l = i
%
sum_all([], _, _, _, _, []).
sum_all([V|Vs], Pos, I, Max0, Sums, [O|List]) :-
J is I-Pos,
J >= 0,
J =< Max0, !,
J1 is J+1,
arg(J1, Sums, S0),
( J < I -> O = V*S0 ; O = S0*V ),
Pos1 is Pos+1,
sum_all(Vs, Pos1, I, Max0, Sums, List).
sum_all([_V|Vs], Pos, I, Max0, Sums, List) :-
Pos1 is Pos+1,
sum_all(Vs, Pos1, I, Max0, Sums, List).
gen_arg(J, Sums, Max, S0) :-
gen_arg(0, Max, J, Sums, S0).
gen_arg(Max, Max, J, Sums, S0) :- !,
I is Max+1,
arg(I, Sums, A),
( Max = J -> S0 = A ; S0 = not(A)).
gen_arg(I0, Max, J, Sums, S) :-
I is I0+1,
arg(I, Sums, A),
( I0 = J -> S = A*S0 ; S = not(A)*S0),
gen_arg(I, Max, J, Sums, S0).
avg_borders(Size, Size, _Max, []) :- !.
avg_borders(I0, Size, Max, [J|Vals]) :-
I is I0+1,
Border is (I*Max)/Size,
J is integer(round(Border)),
avg_borders(I, Size, Max, Vals).
avg_domains(Size, Size, _J, _Max, []).
avg_domains(I0, Size, J0, Max, Vals) :-
I is I0+1,
Border is (I*Max)/Size,
fetch_domain_for_avg(J0, Border, J, I0, Vals, ValsI),
avg_domains(I, Size, J, Max, ValsI).
fetch_domain_for_avg(J, Border, J, _, Vals, Vals) :-
J > Border, !.
fetch_domain_for_avg(J0, Border, J, I0, [I0|LVals], RLVals) :-
J1 is J0+1,
fetch_domain_for_avg(J1, Border, J, I0, LVals, RLVals).
generate_avg(Size, Size, _J, _Max, [], [], [], F, F).
generate_avg(I0, Size, J0, Max, LSums, [O|OVs], [Ev|Evs], [O=Ev*Disj|F], F0) :-
I is I0+1,
Border is (I*Max)/Size,
fetch_for_avg(J0, Border, J, LSums, MySums, RSums),
to_disj(MySums, Disj),
generate_avg(I, Size, J, Max, RSums, OVs, Evs, F, F0).
fetch_for_avg(J, Border, J, RSums, [], RSums) :-
J > Border, !.
fetch_for_avg(J0, Border, J, [S|LSums], [S|MySums], RSums) :-
J1 is J0+1,
fetch_for_avg(J1, Border, J, LSums, MySums, RSums).
to_disj([], 0).
to_disj([V], V).
to_disj([V,V1|Vs], Out) :-
to_disj2([V1|Vs], V, Out).
to_disj2([V], V0, V0+V).
to_disj2([V,V1|Vs], V0, Out) :-
to_disj2([V1|Vs], V0+V, Out).
%
% look for parameters in the rb-tree, or add a new.
% distid is the key
%
check_p(DistId, Map, Parms, ParmVars, Ps, Ps) :-
rb_lookup(DistId-Map, theta(Parms, ParmVars), Ps), !.
check_p(DistId, Map, Parms, ParmVars, Ps, PsF) :-
get_dist_params(DistId, Parms0),
get_dist_all_sizes(DistId, Sizes),
swap_parms(Parms0, Sizes, [0|Map], Parms1),
length(Parms1, L0),
get_dist_domain_size(DistId, Size),
L1 is L0 div Size,
L is L0-L1,
initial_maxes(L1, Multipliers),
copy(L, Multipliers, NextMults, NextMults, Parms1, Parms, ParmVars),
%writeln(t:Size:Parms0:Parms:ParmVars),
rb_insert(Ps, DistId-Map, theta(Parms, ParmVars), PsF).
swap_parms(Parms0, Sizes, Map, Parms1) :-
matrix_new(floats, Sizes, Parms0, T0),
matrix_shuffle(T0,Map,TF),
matrix_to_list(TF, Parms1).
%
% we are using switches by two
%
initial_maxes(0, []) :- !.
initial_maxes(Size, [1.0|Multipliers]) :- !,
Size1 is Size-1,
initial_maxes(Size1, Multipliers).
copy(0, [], [], _, _Parms0, [], []) :- !.
copy(N, [], [], Ms, Parms0, Parms, ParmVars) :-!,
copy(N, Ms, NewMs, NewMs, Parms0, Parms, ParmVars).
copy(N, D.Ds, ND.NDs, New, El.Parms0, NEl.Parms, V.ParmVars) :-
N1 is N-1,
(El == 0.0 ->
NEl = 0,
V = NEl,
ND = D
;El == 1.0 ->
NEl = 1,
V = NEl,
ND = 0.0
;El == 0 ->
NEl = 0,
V = NEl,
ND = D
;El =:= 1 ->
NEl = 1,
V = NEl,
ND = 0.0,
V = NEl
;
NEl is El/D,
ND is D-El,
V = NEl
),
copy(N1, Ds, NDs, New, Parms0, Parms, ParmVars).
unbound_parms([], []).
unbound_parms(_.Parms, _.ParmVars) :-
unbound_parms(Parms, ParmVars).
check_v(V, _, INFO, Vs, Vs) :-
rb_lookup(V, INFO, Vs), !.
check_v(V, DistId, INFO, Vs0, Vs) :-
get_dist_domain_size(DistId, Size),
length(Values, Size),
length(Ev, Size),
INFO = info(V, _Tree, Ev, Values, _Formula, _, _),
rb_insert(Vs0, V, INFO, Vs).
get_parents([], [], Vs, Vs).
get_parents(V.Parents, Values.PVars, Vs0, Vs) :-
clpbn:get_atts(V, [dist(DistId, _)]),
check_v(V, DistId, INFO, Vs0, Vs1),
INFO = info(V, _Parent, _Ev, Values, _, _, _),
get_parents(Parents, PVars, Vs1, Vs).
%
% construct the formula, this is the key...
%
cross_product(Values, Ev, PVars, ParmVars, Formulas) :-
arrangements(PVars, Arranges),
apply_parents_first(Values, Ev, ParmCombos, ParmCombos, Arranges, Formulas, ParmVars).
%
% if we have the parent variables with two values, we get
% [[XP,YP],[XP,YN],[XN,YP],[XN,YN]]
%
arrangements([], [[]]).
arrangements([L1|Ls],O) :-
arrangements(Ls, LN),
expand(L1, LN, O, []).
expand([], _LN) --> [].
expand([H|L1], LN) -->
concatenate_all(H, LN),
expand(L1, LN).
concatenate_all(_H, []) --> [].
concatenate_all(H, L.LN) -->
[[H|L]],
concatenate_all(H, LN).
%
% core of algorithm
%
% Values -> Output Vars for BDD
% Es -> Evidence variables
% Previous -> top of difference list with parameters used so far
% P0 -> end of difference list with parameters used so far
% Pvars -> Parents
% Eqs -> Output Equations
% Pars -> Output Theta Parameters
%
apply_parents_first([Value], [E], Previous, [], PVars, [Value=Disj*E], Parameters) :- !,
apply_last_parent(PVars, Previous, Disj),
flatten(Previous, Parameters).
apply_parents_first([Value|Values], [E|Ev], Previous, P0, PVars, (Value=Disj*E).Formulas, Parameters) :-
P0 = [TheseParents|End],
apply_first_parent(PVars, Disj, TheseParents),
apply_parents_second(Values, Ev, Previous, End, PVars, Formulas, Parameters).
apply_parents_second([Value], [E], Previous, [], PVars, [Value=Disj*E], Parameters) :- !,
apply_last_parent(PVars, Previous, Disj),
flatten(Previous, Parameters).
apply_parents_second([Value|Values], [E|Ev], Previous, P0, PVars, (Value=Disj*E).Formulas, Parameters) :-
apply_middle_parent(PVars, Previous, Disj, TheseParents),
% this must be done after applying middle parents because of the var
% test.
P0 = [TheseParents|End],
apply_parents_second(Values, Ev, Previous, End, PVars, Formulas, Parameters).
apply_first_parent([Parents], Conj, [Theta]) :- !,
parents_to_conj(Parents,Theta,Conj).
apply_first_parent(Parents.PVars, Conj+Disj, Theta.TheseParents) :-
parents_to_conj(Parents,Theta,Conj),
apply_first_parent(PVars, Disj, TheseParents).
apply_middle_parent([Parents], Other, Conj, [ThetaPar]) :- !,
skim_for_theta(Other, Theta, _, ThetaPar),
parents_to_conj(Parents,Theta,Conj).
apply_middle_parent(Parents.PVars, Other, Conj+Disj, ThetaPar.TheseParents) :-
skim_for_theta(Other, Theta, Remaining, ThetaPar),
parents_to_conj(Parents,(Theta),Conj),
apply_middle_parent(PVars, Remaining, Disj, TheseParents).
apply_last_parent([Parents], Other, Conj) :- !,
parents_to_conj(Parents,(Theta),Conj),
skim_for_theta(Other, Theta, _, _).
apply_last_parent(Parents.PVars, Other, Conj+Disj) :-
parents_to_conj(Parents,(Theta),Conj),
skim_for_theta(Other, Theta, Remaining, _),
apply_last_parent(PVars, Remaining, Disj).
%
%
% simplify stuff, removing process that is cancelled by 0s
%
parents_to_conj([], Theta, Theta) :- !.
parents_to_conj(Ps, Theta, Theta*Conj) :-
parents_to_conj2(Ps, Conj).
parents_to_conj2([P],P) :- !.
parents_to_conj2(P.Ps,P*Conj) :-
parents_to_conj2(Ps,Conj).
%
% first case we haven't reached the end of the list so we need
% to create a new parameter variable
%
skim_for_theta([[P|Other]|V], not(P)*New, [Other|_], New) :- var(V), !.
%
% last theta, it is just negation of the other ones
%
skim_for_theta([[P|Other]], not(P), [Other], _) :- !.
%
% recursive case, build-up
%
skim_for_theta([[P|Other]|More], not(P)*Ps, [Other|Left], New ) :-
skim_for_theta(More, Ps, Left, New ).
get_evidence(V, Tree, Ev, F0, F, Leaves, Finals) :-
clpbn:get_atts(V, [evidence(Pos)]), !,
zero_pos(0, Pos, Ev),
insert_output(Leaves, V, Finals, Tree, Outs, SendOut),
get_outs(F0, F, SendOut, Outs).
% hidden deterministic node, can be removed.
get_evidence(V, _Tree, Ev, F0, [], _Leaves, _Finals) :-
clpbn:get_atts(V, [key(K)]),
functor(K, Name, 2),
( Name = 'AVG' ; Name = 'MAX' ; Name = 'MIN' ),
!,
one_list(Ev),
eval_outs(F0).
%% no evidence !!!
get_evidence(V, Tree, _Values, F0, F1, Leaves, Finals) :-
insert_output(Leaves, V, Finals, Tree, Outs, SendOut),
get_outs(F0, F1, SendOut, Outs).
zero_pos(_, _Pos, []).
zero_pos(Pos, Pos, 1.Values) :- !,
I is Pos+1,
zero_pos(I, Pos, Values).
zero_pos(I0, Pos, 0.Values) :-
I is I0+1,
zero_pos(I, Pos, Values).
one_list([]).
one_list(1.Ev) :-
one_list(Ev).
%
% insert a node with the disj of all alternatives, this is only done if node ends up to be in the output
%
insert_output([], _V, [], _Out, _Outs, []).
insert_output(V._Leaves, V0, [Top|_], Top, Outs, [Top = Outs]) :- V == V0, !.
insert_output(_.Leaves, V, _.Finals, Top, Outs, SendOut) :-
insert_output(Leaves, V, Finals, Top, Outs, SendOut).
get_outs([V=F], [V=NF|End], End, V) :- !,
% writeln(f0:F),
simplify_exp(F,NF).
get_outs((V=F).Outs, (V=NF).NOuts, End, (F0 + V)) :-
% writeln(f0:F),
simplify_exp(F,NF),
get_outs(Outs, NOuts, End, F0).
eval_outs([]).
eval_outs((V=F).Outs) :-
simplify_exp(F,NF),
V = NF,
eval_outs(Outs).
run_bdd_solver([[V]], LPs, bdd(Term, _Leaves, Nodes)) :-
build_out_node(Nodes, Node),
findall(Prob, get_prob(Term, Node, V, Prob),TermProbs),
sumlist(TermProbs, Sum),
writeln(TermProbs:Sum),
normalise(TermProbs, Sum, LPs).
build_out_node([_Top], []).
build_out_node([T,T1|Tops], [Top = T*Top]) :-
build_out_node2(T1.Tops, Top).
build_out_node2([Top], Top).
build_out_node2([T,T1|Tops], T*Top) :-
build_out_node2(T1.Tops, Top).
get_prob(Term, Node, V, SP) :-
bind_all(Term, Node, Bindings, V, AllParms, AllParmValues),
% reverse(AllParms, RAllParms),
term_variables(AllParms, NVs),
build_bdd(Bindings, NVs, AllParms, AllParmValues, Bdd),
bdd_to_probability_sum_product(Bdd, SP),
bdd_close(Bdd).
build_bdd(Bindings, NVs, VTheta, Theta, Bdd) :-
bdd_from_list(Bindings, NVs, Bdd),
bdd_size(Bdd, Len),
number_codes(Len,Codes),
atom_codes(Name,Codes),
bdd_print(Bdd, Name),
writeln(length=Len),
VTheta = Theta.
bind_all([], End, End, _V, [], []).
bind_all(info(V, _Tree, Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
V0 == V, !,
set_to_one_zeros(Ev),
bind_formula(Formula, BindsF, BindsI),
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
bind_all(info(_V, _Tree, Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
set_to_ones(Ev),!,
bind_formula(Formula, BindsF, BindsI),
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
% evidence: no need to add any stuff.
bind_all(info(_V, _Tree, _Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
bind_formula(Formula, BindsF, BindsI),
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
bind_formula([], L, L).
bind_formula(B.Formula, B.BsF, Bs0) :-
bind_formula(Formula, BsF, Bs0).
set_to_one_zeros([1|Values]) :-
set_to_zeros(Values).
set_to_one_zeros([0|Values]) :-
set_to_one_zeros(Values).
set_to_zeros([]).
set_to_zeros(0.Values) :-
set_to_zeros(Values).
set_to_ones([]).
set_to_ones(1.Values) :-
set_to_ones(Values).
normalise([], _Sum, []).
normalise(P.TermProbs, Sum, NP.LPs) :-
NP is P/Sum,
normalise(TermProbs, Sum, LPs).
finalize_bdd_solver(_).

View File

@ -1,16 +1,16 @@
/************************************************
/*******************************************************
Belief Propagation in CLP(BN)
Belief Propagation and Variable Elimination Interface
**************************************************/
********************************************************/
:- module(clpbn_bp,
[bp/3,
check_if_bp_done/1,
set_horus_flag/2,
init_bp_solver/4,
run_bp_solver/3,
call_bp_ground/6,
finalize_bp_solver/1
]).
@ -24,154 +24,143 @@
:- use_module(library('clpbn/display'),
[clpbn_bind_vals/3]).
[clpbn_bind_vals/3]).
:- use_module(library('clpbn/aggregates'),
[check_for_agg_vars/2]).
[check_for_agg_vars/2]).
:- use_module(library(charsio),
[term_to_atom/2]).
:- use_module(library(pfl),
[skolem/2,
get_pfl_parameters/2
]).
:- use_module(library(lists)).
:- use_module(library(atts)).
:- use_module(library(lists)).
:- use_module(library(charsio)).
:- load_foreign_files(['horus'], [], init_predicates).
:- attribute id/1.
:- use_module(library(bhash)).
%:- set_horus_flag(inf_alg, ve).
:- set_horus_flag(inf_alg, bn_bp).
%:- set_horus_flag(inf_alg, fg_bp).
%: -set_horus_flag(inf_alg, cbp).
:- use_module(horus,
[create_ground_network/4,
set_factors_params/2,
run_ground_solver/3,
set_vars_information/2,
free_ground_network/1
]).
:- set_horus_flag(schedule, seq_fixed).
%:- set_horus_flag(schedule, seq_random).
%:- set_horus_flag(schedule, parallel).
%:- set_horus_flag(schedule, max_residual).
:- set_horus_flag(accuracy, 0.0001).
call_bp_ground(QueryVars, QueryKeys, AllKeys, Factors, Evidence, Output) :-
writeln(here:Factors),
b_hash_new(Hash0),
keys_to_ids(AllKeys, 0, Hash0, Hash),
get_factors_type(Factors, Type),
evidence_to_ids(Evidence, Hash, EvidenceIds),
factors_to_ids(Factors, Hash, FactorIds),
writeln(type:Type), writeln(''),
writeln(allKeys:AllKeys), writeln(''),
writeln(factors:Factors), writeln(''),
writeln(factorIds:FactorIds), writeln(''),
writeln(evidence:Evidence), writeln(''),
writeln(evidenceIds:EvidenceIds), writeln(''),
create_ground_network(Type, FactorIds, EvidenceIds, Network),
%get_vars_information(AllKeys, StatesNames),
%set_vars_information(AllKeys, StatesNames),
run_solver(ground(Network,Hash), QueryKeys, Solutions),
writeln(answer:Solutions),
clpbn_bind_vals([QueryVars], Solutions, Output),
free_ground_network(Network).
:- set_horus_flag(max_iter, 1000).
:- set_horus_flag(use_logarithms, false).
%:- set_horus_flag(use_logarithms, true).
run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
%get_dists_parameters(DistIds, DistsParams),
%set_factors_params(Network, DistsParams),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
writeln(queryKeys:QueryKeys), writeln(''),
writeln(queryIds:QueryIds), writeln(''),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
run_ground_solver(Network, [QueryIds], Solutions).
:- set_horus_flag(order_factor_variables, false).
%:- set_horus_flag(order_factor_variables, true).
keys_to_ids([], _, Hash, Hash).
keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
b_hash_insert(Hash0, Key, I0, HashI),
I is I0+1,
keys_to_ids(AllKeys, I, HashI, Hash).
get_factors_type([f(bayes, _, _, _)|_], bayes) :- ! .
get_factors_type([f(markov, _, _, _)|_], markov) :- ! .
list_of_keys_to_ids([], _, []).
list_of_keys_to_ids([Key|QueryKeys], Hash, [Id|QueryIds]) :-
b_hash_lookup(Key, Id, Hash),
list_of_keys_to_ids(QueryKeys, Hash, QueryIds).
factors_to_ids([], _, []).
factors_to_ids([f(_, DistId, Keys, CPT)|Fs], Hash, [f(Ids, Ranges, CPT, DistId)|NFs]) :-
list_of_keys_to_ids(Keys, Hash, Ids),
get_ranges(Keys, Ranges),
factors_to_ids(Fs, Hash, NFs).
get_ranges([],[]).
get_ranges(K.Ks, Range.Rs) :- !,
skolem(K,Domain),
length(Domain,Range),
get_ranges(Ks, Rs).
evidence_to_ids([], _, []).
evidence_to_ids([Key=Ev|QueryKeys], Hash, [Id=Ev|QueryIds]) :-
b_hash_lookup(Key, Id, Hash),
evidence_to_ids(QueryKeys, Hash, QueryIds).
get_vars_information([], []).
get_vars_information(Key.QueryKeys, Domain.StatesNames) :-
pfl:skolem(Key, Domain),
get_vars_information(QueryKeys, StatesNames).
finalize_bp_solver(bp(Network, _)) :-
free_ground_network(Network).
bp([[]],_,_) :- !.
bp([QueryVars], AllVars, Output) :-
init_bp_solver(_, AllVars, _, Network),
run_bp_solver([QueryVars], LPs, Network),
finalize_bp_solver(Network),
clpbn_bind_vals([QueryVars], LPs, Output).
init_bp_solver(_, AllVars, _, Network),
run_bp_solver([QueryVars], LPs, Network),
finalize_bp_solver(Network),
clpbn_bind_vals([QueryVars], LPs, Output).
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
check_for_agg_vars(AllVars0, AllVars),
writeln('clpbn_vars:'),
print_clpbn_vars(AllVars),
assign_ids(AllVars, 0),
get_vars_info(AllVars, VarsInfo, DistIds0),
sort(DistIds0, DistIds),
create_ground_network(VarsInfo, BayesNet).
%get_extra_vars_info(AllVars, ExtraVarsInfo),
%set_extra_vars_info(BayesNet, ExtraVarsInfo).
%check_for_agg_vars(AllVars0, AllVars),
get_vars_info(AllVars0, VarsInfo, DistIds0),
sort(DistIds0, DistIds),
create_ground_network(VarsInfo, BayesNet),
true.
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
get_dists_parameters(DistIds, DistsParams),
set_bayes_net_params(Network, DistsParams),
flatten_1_element_sublists(QueryVars, QueryVars1),
vars_to_ids(QueryVars1, QueryVarsIds),
run_other_solvers(Network, QueryVarsIds, Solutions).
finalize_bp_solver(bp(Network, _)) :-
free_bayesian_network(Network).
assign_ids([], _).
assign_ids([V|Vs], Count) :-
put_atts(V, [id(Count)]),
Count1 is Count + 1,
assign_ids(Vs, Count1).
get_vars_info([], [], []).
get_vars_info(V.Vs,
var(VarId,DS,Ev,PIds,DistId).VarsInfo,
DistId.DistIds) :-
clpbn:get_atts(V, [dist(DistId, Parents)]), !,
get_atts(V, [id(VarId)]),
get_dist_domain_size(DistId, DS),
get_evidence(V, Ev),
vars_to_ids(Parents, PIds),
get_vars_info(Vs, VarsInfo, DistIds).
get_evidence(V, Ev) :-
clpbn:get_atts(V, [evidence(Ev)]), !.
get_evidence(_V, -1). % no evidence !!!
vars_to_ids([], []).
vars_to_ids([L|Vars], [LIds|Ids]) :-
is_list(L), !,
vars_to_ids(L, LIds),
vars_to_ids(Vars, Ids).
vars_to_ids([V|Vars], [VarId|Ids]) :-
get_atts(V, [id(VarId)]),
vars_to_ids(Vars, Ids).
get_extra_vars_info([], []).
get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
get_atts(V, [id(VarId)]), !,
clpbn:get_atts(V, [key(Key),dist(DistId, _)]),
term_to_atom(Key, Label),
get_dist_domain(DistId, Domain0),
numbers_to_atoms(Domain0, Domain),
get_extra_vars_info(Vs, VarsInfo).
get_extra_vars_info([_|Vs], VarsInfo) :-
get_extra_vars_info(Vs, VarsInfo).
get_dists_parameters(DistIds, DistsParams),
set_factors_params(Network, DistsParams),
vars_to_ids(QueryVars, QueryVarsIds),
run_ground_solver(Network, QueryVarsIds, Solutions).
get_dists_parameters([],[]).
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
get_dist_params(Id, Params),
get_dists_parameters(Ids, DistsInfo).
numbers_to_atoms([], []).
numbers_to_atoms([Atom|L0], [Atom|L]) :-
atom(Atom), !,
numbers_to_atoms(L0, L).
numbers_to_atoms([Number|L0], [Atom|L]) :-
number_atom(Number, Atom),
numbers_to_atoms(L0, L).
flatten_1_element_sublists([],[]).
flatten_1_element_sublists([[H|[]]|T],[H|R]) :- !,
flatten_1_element_sublists(T,R).
flatten_1_element_sublists([H|T],[H|R]) :-
flatten_1_element_sublists(T,R).
print_clpbn_vars(Var.AllVars) :-
clpbn:get_atts(Var, [key(Key),dist(DistId,Parents)]),
parents_to_keys(Parents, ParentKeys),
writeln(Var:Key:ParentKeys:DistId),
print_clpbn_vars(AllVars).
print_clpbn_vars([]).
parents_to_keys([], []).
parents_to_keys(Var.Parents, Key.Keys) :-
clpbn:get_atts(Var, [key(Key)]),
parents_to_keys(Parents, Keys).
get_dist_params(Id, Params),
get_dists_parameters(Ids, DistsInfo).

View File

@ -0,0 +1,77 @@
#include <cstdlib>
#include <cassert>
#include <iostream>
#include <fstream>
#include <sstream>
#include "BayesBall.h"
#include "Util.h"
FactorGraph*
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
{
assert (fg_.isFromBayesNetwork());
Scheduling scheduling;
for (unsigned i = 0; i < queryIds.size(); i++) {
assert (dag_.getNode (queryIds[i]));
DAGraphNode* n = dag_.getNode (queryIds[i]);
scheduling.push (ScheduleInfo (n, false, true));
}
while (!scheduling.empty()) {
ScheduleInfo& sch = scheduling.front();
DAGraphNode* n = sch.node;
n->setAsVisited();
if (n->hasEvidence() == false && sch.visitedFromChild) {
if (n->isMarkedOnTop() == false) {
n->markOnTop();
scheduleParents (n, scheduling);
}
if (n->isMarkedOnBottom() == false) {
n->markOnBottom();
scheduleChilds (n, scheduling);
}
}
if (sch.visitedFromParent) {
if (n->hasEvidence() && n->isMarkedOnTop() == false) {
n->markOnTop();
scheduleParents (n, scheduling);
}
if (n->hasEvidence() == false && n->isMarkedOnBottom() == false) {
n->markOnBottom();
scheduleChilds (n, scheduling);
}
}
scheduling.pop();
}
FactorGraph* fg = new FactorGraph();
constructGraph (fg);
return fg;
}
void
BayesBall::constructGraph (FactorGraph* fg) const
{
const FacNodes& facNodes = fg_.facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
const DAGraphNode* n = dag_.getNode (
facNodes[i]->factor().argument (0));
if (n->isMarkedOnTop()) {
fg->addFactor (Factor (facNodes[i]->factor()));
} else if (n->hasEvidence() && n->isVisited()) {
VarIds varIds = { facNodes[i]->factor().argument (0) };
Ranges ranges = { facNodes[i]->factor().range (0) };
Params params (ranges[0], LogAware::noEvidence());
params[n->getEvidence()] = LogAware::withEvidence();
fg->addFactor (Factor (varIds, ranges, params));
}
}
}

View File

@ -0,0 +1,85 @@
#ifndef HORUS_BAYESBALL_H
#define HORUS_BAYESBALL_H
#include <vector>
#include <queue>
#include <list>
#include <map>
#include "FactorGraph.h"
#include "BayesNet.h"
#include "Horus.h"
using namespace std;
struct ScheduleInfo
{
ScheduleInfo (DAGraphNode* n, bool vfp, bool vfc) :
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
DAGraphNode* node;
bool visitedFromParent;
bool visitedFromChild;
};
typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
class BayesBall
{
public:
BayesBall (FactorGraph& fg)
: fg_(fg) , dag_(fg.getStructure())
{
dag_.clear();
}
FactorGraph* getMinimalFactorGraph (const VarIds&);
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
{
BayesBall bb (fg);
return bb.getMinimalFactorGraph (vids);
}
private:
void constructGraph (FactorGraph* fg) const;
void scheduleParents (const DAGraphNode* n, Scheduling& sch) const;
void scheduleChilds (const DAGraphNode* n, Scheduling& sch) const;
FactorGraph& fg_;
DAGraph& dag_;
};
inline void
BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
{
const vector<DAGraphNode*>& ps = n->parents();
for (vector<DAGraphNode*>::const_iterator it = ps.begin();
it != ps.end(); it++) {
sch.push (ScheduleInfo (*it, false, true));
}
}
inline void
BayesBall::scheduleChilds (const DAGraphNode* n, Scheduling& sch) const
{
const vector<DAGraphNode*>& cs = n->childs();
for (vector<DAGraphNode*>::const_iterator it = cs.begin();
it != cs.end(); it++) {
sch.push (ScheduleInfo (*it, true, false));
}
}
#endif // HORUS_BAYESBALL_H

View File

@ -5,381 +5,57 @@
#include <fstream>
#include <sstream>
#include "xmlParser/xmlParser.h"
#include "BayesNet.h"
#include "Util.h"
BayesNet::~BayesNet (void)
void
DAGraph::addNode (DAGraphNode* n)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
delete nodes_[i];
}
assert (Util::contains (varMap_, n->varId()) == false);
nodes_.push_back (n);
varMap_[n->varId()] = n;
}
void
BayesNet::readFromBifFormat (const char* fileName)
DAGraph::addEdge (VarId vid1, VarId vid2)
{
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF");
// only the first network is parsed, others are ignored
XMLNode xNode = xMainNode.getChildNode ("NETWORK");
unsigned nVars = xNode.nChildNode ("VARIABLE");
for (unsigned i = 0; i < nVars; i++) {
XMLNode var = xNode.getChildNode ("VARIABLE", i);
if (string (var.getAttribute ("TYPE")) != "nature") {
cerr << "error: only \"nature\" variables are supported" << endl;
abort();
}
States states;
string label = var.getChildNode("NAME").getText();
unsigned nrStates = var.nChildNode ("OUTCOME");
for (unsigned j = 0; j < nrStates; j++) {
if (var.getChildNode("OUTCOME", j).getText() == 0) {
stringstream ss;
ss << j + 1;
states.push_back (ss.str());
} else {
states.push_back (var.getChildNode("OUTCOME", j).getText());
}
}
addNode (label, states);
}
unsigned nDefs = xNode.nChildNode ("DEFINITION");
if (nVars != nDefs) {
cerr << "error: different number of variables and definitions" << endl;
abort();
}
for (unsigned i = 0; i < nDefs; i++) {
XMLNode def = xNode.getChildNode ("DEFINITION", i);
string label = def.getChildNode("FOR").getText();
BayesNode* node = getBayesNode (label);
if (!node) {
cerr << "error: unknow variable `" << label << "'" << endl;
abort();
}
BnNodeSet parents;
unsigned nParams = node->nrStates();
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
string parentLabel = def.getChildNode("GIVEN", j).getText();
BayesNode* parentNode = getBayesNode (parentLabel);
if (!parentNode) {
cerr << "error: unknow variable `" << parentLabel << "'" << endl;
abort();
}
nParams *= parentNode->nrStates();
parents.push_back (parentNode);
}
node->setParents (parents);
unsigned count = 0;
Params params (nParams);
stringstream s (def.getChildNode("TABLE").getText());
while (!s.eof() && count < nParams) {
s >> params[count];
count ++;
}
if (count != nParams) {
cerr << "error: invalid number of parameters " ;
cerr << "for variable `" << label << "'" << endl;
abort();
}
params = reorderParameters (params, node->nrStates());
Distribution* dist = new Distribution (params);
node->setDistribution (dist);
addDistribution (dist);
}
setIndexes();
if (Globals::logDomain) {
distributionsToLogs();
}
unordered_map<VarId, DAGraphNode*>::iterator it1;
unordered_map<VarId, DAGraphNode*>::iterator it2;
it1 = varMap_.find (vid1);
it2 = varMap_.find (vid2);
assert (it1 != varMap_.end());
assert (it2 != varMap_.end());
it1->second->addChild (it2->second);
it2->second->addParent (it1->second);
}
BayesNode*
BayesNet::addNode (string label, const States& states)
const DAGraphNode*
DAGraph::getNode (VarId vid) const
{
VarId vid = nodes_.size();
varMap_.insert (make_pair (vid, nodes_.size()));
GraphicalModel::addVariableInformation (vid, label, states);
BayesNode* node = new BayesNode (VarNode (vid, states.size()));
nodes_.push_back (node);
return node;
unordered_map<VarId, DAGraphNode*>::const_iterator it;
it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
BayesNode*
BayesNet::addNode (VarId vid, unsigned dsize, int evidence, Distribution* dist)
DAGraphNode*
DAGraph::getNode (VarId vid)
{
varMap_.insert (make_pair (vid, nodes_.size()));
nodes_.push_back (new BayesNode (vid, dsize, evidence, dist));
return nodes_.back();
}
BayesNode*
BayesNet::getBayesNode (VarId vid) const
{
IndexMap::const_iterator it = varMap_.find (vid);
if (it == varMap_.end()) {
return 0;
} else {
return nodes_[it->second];
}
}
BayesNode*
BayesNet::getBayesNode (string label) const
{
BayesNode* node = 0;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (nodes_[i]->label() == label) {
node = nodes_[i];
break;
}
}
return node;
}
VarNode*
BayesNet::getVariableNode (VarId vid) const
{
BayesNode* node = getBayesNode (vid);
assert (node);
return node;
}
VarNodes
BayesNet::getVariableNodes (void) const
{
VarNodes vars;
for (unsigned i = 0; i < nodes_.size(); i++) {
vars.push_back (nodes_[i]);
}
return vars;
unordered_map<VarId, DAGraphNode*>::const_iterator it;
it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
void
BayesNet::addDistribution (Distribution* dist)
{
dists_.push_back (dist);
}
Distribution*
BayesNet::getDistribution (unsigned distId) const
{
Distribution* dist = 0;
for (unsigned i = 0; i < dists_.size(); i++) {
if (dists_[i]->id == (int) distId) {
dist = dists_[i];
break;
}
}
return dist;
}
const BnNodeSet&
BayesNet::getBayesNodes (void) const
{
return nodes_;
}
unsigned
BayesNet::nrNodes (void) const
{
return nodes_.size();
}
BnNodeSet
BayesNet::getRootNodes (void) const
{
BnNodeSet roots;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (nodes_[i]->isRoot()) {
roots.push_back (nodes_[i]);
}
}
return roots;
}
BnNodeSet
BayesNet::getLeafNodes (void) const
{
BnNodeSet leafs;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (nodes_[i]->isLeaf()) {
leafs.push_back (nodes_[i]);
}
}
return leafs;
}
BayesNet*
BayesNet::getMinimalRequesiteNetwork (VarId vid) const
{
return getMinimalRequesiteNetwork (VarIds() = {vid});
}
BayesNet*
BayesNet::getMinimalRequesiteNetwork (const VarIds& queryVarIds) const
{
BnNodeSet queryVars;
Scheduling scheduling;
for (unsigned i = 0; i < queryVarIds.size(); i++) {
BayesNode* n = getBayesNode (queryVarIds[i]);
assert (n);
queryVars.push_back (n);
scheduling.push (ScheduleInfo (n, false, true));
}
vector<StateInfo*> states (nodes_.size(), 0);
while (!scheduling.empty()) {
ScheduleInfo& sch = scheduling.front();
StateInfo* state = states[sch.node->getIndex()];
if (!state) {
state = new StateInfo();
states[sch.node->getIndex()] = state;
} else {
state->visited = true;
}
if (!sch.node->hasEvidence() && sch.visitedFromChild) {
if (!state->markedOnTop) {
state->markedOnTop = true;
scheduleParents (sch.node, scheduling);
}
if (!state->markedOnBottom) {
state->markedOnBottom = true;
scheduleChilds (sch.node, scheduling);
}
}
if (sch.visitedFromParent) {
if (sch.node->hasEvidence() && !state->markedOnTop) {
state->markedOnTop = true;
scheduleParents (sch.node, scheduling);
}
if (!sch.node->hasEvidence() && !state->markedOnBottom) {
state->markedOnBottom = true;
scheduleChilds (sch.node, scheduling);
}
}
scheduling.pop();
}
/*
cout << "\t\ttop\tbottom" << endl;
cout << "variable\t\tmarked\tmarked\tvisited\tobserved" << endl;
cout << "----------------------------------------------------------" ;
cout << endl;
for (unsigned i = 0; i < states.size(); i++) {
cout << nodes_[i]->label() << ":\t\t" ;
if (states[i]) {
states[i]->markedOnTop ? cout << "yes\t" : cout << "no\t" ;
states[i]->markedOnBottom ? cout << "yes\t" : cout << "no\t" ;
states[i]->visited ? cout << "yes\t" : cout << "no\t" ;
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
cout << endl;
} else {
cout << "no\tno\tno\t" ;
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
cout << endl;
}
}
cout << endl;
*/
BayesNet* bn = new BayesNet();
constructGraph (bn, states);
for (unsigned i = 0; i < nodes_.size(); i++) {
delete states[i];
}
return bn;
}
void
BayesNet::constructGraph (BayesNet* bn,
const vector<StateInfo*>& states) const
{
BnNodeSet mrnNodes;
vector<VarIds> parents;
for (unsigned i = 0; i < nodes_.size(); i++) {
bool isRequired = false;
if (states[i]) {
isRequired = (nodes_[i]->hasEvidence() && states[i]->visited)
||
states[i]->markedOnTop;
}
if (isRequired) {
parents.push_back (VarIds());
if (states[i]->markedOnTop) {
const BnNodeSet& ps = nodes_[i]->getParents();
for (unsigned j = 0; j < ps.size(); j++) {
parents.back().push_back (ps[j]->varId());
}
}
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
BayesNode* mrnNode = bn->addNode (nodes_[i]->varId(),
nodes_[i]->nrStates(),
nodes_[i]->getEvidence(),
nodes_[i]->getDistribution());
mrnNodes.push_back (mrnNode);
}
}
for (unsigned i = 0; i < mrnNodes.size(); i++) {
BnNodeSet ps;
for (unsigned j = 0; j < parents[i].size(); j++) {
assert (bn->getBayesNode (parents[i][j]) != 0);
ps.push_back (bn->getBayesNode (parents[i][j]));
}
mrnNodes[i]->setParents (ps);
}
bn->setIndexes();
}
bool
BayesNet::isPolyTree (void) const
{
return !containsUndirectedCycle();
}
void
BayesNet::setIndexes (void)
DAGraph::setIndexes (void)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
nodes_[i]->setIndex (i);
@ -389,233 +65,43 @@ BayesNet::setIndexes (void)
void
BayesNet::distributionsToLogs (void)
{
for (unsigned i = 0; i < dists_.size(); i++) {
Util::toLog (dists_[i]->params);
}
}
void
BayesNet::freeDistributions (void)
{
for (unsigned i = 0; i < dists_.size(); i++) {
delete dists_[i];
}
}
void
BayesNet::printGraphicalModel (void) const
DAGraph::clear (void)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
cout << *nodes_[i];
nodes_[i]->clear();
}
}
void
BayesNet::exportToGraphViz (const char* fileName,
bool showNeighborless,
const VarIds& highlightVarIds) const
DAGraph::exportToGraphViz (const char* fileName)
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "BayesNet::exportToDotFile()" << endl;
cerr << "DAGraph::exportToDotFile()" << endl;
abort();
}
out << "digraph {" << endl;
out << "ranksep=1" << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (showNeighborless || nodes_[i]->hasNeighbors()) {
out << nodes_[i]->varId() ;
if (nodes_[i]->hasEvidence()) {
out << " [" ;
out << "label=\"" << nodes_[i]->label() << "\"," ;
out << "style=filled, fillcolor=yellow" ;
out << "]" ;
} else {
out << " [" ;
out << "label=\"" << nodes_[i]->label() << "\"" ;
out << "]" ;
}
out << endl;
out << nodes_[i]->varId() ;
out << " [" ;
out << "label=\"" << nodes_[i]->label() << "\"" ;
if (nodes_[i]->hasEvidence()) {
out << ",style=filled, fillcolor=yellow" ;
}
out << "]" << endl;
}
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
BayesNode* node = getBayesNode (highlightVarIds[i]);
if (node) {
out << node->varId() ;
out << " [shape=box3d]" << endl;
} else {
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
abort();
}
}
for (unsigned i = 0; i < nodes_.size(); i++) {
const BnNodeSet& childs = nodes_[i]->getChilds();
const vector<DAGraphNode*>& childs = nodes_[i]->childs();
for (unsigned j = 0; j < childs.size(); j++) {
out << nodes_[i]->varId() << " -> " << childs[j]->varId() << " [style=bold]" << endl ;
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
out << " [style=bold]" << endl ;
}
}
out << "}" << endl;
out.close();
}
void
BayesNet::exportToBifFormat (const char* fileName) const
{
ofstream out (fileName);
if(!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "BayesNet::exportToBifFile()" << endl;
abort();
}
out << "<?xml version=\"1.0\" encoding=\"US-ASCII\"?>" << endl;
out << "<BIF VERSION=\"0.3\">" << endl;
out << "<NETWORK>" << endl;
out << "<NAME>" << fileName << "</NAME>" << endl << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
out << "<VARIABLE TYPE=\"nature\">" << endl;
out << "\t<NAME>" << nodes_[i]->label() << "</NAME>" << endl;
const States& states = nodes_[i]->states();
for (unsigned j = 0; j < states.size(); j++) {
out << "\t<OUTCOME>" << states[j] << "</OUTCOME>" << endl;
}
out << "</VARIABLE>" << endl << endl;
}
for (unsigned i = 0; i < nodes_.size(); i++) {
out << "<DEFINITION>" << endl;
out << "\t<FOR>" << nodes_[i]->label() << "</FOR>" << endl;
const BnNodeSet& parents = nodes_[i]->getParents();
for (unsigned j = 0; j < parents.size(); j++) {
out << "\t<GIVEN>" << parents[j]->label();
out << "</GIVEN>" << endl;
}
Params params = revertParameterReorder (nodes_[i]->getParameters(),
nodes_[i]->nrStates());
out << "\t<TABLE>" ;
for (unsigned j = 0; j < params.size(); j++) {
out << " " << params[j];
}
out << " </TABLE>" << endl;
out << "</DEFINITION>" << endl << endl;
}
out << "</NETWORK>" << endl;
out << "</BIF>" << endl << endl;
out.close();
}
bool
BayesNet::containsUndirectedCycle (void) const
{
vector<bool> visited (nodes_.size(), false);
for (unsigned i = 0; i < nodes_.size(); i++) {
int v = nodes_[i]->getIndex();
if (!visited[v]) {
if (containsUndirectedCycle (v, -1, visited)) {
return true;
}
}
}
return false;
}
bool
BayesNet::containsUndirectedCycle (int v, int p, vector<bool>& visited) const
{
visited[v] = true;
vector<int> adjacencies = getAdjacentNodes (v);
for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i];
if (!visited[w]) {
if (containsUndirectedCycle (w, v, visited)) {
return true;
}
}
else if (visited[w] && w != p) {
return true;
}
}
return false; // no cycle detected in this component
}
vector<int>
BayesNet::getAdjacentNodes (int v) const
{
vector<int> adjacencies;
const BnNodeSet& parents = nodes_[v]->getParents();
const BnNodeSet& childs = nodes_[v]->getChilds();
for (unsigned i = 0; i < parents.size(); i++) {
adjacencies.push_back (parents[i]->getIndex());
}
for (unsigned i = 0; i < childs.size(); i++) {
adjacencies.push_back (childs[i]->getIndex());
}
return adjacencies;
}
Params
BayesNet::reorderParameters (const Params& params, unsigned dsize) const
{
// the interchange format for bayesian networks keeps the probabilities
// in the following order:
// p(a1|b1,c1) p(a2|b1,c1) p(a1|b1,c2) p(a2|b1,c2) p(a1|b2,c1) p(a2|b2,c1)
// p(a1|b2,c2) p(a2|b2,c2).
//
// however, in clpbn we keep the probabilities in this order:
// p(a1|b1,c1) p(a1|b1,c2) p(a1|b2,c1) p(a1|b2,c2) p(a2|b1,c1) p(a2|b1,c2)
// p(a2|b2,c1) p(a2|b2,c2).
unsigned count = 0;
unsigned rowSize = params.size() / dsize;
Params reordered;
while (reordered.size() < params.size()) {
unsigned idx = count;
for (unsigned i = 0; i < rowSize; i++) {
reordered.push_back (params[idx]);
idx += dsize ;
}
count++;
}
return reordered;
}
Params
BayesNet::revertParameterReorder (const Params& params, unsigned dsize) const
{
unsigned count = 0;
unsigned rowSize = params.size() / dsize;
Params reordered;
while (reordered.size() < params.size()) {
unsigned idx = count;
for (unsigned i = 0; i < dsize; i++) {
reordered.push_back (params[idx]);
idx += rowSize;
}
count ++;
}
return reordered;
}

View File

@ -6,118 +6,83 @@
#include <list>
#include <map>
#include "GraphicalModel.h"
#include "BayesNode.h"
#include "Var.h"
#include "Horus.h"
using namespace std;
class Distribution;
struct ScheduleInfo
{
ScheduleInfo (BayesNode* n, bool vfp, bool vfc)
{
node = n;
visitedFromParent = vfp;
visitedFromChild = vfc;
}
BayesNode* node;
bool visitedFromParent;
bool visitedFromChild;
};
class Var;
struct StateInfo
{
StateInfo (void)
{
visited = true;
markedOnTop = false;
markedOnBottom = false;
}
bool visited;
bool markedOnTop;
bool markedOnBottom;
};
typedef vector<Distribution*> DistSet;
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
class BayesNet : public GraphicalModel
class DAGraphNode : public Var
{
public:
BayesNet (void) {};
~BayesNet (void);
DAGraphNode (Var* v) : Var (v) , visited_(false),
markedOnTop_(false), markedOnBottom_(false) { }
void readFromBifFormat (const char*);
BayesNode* addNode (string, const States&);
// BayesNode* addNode (VarId, unsigned, int, BnNodeSet&, Distribution*);
BayesNode* addNode (VarId, unsigned, int, Distribution*);
BayesNode* getBayesNode (VarId) const;
BayesNode* getBayesNode (string) const;
VarNode* getVariableNode (VarId) const;
VarNodes getVariableNodes (void) const;
void addDistribution (Distribution*);
Distribution* getDistribution (unsigned) const;
const BnNodeSet& getBayesNodes (void) const;
unsigned nrNodes (void) const;
BnNodeSet getRootNodes (void) const;
BnNodeSet getLeafNodes (void) const;
BayesNet* getMinimalRequesiteNetwork (VarId) const;
BayesNet* getMinimalRequesiteNetwork (const VarIds&) const;
void constructGraph (
BayesNet*, const vector<StateInfo*>&) const;
bool isPolyTree (void) const;
void setIndexes (void);
void distributionsToLogs (void);
void freeDistributions (void);
void printGraphicalModel (void) const;
void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const;
void exportToBifFormat (const char*) const;
const vector<DAGraphNode*>& childs (void) const { return childs_; }
vector<DAGraphNode*>& childs (void) { return childs_; }
const vector<DAGraphNode*>& parents (void) const { return parents_; }
vector<DAGraphNode*>& parents (void) { return parents_; }
void addParent (DAGraphNode* p) { parents_.push_back (p); }
void addChild (DAGraphNode* c) { childs_.push_back (c); }
bool isVisited (void) const { return visited_; }
void setAsVisited (void) { visited_ = true; }
bool isMarkedOnTop (void) const { return markedOnTop_; }
void markOnTop (void) { markedOnTop_ = true; }
bool isMarkedOnBottom (void) const { return markedOnBottom_; }
void markOnBottom (void) { markedOnBottom_ = true; }
void clear (void) { visited_ = markedOnTop_ = markedOnBottom_ = false; }
private:
DISALLOW_COPY_AND_ASSIGN (BayesNet);
bool visited_;
bool markedOnTop_;
bool markedOnBottom_;
bool containsUndirectedCycle (void) const;
bool containsUndirectedCycle (int, int, vector<bool>&)const;
vector<int> getAdjacentNodes (int) const;
Params reorderParameters (const Params&, unsigned) const;
Params revertParameterReorder (const Params&, unsigned) const;
void scheduleParents (const BayesNode*, Scheduling&) const;
void scheduleChilds (const BayesNode*, Scheduling&) const;
BnNodeSet nodes_;
DistSet dists_;
typedef unordered_map<unsigned, unsigned> IndexMap;
IndexMap varMap_;
vector<DAGraphNode*> childs_;
vector<DAGraphNode*> parents_;
};
inline void
BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
class DAGraph
{
const BnNodeSet& ps = n->getParents();
for (BnNodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) {
sch.push (ScheduleInfo (*it, false, true));
}
}
public:
DAGraph (void) { }
void addNode (DAGraphNode* n);
void addEdge (VarId vid1, VarId vid2);
inline void
BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const
{
const BnNodeSet& cs = n->getChilds();
for (BnNodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) {
sch.push (ScheduleInfo (*it, true, false));
}
}
const DAGraphNode* getNode (VarId vid) const;
DAGraphNode* getNode (VarId vid);
bool empty (void) const { return nodes_.empty(); }
void setIndexes (void);
void clear (void);
void exportToGraphViz (const char*);
private:
vector<DAGraphNode*> nodes_;
unordered_map<VarId, DAGraphNode*> varMap_;
};
#endif // HORUS_BAYESNET_H

View File

@ -1,291 +0,0 @@
#include <cstdlib>
#include <cassert>
#include <iomanip>
#include <iostream>
#include <sstream>
#include "BayesNode.h"
BayesNode::BayesNode (VarId vid,
unsigned dsize,
int evidence,
Distribution* dist)
: VarNode (vid, dsize, evidence)
{
dist_ = dist;
}
BayesNode::BayesNode (VarId vid,
unsigned dsize,
int evidence,
const BnNodeSet& parents,
Distribution* dist)
: VarNode (vid, dsize, evidence)
{
parents_ = parents;
dist_ = dist;
for (unsigned int i = 0; i < parents.size(); i++) {
parents[i]->addChild (this);
}
}
void
BayesNode::setParents (const BnNodeSet& parents)
{
parents_ = parents;
for (unsigned int i = 0; i < parents.size(); i++) {
parents[i]->addChild (this);
}
}
void
BayesNode::addChild (BayesNode* node)
{
childs_.push_back (node);
}
void
BayesNode::setDistribution (Distribution* dist)
{
assert (dist);
dist_ = dist;
}
Distribution*
BayesNode::getDistribution (void)
{
return dist_;
}
const Params&
BayesNode::getParameters (void)
{
return dist_->params;
}
Params
BayesNode::getRow (int rowIndex) const
{
int rowSize = getRowSize();
int offset = rowSize * rowIndex;
Params row (rowSize);
for (int i = 0; i < rowSize; i++) {
row[i] = dist_->params[offset + i] ;
}
return row;
}
bool
BayesNode::isRoot (void)
{
return getParents().empty();
}
bool
BayesNode::isLeaf (void)
{
return getChilds().empty();
}
bool
BayesNode::hasNeighbors (void) const
{
return childs_.size() != 0 || parents_.size() != 0;
}
int
BayesNode::getCptSize (void)
{
return dist_->params.size();
}
int
BayesNode::getIndexOfParent (const BayesNode* parent) const
{
for (unsigned int i = 0; i < parents_.size(); i++) {
if (parents_[i] == parent) {
return i;
}
}
return -1;
}
string
BayesNode::cptEntryToString (
int row,
const vector<unsigned>& stateConf) const
{
stringstream ss;
ss << "p(" ;
ss << states()[row];
if (parents_.size() > 0) {
ss << "|" ;
for (unsigned int i = 0; i < stateConf.size(); i++) {
if (i != 0) {
ss << ",";
}
ss << parents_[i]->states()[stateConf[i]];
}
}
ss << ")" ;
return ss.str();
}
vector<string>
BayesNode::getDomainHeaders (void) const
{
unsigned nParents = parents_.size();
unsigned rowSize = getRowSize();
unsigned nReps = 1;
vector<string> headers (rowSize);
for (int i = nParents - 1; i >= 0; i--) {
States states = parents_[i]->states();
unsigned index = 0;
while (index < rowSize) {
for (unsigned j = 0; j < parents_[i]->nrStates(); j++) {
for (unsigned r = 0; r < nReps; r++) {
if (headers[index] != "") {
headers[index] = states[j] + "," + headers[index];
} else {
headers[index] = states[j];
}
index++;
}
}
}
nReps *= parents_[i]->nrStates();
}
return headers;
}
ostream&
operator << (ostream& o, const BayesNode& node)
{
o << "variable " << node.getIndex() << endl;
o << "Var Id: " << node.varId() << endl;
o << "Label: " << node.label() << endl;
o << "Evidence: " ;
if (node.hasEvidence()) {
o << node.getEvidence();
}
else {
o << "no" ;
}
o << endl;
o << "Parents: " ;
const BnNodeSet& parents = node.getParents();
if (parents.size() != 0) {
for (unsigned int i = 0; i < parents.size() - 1; i++) {
o << parents[i]->label() << ", " ;
}
o << parents[parents.size() - 1]->label();
}
o << endl;
o << "Childs: " ;
const BnNodeSet& childs = node.getChilds();
if (childs.size() != 0) {
for (unsigned int i = 0; i < childs.size() - 1; i++) {
o << childs[i]->label() << ", " ;
}
o << childs[childs.size() - 1]->label();
}
o << endl;
o << "Domain: " ;
States states = node.states();
for (unsigned int i = 0; i < states.size() - 1; i++) {
o << states[i] << ", " ;
}
if (states.size() != 0) {
o << states[states.size() - 1];
}
o << endl;
// min width of first column
const unsigned int MIN_DOMAIN_WIDTH = 4;
// min width of following columns
const unsigned int MIN_COMBO_WIDTH = 12;
unsigned int domainWidth = states[0].length();
for (unsigned int i = 1; i < states.size(); i++) {
if (states[i].length() > domainWidth) {
domainWidth = states[i].length();
}
}
domainWidth = (domainWidth < MIN_DOMAIN_WIDTH)
? MIN_DOMAIN_WIDTH
: domainWidth;
o << left << setw (domainWidth) << "cpt" << right;
vector<int> widths;
int lineWidth = domainWidth;
vector<string> headers = node.getDomainHeaders();
if (!headers.empty()) {
for (unsigned int i = 0; i < headers.size(); i++) {
unsigned int len = headers[i].length();
int w = (len < MIN_COMBO_WIDTH) ? MIN_COMBO_WIDTH : len;
widths.push_back (w);
o << setw (w) << headers[i];
lineWidth += w;
}
o << endl;
} else {
cout << endl;
widths.push_back (domainWidth);
lineWidth += MIN_COMBO_WIDTH;
}
for (int i = 0; i < lineWidth; i++) {
o << "-" ;
}
o << endl;
for (unsigned int i = 0; i < states.size(); i++) {
Params row = node.getRow (i);
o << left << setw (domainWidth) << states[i] << right;
for (unsigned j = 0; j < node.getRowSize(); j++) {
o << setw (widths[j]) << row[j];
}
o << endl;
}
o << endl;
return o;
}

View File

@ -1,61 +0,0 @@
#ifndef HORUS_BAYESNODE_H
#define HORUS_BAYESNODE_H
#include <vector>
#include "VarNode.h"
#include "Distribution.h"
#include "Horus.h"
using namespace std;
class BayesNode : public VarNode
{
public:
BayesNode (const VarNode& v) : VarNode (v) {}
BayesNode (VarId, unsigned, int, Distribution*);
BayesNode (VarId, unsigned, int, const BnNodeSet&, Distribution*);
void setParents (const BnNodeSet&);
void addChild (BayesNode*);
void setDistribution (Distribution*);
Distribution* getDistribution (void);
const Params& getParameters (void);
Params getRow (int) const;
bool isRoot (void);
bool isLeaf (void);
bool hasNeighbors (void) const;
int getCptSize (void);
int getIndexOfParent (const BayesNode*) const;
string cptEntryToString (int, const vector<unsigned>&) const;
const BnNodeSet& getParents (void) const { return parents_; }
const BnNodeSet& getChilds (void) const { return childs_; }
unsigned getRowSize (void) const
{
return dist_->params.size() / nrStates();
}
double getProbability (int row, unsigned col)
{
int idx = (row * getRowSize()) + col;
return dist_->params[idx];
}
private:
DISALLOW_COPY_AND_ASSIGN (BayesNode);
States getDomainHeaders (void) const;
friend ostream& operator << (ostream&, const BayesNode&);
BnNodeSet parents_;
BnNodeSet childs_;
Distribution* dist_;
};
ostream& operator << (ostream&, const BayesNode&);
#endif // HORUS_BAYESNODE_H

View File

@ -1,803 +0,0 @@
#include <cstdlib>
#include <limits>
#include <time.h>
#include <algorithm>
#include <iostream>
#include <sstream>
#include <iomanip>
#include "BnBpSolver.h"
#include "Indexer.h"
BnBpSolver::BnBpSolver (const BayesNet& bn) : Solver (&bn)
{
bayesNet_ = &bn;
}
BnBpSolver::~BnBpSolver (void)
{
for (unsigned i = 0; i < nodesI_.size(); i++) {
delete nodesI_[i];
}
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
}
void
BnBpSolver::runSolver (void)
{
clock_t start;
if (COLLECT_STATISTICS) {
start = clock();
}
initializeSolver();
runLoopySolver();
if (DL >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Belief propagation converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = bayesNet_->nrNodes();
if (COLLECT_STATISTICS) {
unsigned nIters = 0;
bool loopy = bayesNet_->isPolyTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
stringstream ss;
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
bayesNet_->exportToGraphViz (ss.str().c_str());
}
}
Params
BnBpSolver::getPosterioriOf (VarId vid)
{
BayesNode* node = bayesNet_->getBayesNode (vid);
assert (node);
return nodesI_[node->getIndex()]->getBeliefs();
}
Params
BnBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
{
if (DL >= 2) {
cout << "calculating joint distribution on: " ;
for (unsigned i = 0; i < jointVarIds.size(); i++) {
VarNode* var = bayesNet_->getBayesNode (jointVarIds[i]);
cout << var->label() << " " ;
}
cout << endl;
}
return getJointByConditioning (jointVarIds);
}
void
BnBpSolver::initializeSolver (void)
{
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
for (unsigned i = 0; i < nodesI_.size(); i++) {
delete nodesI_[i];
}
nodesI_.clear();
nodesI_.reserve (nodes.size());
links_.clear();
sortedOrder_.clear();
linkMap_.clear();
for (unsigned i = 0; i < nodes.size(); i++) {
nodesI_.push_back (new BpNodeInfo (nodes[i]));
}
BnNodeSet roots = bayesNet_->getRootNodes();
for (unsigned i = 0; i < roots.size(); i++) {
const Params& params = roots[i]->getParameters();
Params& piVals = ninf(roots[i])->getPiValues();
for (unsigned ri = 0; ri < roots[i]->nrStates(); ri++) {
piVals[ri] = params[ri];
}
}
for (unsigned i = 0; i < nodes.size(); i++) {
const BnNodeSet& parents = nodes[i]->getParents();
for (unsigned j = 0; j < parents.size(); j++) {
BpLink* newLink = new BpLink (
parents[j], nodes[i], LinkOrientation::DOWN);
links_.push_back (newLink);
ninf(nodes[i])->addIncomingParentLink (newLink);
ninf(parents[j])->addOutcomingChildLink (newLink);
}
const BnNodeSet& childs = nodes[i]->getChilds();
for (unsigned j = 0; j < childs.size(); j++) {
BpLink* newLink = new BpLink (
childs[j], nodes[i], LinkOrientation::UP);
links_.push_back (newLink);
ninf(nodes[i])->addIncomingChildLink (newLink);
ninf(childs[j])->addOutcomingParentLink (newLink);
}
}
for (unsigned i = 0; i < nodes.size(); i++) {
if (nodes[i]->hasEvidence()) {
Params& piVals = ninf(nodes[i])->getPiValues();
Params& ldVals = ninf(nodes[i])->getLambdaValues();
for (unsigned xi = 0; xi < nodes[i]->nrStates(); xi++) {
piVals[xi] = Util::noEvidence();
ldVals[xi] = Util::noEvidence();
}
piVals[nodes[i]->getEvidence()] = Util::withEvidence();
ldVals[nodes[i]->getEvidence()] = Util::withEvidence();
}
}
}
void
BnBpSolver::runLoopySolver()
{
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_++;
if (DL >= 2) {
cout << "****************************************" ;
cout << "****************************************" ;
cout << endl;
cout << " Iteration " << nIters_ << endl;
cout << "****************************************" ;
cout << "****************************************" ;
cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
updateValues (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
}
for (unsigned i = 0; i < links_.size(); i++) {
updateMessage (links_[i]);
updateValues (links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
if (DL >= 2) {
cout << endl;
}
}
}
bool
BnBpSolver::converged (void) const
{
// this can happen if the graph is fully disconnected
if (links_.size() == 0) {
return true;
}
if (nIters_ == 0 || nIters_ == 1) {
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (maxResidual < BpOptions::accuracy) {
converged = true;
} else {
converged = false;
}
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (DL >= 2) {
cout << links_[i]->toString() + " residual change = " ;
cout << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
break;
}
}
}
return converged;
}
void
BnBpSolver::maxResidualSchedule (void)
{
if (nIters_ == 1) {
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it));
}
return;
}
for (unsigned c = 0; c < sortedOrder_.size(); c++) {
if (DL >= 2) {
cout << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) {
cout << " " << setw (30) << left << (*it)->toString();
cout << "residual = " << (*it)->getResidual() << endl;
}
}
SortedOrder::iterator it = sortedOrder_.begin();
BpLink* link = *it;
if (link->getResidual() < BpOptions::accuracy) {
sortedOrder_.erase (it);
it = sortedOrder_.begin();
return;
}
updateMessage (link);
updateValues (link);
link->clearResidual();
sortedOrder_.erase (it);
linkMap_.find (link)->second = sortedOrder_.insert (link);
const BpLinkSet& outParentLinks =
ninf(link->getDestination())->getOutcomingParentLinks();
for (unsigned i = 0; i < outParentLinks.size(); i++) {
if (outParentLinks[i]->getDestination() != link->getSource()
&& outParentLinks[i]->getDestination()->hasEvidence() == false) {
calculateMessage (outParentLinks[i]);
BpLinkMap::iterator iter = linkMap_.find (outParentLinks[i]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (outParentLinks[i]);
}
}
const BpLinkSet& outChildLinks =
ninf(link->getDestination())->getOutcomingChildLinks();
for (unsigned i = 0; i < outChildLinks.size(); i++) {
if (outChildLinks[i]->getDestination() != link->getSource()) {
calculateMessage (outChildLinks[i]);
BpLinkMap::iterator iter = linkMap_.find (outChildLinks[i]);
sortedOrder_.erase (iter->second);
iter->second = sortedOrder_.insert (outChildLinks[i]);
}
}
if (DL >= 2) {
cout << "----------------------------------------" ;
cout << "----------------------------------------" << endl;
}
}
}
void
BnBpSolver::updatePiValues (BayesNode* x)
{
// π(Xi)
if (DL >= 3) {
cout << "updating " << PI_SYMBOL << " values for " << x->label() << endl;
}
Params& piValues = ninf(x)->getPiValues();
const BpLinkSet& parentLinks = ninf(x)->getIncomingParentLinks();
const BnNodeSet& ps = x->getParents();
Ranges ranges;
for (unsigned i = 0; i < ps.size(); i++) {
ranges.push_back (ps[i]->nrStates());
}
StatesIndexer indexer (ranges, false);
stringstream* calcs1 = 0;
stringstream* calcs2 = 0;
Params messageProducts (indexer.size());
for (unsigned k = 0; k < indexer.size(); k++) {
if (DL >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double messageProduct = Util::multIdenty();
if (Globals::logDomain) {
for (unsigned i = 0; i < parentLinks.size(); i++) {
messageProduct += parentLinks[i]->getMessage()[indexer[i]];
}
} else {
for (unsigned i = 0; i < parentLinks.size(); i++) {
messageProduct *= parentLinks[i]->getMessage()[indexer[i]];
if (DL >= 5) {
if (i != 0) *calcs1 << " + " ;
if (i != 0) *calcs2 << " + " ;
*calcs1 << parentLinks[i]->toString (indexer[i]);
*calcs2 << parentLinks[i]->getMessage()[indexer[i]];
}
}
}
messageProducts[k] = messageProduct;
if (DL >= 5) {
cout << " mp" << k;
cout << " = " << (*calcs1).str();
if (parentLinks.size() == 1) {
cout << " = " << messageProduct << endl;
} else {
cout << " = " << (*calcs2).str();
cout << " = " << messageProduct << endl;
}
delete calcs1;
delete calcs2;
}
++ indexer;
}
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
double sum = Util::addIdenty();
if (DL >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
indexer.reset();
if (Globals::logDomain) {
for (unsigned k = 0; k < indexer.size(); k++) {
Util::logSum (sum,
x->getProbability(xi, indexer.linearIndex()) + messageProducts[k]);
++ indexer;
}
} else {
for (unsigned k = 0; k < indexer.size(); k++) {
sum += x->getProbability (xi, indexer.linearIndex()) * messageProducts[k];
if (DL >= 5) {
if (k != 0) *calcs1 << " + " ;
if (k != 0) *calcs2 << " + " ;
*calcs1 << x->cptEntryToString (xi, indexer.indices());
*calcs1 << ".mp" << k;
*calcs2 << Util::fl (x->getProbability (xi, indexer.linearIndex()));
*calcs2 << "*" << messageProducts[k];
}
++ indexer;
}
}
piValues[xi] = sum;
if (DL >= 5) {
cout << " " << PI_SYMBOL << "(" << x->label() << ")" ;
cout << "[" << x->states()[xi] << "]" ;
cout << " = " << (*calcs1).str();
cout << " = " << (*calcs2).str();
cout << " = " << piValues[xi] << endl;
delete calcs1;
delete calcs2;
}
}
}
void
BnBpSolver::updateLambdaValues (BayesNode* x)
{
// λ(Xi)
if (DL >= 3) {
cout << "updating " << LD_SYMBOL << " values for " << x->label() << endl;
}
Params& lambdaValues = ninf(x)->getLambdaValues();
const BpLinkSet& childLinks = ninf(x)->getIncomingChildLinks();
stringstream* calcs1 = 0;
stringstream* calcs2 = 0;
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
if (DL >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double product = Util::multIdenty();
if (Globals::logDomain) {
for (unsigned i = 0; i < childLinks.size(); i++) {
product += childLinks[i]->getMessage()[xi];
}
} else {
for (unsigned i = 0; i < childLinks.size(); i++) {
product *= childLinks[i]->getMessage()[xi];
if (DL >= 5) {
if (i != 0) *calcs1 << "." ;
if (i != 0) *calcs2 << "*" ;
*calcs1 << childLinks[i]->toString (xi);
*calcs2 << childLinks[i]->getMessage()[xi];
}
}
}
lambdaValues[xi] = product;
if (DL >= 5) {
cout << " " << LD_SYMBOL << "(" << x->label() << ")" ;
cout << "[" << x->states()[xi] << "]" ;
cout << " = " << (*calcs1).str();
if (childLinks.size() == 1) {
cout << " = " << product << endl;
} else {
cout << " = " << (*calcs2).str();
cout << " = " << lambdaValues[xi] << endl;
}
delete calcs1;
delete calcs2;
}
}
}
void
BnBpSolver::calculatePiMessage (BpLink* link)
{
// πX(Zi)
BayesNode* z = link->getSource();
BayesNode* x = link->getDestination();
Params& zxPiNextMessage = link->getNextMessage();
const BpLinkSet& zChildLinks = ninf(z)->getIncomingChildLinks();
stringstream* calcs1 = 0;
stringstream* calcs2 = 0;
const Params& zPiValues = ninf(z)->getPiValues();
for (unsigned zi = 0; zi < z->nrStates(); zi++) {
double product = zPiValues[zi];
if (DL >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
*calcs1 << PI_SYMBOL << "(" << z->label() << ")";
*calcs1 << "[" << z->states()[zi] << "]" ;
*calcs2 << product;
}
if (Globals::logDomain) {
for (unsigned i = 0; i < zChildLinks.size(); i++) {
if (zChildLinks[i]->getSource() != x) {
product += zChildLinks[i]->getMessage()[zi];
}
}
} else {
for (unsigned i = 0; i < zChildLinks.size(); i++) {
if (zChildLinks[i]->getSource() != x) {
product *= zChildLinks[i]->getMessage()[zi];
if (DL >= 5) {
*calcs1 << "." << zChildLinks[i]->toString (zi);
*calcs2 << " * " << zChildLinks[i]->getMessage()[zi];
}
}
}
}
zxPiNextMessage[zi] = product;
if (DL >= 5) {
cout << " " << link->toString();
cout << "[" << z->states()[zi] << "]" ;
cout << " = " << (*calcs1).str();
if (zChildLinks.size() == 1) {
cout << " = " << product << endl;
} else {
cout << " = " << (*calcs2).str();
cout << " = " << product << endl;
}
delete calcs1;
delete calcs2;
}
}
Util::normalize (zxPiNextMessage);
}
void
BnBpSolver::calculateLambdaMessage (BpLink* link)
{
// λY(Xi)
BayesNode* y = link->getSource();
BayesNode* x = link->getDestination();
if (x->hasEvidence()) {
return;
}
Params& yxLambdaNextMessage = link->getNextMessage();
const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks();
const Params& yLambdaValues = ninf(y)->getLambdaValues();
int parentIndex = y->getIndexOfParent (x);
stringstream* calcs1 = 0;
stringstream* calcs2 = 0;
const BnNodeSet& ps = y->getParents();
Ranges ranges;
for (unsigned i = 0; i < ps.size(); i++) {
ranges.push_back (ps[i]->nrStates());
}
StatesIndexer indexer (ranges, false);
unsigned N = indexer.size() / x->nrStates();
Params messageProducts (N);
for (unsigned k = 0; k < N; k++) {
while (indexer[parentIndex] != 0) {
++ indexer;
}
if (DL >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double messageProduct = Util::multIdenty();
if (Globals::logDomain) {
for (unsigned i = 0; i < yParentLinks.size(); i++) {
if (yParentLinks[i]->getSource() != x) {
messageProduct += yParentLinks[i]->getMessage()[indexer[i]];
}
}
} else {
for (unsigned i = 0; i < yParentLinks.size(); i++) {
if (yParentLinks[i]->getSource() != x) {
if (DL >= 5) {
if (messageProduct != Util::multIdenty()) *calcs1 << "*" ;
if (messageProduct != Util::multIdenty()) *calcs2 << "*" ;
*calcs1 << yParentLinks[i]->toString (indexer[i]);
*calcs2 << yParentLinks[i]->getMessage()[indexer[i]];
}
messageProduct *= yParentLinks[i]->getMessage()[indexer[i]];
}
}
}
messageProducts[k] = messageProduct;
++ indexer;
if (DL >= 5) {
cout << " mp" << k;
cout << " = " << (*calcs1).str();
if (yParentLinks.size() == 1) {
cout << 1 << endl;
} else if (yParentLinks.size() == 2) {
cout << " = " << messageProduct << endl;
} else {
cout << " = " << (*calcs2).str();
cout << " = " << messageProduct << endl;
}
delete calcs1;
delete calcs2;
}
}
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
if (DL >= 5) {
calcs1 = new stringstream;
calcs2 = new stringstream;
}
double outerSum = Util::addIdenty();
for (unsigned yi = 0; yi < y->nrStates(); yi++) {
if (DL >= 5) {
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
}
double innerSum = Util::addIdenty();
indexer.reset();
if (Globals::logDomain) {
for (unsigned k = 0; k < N; k++) {
while (indexer[parentIndex] != xi) {
++ indexer;
}
Util::logSum (innerSum, y->getProbability (
yi, indexer.linearIndex()) + messageProducts[k]);
++ indexer;
}
Util::logSum (outerSum, innerSum + yLambdaValues[yi]);
} else {
for (unsigned k = 0; k < N; k++) {
while (indexer[parentIndex] != xi) {
++ indexer;
}
if (DL >= 5) {
if (k != 0) *calcs1 << " + " ;
if (k != 0) *calcs2 << " + " ;
*calcs1 << y->cptEntryToString (yi, indexer.indices());
*calcs1 << ".mp" << k;
*calcs2 << y->getProbability (yi, indexer.linearIndex());
*calcs2 << "*" << messageProducts[k];
}
innerSum += y->getProbability (
yi, indexer.linearIndex()) * messageProducts[k];
++ indexer;
}
outerSum += innerSum * yLambdaValues[yi];
}
if (DL >= 5) {
*calcs1 << "}." << LD_SYMBOL << "(" << y->label() << ")" ;
*calcs1 << "[" << y->states()[yi] << "]";
*calcs2 << "}*" << yLambdaValues[yi];
}
}
yxLambdaNextMessage[xi] = outerSum;
if (DL >= 5) {
cout << " " << link->toString();
cout << "[" << x->states()[xi] << "]" ;
cout << " = " << (*calcs1).str();
cout << " = " << (*calcs2).str();
cout << " = " << yxLambdaNextMessage[xi] << endl;
delete calcs1;
delete calcs2;
}
}
Util::normalize (yxLambdaNextMessage);
}
Params
BnBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
{
BnNodeSet jointVars;
for (unsigned i = 0; i < jointVarIds.size(); i++) {
assert (bayesNet_->getBayesNode (jointVarIds[i]));
jointVars.push_back (bayesNet_->getBayesNode (jointVarIds[i]));
}
BayesNet* mrn = bayesNet_->getMinimalRequesiteNetwork (jointVarIds[0]);
BnBpSolver solver (*mrn);
solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
delete mrn;
VarIds observedVids = {jointVars[0]->varId()};
for (unsigned i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
VarIds reqVars = {jointVarIds[i]};
reqVars.insert (reqVars.end(), observedVids.begin(), observedVids.end());
mrn = bayesNet_->getMinimalRequesiteNetwork (reqVars);
Params newBeliefs;
VarNodes observedVars;
for (unsigned j = 0; j < observedVids.size(); j++) {
observedVars.push_back (mrn->getBayesNode (observedVids[j]));
}
StatesIndexer idx (observedVars, false);
while (idx.valid()) {
for (unsigned j = 0; j < observedVars.size(); j++) {
observedVars[j]->setEvidence (idx[j]);
}
BnBpSolver solver (*mrn);
solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (unsigned k = 0; k < beliefs.size(); k++) {
newBeliefs.push_back (beliefs[k]);
}
++ idx;
}
int count = -1;
for (unsigned j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->nrStates() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
}
prevBeliefs = newBeliefs;
observedVids.push_back (jointVars[i]->varId());
delete mrn;
}
return prevBeliefs;
}
void
BnBpSolver::printPiLambdaValues (const BayesNode* var) const
{
cout << left;
cout << setw (10) << "states" ;
cout << setw (20) << PI_SYMBOL << "(" + var->label() + ")" ;
cout << setw (20) << LD_SYMBOL << "(" + var->label() + ")" ;
cout << setw (16) << "belief" ;
cout << endl;
cout << "--------------------------------" ;
cout << "--------------------------------" ;
cout << endl;
const States& states = var->states();
const Params& piVals = ninf(var)->getPiValues();
const Params& ldVals = ninf(var)->getLambdaValues();
const Params& beliefs = ninf(var)->getBeliefs();
for (unsigned xi = 0; xi < var->nrStates(); xi++) {
cout << setw (10) << states[xi];
cout << setw (19) << piVals[xi];
cout << setw (19) << ldVals[xi];
cout.precision (PRECISION);
cout << setw (16) << beliefs[xi];
cout << endl;
}
cout << endl;
}
void
BnBpSolver::printAllMessageStatus (void) const
{
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
for (unsigned i = 0; i < nodes.size(); i++) {
printPiLambdaValues (nodes[i]);
}
}
BpNodeInfo::BpNodeInfo (BayesNode* node)
{
node_ = node;
piVals_.resize (node->nrStates(), Util::one());
ldVals_.resize (node->nrStates(), Util::one());
}
Params
BpNodeInfo::getBeliefs (void) const
{
double sum = 0.0;
Params beliefs (node_->nrStates());
if (Globals::logDomain) {
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
beliefs[xi] = exp (piVals_[xi] + ldVals_[xi]);
sum += beliefs[xi];
}
} else {
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
beliefs[xi] = piVals_[xi] * ldVals_[xi];
sum += beliefs[xi];
}
}
assert (sum);
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
beliefs[xi] /= sum;
}
return beliefs;
}
bool
BpNodeInfo::receivedBottomInfluence (void) const
{
// if all lambda values are equal, then neither
// this node neither its descendents have evidence,
// we can use this to don't send lambda messages his parents
bool childInfluenced = false;
for (unsigned xi = 1; xi < node_->nrStates(); xi++) {
if (ldVals_[xi] != ldVals_[0]) {
childInfluenced = true;
break;
}
}
return childInfluenced;
}

View File

@ -1,245 +0,0 @@
#ifndef HORUS_BNBPSOLVER_H
#define HORUS_BNBPSOLVER_H
#include <vector>
#include <set>
#include "Solver.h"
#include "BayesNet.h"
#include "Horus.h"
#include "Util.h"
using namespace std;
class BpNodeInfo;
static const string PI_SYMBOL = "pi" ;
static const string LD_SYMBOL = "ld" ;
enum LinkOrientation {UP, DOWN};
class BpLink
{
public:
BpLink (BayesNode* s, BayesNode* d, LinkOrientation o)
{
source_ = s;
destin_ = d;
orientation_ = o;
if (orientation_ == LinkOrientation::DOWN) {
v1_.resize (s->nrStates(), Util::tl (1.0 / s->nrStates()));
v2_.resize (s->nrStates(), Util::tl (1.0 / s->nrStates()));
} else {
v1_.resize (d->nrStates(), Util::tl (1.0 / d->nrStates()));
v2_.resize (d->nrStates(), Util::tl (1.0 / d->nrStates()));
}
currMsg_ = &v1_;
nextMsg_ = &v2_;
residual_ = 0;
msgSended_ = false;
}
void updateMessage (void)
{
swap (currMsg_, nextMsg_);
msgSended_ = true;
}
void updateResidual (void)
{
residual_ = Util::getMaxNorm (v1_, v2_);
}
string toString (void) const
{
stringstream ss;
if (orientation_ == LinkOrientation::DOWN) {
ss << PI_SYMBOL;
} else {
ss << LD_SYMBOL;
}
ss << "(" << source_->label();
ss << " --> " << destin_->label() << ")" ;
return ss.str();
}
string toString (unsigned stateIndex) const
{
stringstream ss;
ss << toString() << "[" ;
if (orientation_ == LinkOrientation::DOWN) {
ss << source_->states()[stateIndex] << "]" ;
} else {
ss << destin_->states()[stateIndex] << "]" ;
}
return ss.str();
}
BayesNode* getSource (void) const { return source_; }
BayesNode* getDestination (void) const { return destin_; }
LinkOrientation getOrientation (void) const { return orientation_; }
const Params& getMessage (void) const { return *currMsg_; }
Params& getNextMessage (void) { return *nextMsg_; }
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0;}
private:
BayesNode* source_;
BayesNode* destin_;
LinkOrientation orientation_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
bool msgSended_;
double residual_;
};
typedef vector<BpLink*> BpLinkSet;
class BpNodeInfo
{
public:
BpNodeInfo (BayesNode*);
Params getBeliefs (void) const;
bool receivedBottomInfluence (void) const;
Params& getPiValues (void) { return piVals_; }
Params& getLambdaValues (void) { return ldVals_; }
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
const BpLinkSet& getOutcomingParentLinks (void) { return outParentLinks_; }
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
void addOutcomingParentLink (BpLink* l) { outParentLinks_.push_back (l); }
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
private:
DISALLOW_COPY_AND_ASSIGN (BpNodeInfo);
const BayesNode* node_;
Params piVals_; // pi values
Params ldVals_; // lambda values
BpLinkSet inParentLinks_;
BpLinkSet inChildLinks_;
BpLinkSet outParentLinks_;
BpLinkSet outChildLinks_;
};
class BnBpSolver : public Solver
{
public:
BnBpSolver (const BayesNet&);
~BnBpSolver (void);
void runSolver (void);
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
private:
DISALLOW_COPY_AND_ASSIGN (BnBpSolver);
void initializeSolver (void);
void runLoopySolver (void);
void maxResidualSchedule (void);
bool converged (void) const;
void updatePiValues (BayesNode*);
void updateLambdaValues (BayesNode*);
void calculateLambdaMessage (BpLink*);
void calculatePiMessage (BpLink*);
Params getJointByJunctionNode (const VarIds&);
Params getJointByConditioning (const VarIds&) const;
void printPiLambdaValues (const BayesNode*) const;
void printAllMessageStatus (void) const;
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
{
if (DL >= 3) {
cout << "calculating & updating " << link->toString() << endl;
}
if (link->getOrientation() == LinkOrientation::DOWN) {
calculatePiMessage (link);
} else if (link->getOrientation() == LinkOrientation::UP) {
calculateLambdaMessage (link);
}
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void calculateMessage (BpLink* link, bool calcResidual = true)
{
if (DL >= 3) {
cout << "calculating " << link->toString() << endl;
}
if (link->getOrientation() == LinkOrientation::DOWN) {
calculatePiMessage (link);
} else if (link->getOrientation() == LinkOrientation::UP) {
calculateLambdaMessage (link);
}
if (calcResidual) {
link->updateResidual();
}
}
void updateMessage (BpLink* link)
{
if (DL >= 3) {
cout << "updating " << link->toString() << endl;
}
link->updateMessage();
}
void updateValues (BpLink* link)
{
if (!link->getDestination()->hasEvidence()) {
if (link->getOrientation() == LinkOrientation::DOWN) {
updatePiValues (link->getDestination());
} else if (link->getOrientation() == LinkOrientation::UP) {
updateLambdaValues (link->getDestination());
}
}
}
BpNodeInfo* ninf (const BayesNode* node) const
{
assert (node);
assert (node == bayesNet_->getBayesNode (node->varId()));
assert (node->getIndex() < nodesI_.size());
return nodesI_[node->getIndex()];
}
const BayesNet* bayesNet_;
vector<BpLink*> links_;
vector<BpNodeInfo*> nodesI_;
unsigned nIters_;
struct compare
{
inline bool operator() (const BpLink* e1, const BpLink* e2)
{
return e1->getResidual() > e2->getResidual();
}
};
typedef multiset<BpLink*, compare> SortedOrder;
SortedOrder sortedOrder_;
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
BpLinkMap linkMap_;
};
#endif // HORUS_BNBPSOLVER_H

View File

@ -5,21 +5,22 @@
#include <iostream>
#include "FgBpSolver.h"
#include "BpSolver.h"
#include "FactorGraph.h"
#include "Factor.h"
#include "Indexer.h"
#include "Horus.h"
FgBpSolver::FgBpSolver (const FactorGraph& fg) : Solver (&fg)
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
{
factorGraph_ = &fg;
fg_ = &fg;
runned_ = false;
}
FgBpSolver::~FgBpSolver (void)
BpSolver::~BpSolver (void)
{
for (unsigned i = 0; i < varsI_.size(); i++) {
delete varsI_[i];
@ -34,64 +35,45 @@ FgBpSolver::~FgBpSolver (void)
void
FgBpSolver::runSolver (void)
Params
BpSolver::solveQuery (VarIds queryVids)
{
clock_t start;
if (COLLECT_STATISTICS) {
start = clock();
}
runLoopySolver();
if (DL >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = factorGraph_->getVarNodes().size();
if (COLLECT_STATISTICS) {
unsigned nIters = 0;
bool loopy = factorGraph_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
stringstream ss;
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
factorGraph_->exportToGraphViz (ss.str().c_str());
assert (queryVids.empty() == false);
if (queryVids.size() == 1) {
return getPosterioriOf (queryVids[0]);
} else {
return getJointDistributionOf (queryVids);
}
}
Params
FgBpSolver::getPosterioriOf (VarId vid)
BpSolver::getPosterioriOf (VarId vid)
{
assert (factorGraph_->getFgVarNode (vid));
FgVarNode* var = factorGraph_->getFgVarNode (vid);
if (runned_ == false) {
runSolver();
}
assert (fg_->getVarNode (vid));
VarNode* var = fg_->getVarNode (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->nrStates(), Util::noEvidence());
probs[var->getEvidence()] = Util::withEvidence();
probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = LogAware::withEvidence();
} else {
probs.resize (var->nrStates(), Util::multIdenty());
probs.resize (var->range(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
Util::add (probs, links[i]->getMessage());
}
Util::normalize (probs);
LogAware::normalize (probs);
Util::fromLog (probs);
} else {
for (unsigned i = 0; i < links.size(); i++) {
Util::multiply (probs, links[i]->getMessage());
}
Util::normalize (probs);
LogAware::normalize (probs);
}
}
return probs;
@ -100,13 +82,16 @@ FgBpSolver::getPosterioriOf (VarId vid)
Params
FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
{
FgVarNode* vn = factorGraph_->getFgVarNode (jointVarIds[0]);
const FgFacSet& factorNodes = vn->neighbors();
if (runned_ == false) {
runSolver();
}
int idx = -1;
for (unsigned i = 0; i < factorNodes.size(); i++) {
if (factorNodes[i]->factor()->contains (jointVarIds)) {
VarNode* vn = fg_->getVarNode (jointVarIds[0]);
const FacNodes& facNodes = vn->neighbors();
for (unsigned i = 0; i < facNodes.size(); i++) {
if (facNodes[i]->factor().contains (jointVarIds)) {
idx = i;
break;
}
@ -114,18 +99,18 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
if (idx == -1) {
return getJointByConditioning (jointVarIds);
} else {
Factor r (*factorNodes[idx]->factor());
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
Factor res (facNodes[idx]->factor());
const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
for (unsigned i = 0; i < links.size(); i++) {
Factor msg (links[i]->getVariable()->varId(),
links[i]->getVariable()->nrStates(),
Factor msg ({links[i]->getVariable()->varId()},
{links[i]->getVariable()->range()},
getVar2FactorMsg (links[i]));
r.multiply (msg);
res.multiply (msg);
}
r.sumOutAllExcept (jointVarIds);
r.reorderVariables (jointVarIds);
r.normalize();
Params jointDist = r.getParameters();
res.sumOutAllExcept (jointVarIds);
res.reorderArguments (jointVarIds);
res.normalize();
Params jointDist = res.params();
if (Globals::logDomain) {
Util::fromLog (jointDist);
}
@ -136,35 +121,29 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
void
FgBpSolver::runLoopySolver (void)
BpSolver::runSolver (void)
{
clock_t start;
if (Constants::COLLECT_STATS) {
start = clock();
}
initializeSolver();
nIters_ = 0;
while (!converged() && nIters_ < BpOptions::maxIter) {
nIters_ ++;
if (DL >= 2) {
cout << "****************************************" ;
cout << "****************************************" ;
cout << endl;
cout << " Iteration " << nIters_ << endl;
cout << "****************************************" ;
cout << "****************************************" ;
cout << endl;
if (Constants::DEBUG >= 2) {
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
// cout << endl;
}
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_RANDOM:
random_shuffle (links_.begin(), links_.end());
// no break
case BpOptions::Schedule::SEQ_FIXED:
for (unsigned i = 0; i < links_.size(); i++) {
calculateAndUpdateMessage (links_[i]);
}
break;
case BpOptions::Schedule::PARALLEL:
for (unsigned i = 0; i < links_.size(); i++) {
calculateMessage (links_[i]);
@ -173,61 +152,43 @@ FgBpSolver::runLoopySolver (void)
updateMessage(links_[i]);
}
break;
case BpOptions::Schedule::MAX_RESIDUAL:
maxResidualSchedule();
break;
}
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
if (Constants::DEBUG >= 2) {
cout << endl;
if (nIters_ < BpOptions::maxIter) {
cout << "Sum-Product converged in " ;
cout << nIters_ << " iterations" << endl;
} else {
cout << "The maximum number of iterations was hit, terminating..." ;
cout << endl;
}
}
unsigned size = fg_->varNodes().size();
if (Constants::COLLECT_STATS) {
unsigned nIters = 0;
bool loopy = fg_->isTree() == false;
if (loopy) nIters = nIters_;
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
Statistics::updateStatistics (size, loopy, nIters, time);
}
runned_ = true;
}
void
FgBpSolver::initializeSolver (void)
BpSolver::createLinks (void)
{
const FgVarSet& varNodes = factorGraph_->getVarNodes();
for (unsigned i = 0; i < varsI_.size(); i++) {
delete varsI_[i];
}
varsI_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo());
}
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
for (unsigned i = 0; i < facsI_.size(); i++) {
delete facsI_[i];
}
facsI_.reserve (facNodes.size());
const FacNodes& facNodes = fg_->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo());
}
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
createLinks();
for (unsigned i = 0; i < links_.size(); i++) {
FgFacNode* src = links_[i]->getFactor();
FgVarNode* dst = links_[i]->getVariable();
ninf (dst)->addSpLink (links_[i]);
ninf (src)->addSpLink (links_[i]);
}
}
void
FgBpSolver::createLinks (void)
{
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
const FgVarSet& neighbors = facNodes[i]->neighbors();
const VarNodes& neighbors = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighbors.size(); j++) {
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
}
@ -236,42 +197,8 @@ FgBpSolver::createLinks (void)
bool
FgBpSolver::converged (void)
{
if (links_.size() == 0) {
return true;
}
if (nIters_ == 0 || nIters_ == 1) {
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (maxResidual > BpOptions::accuracy) {
converged = false;
} else {
converged = true;
}
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (DL >= 2) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
if (DL == 0) break;
}
}
}
return converged;
}
void
FgBpSolver::maxResidualSchedule (void)
BpSolver::maxResidualSchedule (void)
{
if (nIters_ == 1) {
for (unsigned i = 0; i < links_.size(); i++) {
@ -283,7 +210,7 @@ FgBpSolver::maxResidualSchedule (void)
}
for (unsigned c = 0; c < links_.size(); c++) {
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) {
@ -303,7 +230,7 @@ FgBpSolver::maxResidualSchedule (void)
linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
if (factorNeighbors[i] != link->getFactor()) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
@ -317,9 +244,8 @@ FgBpSolver::maxResidualSchedule (void)
}
}
}
if (DL >= 2) {
cout << "----------------------------------------" ;
cout << "----------------------------------------" << endl;
if (Constants::DEBUG >= 2) {
Util::printDashedLine();
}
}
}
@ -327,26 +253,26 @@ FgBpSolver::maxResidualSchedule (void)
void
FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
BpSolver::calculateFactor2VariableMsg (SpLink* link)
{
const FgFacNode* src = link->getFactor();
const FgVarNode* dst = link->getVariable();
FacNode* src = link->getFactor();
const VarNode* dst = link->getVariable();
const SpLinkSet& links = ninf(src)->getLinks();
// calculate the product of messages that were sent
// to factor `src', except from var `dst'
unsigned msgSize = 1;
for (unsigned i = 0; i < links.size(); i++) {
msgSize *= links[i]->getVariable()->nrStates();
msgSize *= links[i]->getVariable()->range();
}
unsigned repetitions = 1;
Params msgProduct (msgSize, Util::multIdenty());
Params msgProduct (msgSize, LogAware::multIdenty());
if (Globals::logDomain) {
for (int i = links.size() - 1; i >= 0; i--) {
if (links[i]->getVariable() != dst) {
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->nrStates();
repetitions *= links[i]->getVariable()->range();
} else {
unsigned ds = links[i]->getVariable()->nrStates();
unsigned ds = links[i]->getVariable()->range();
Util::add (msgProduct, Params (ds, 1.0), repetitions);
repetitions *= ds;
}
@ -354,70 +280,64 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
} else {
for (int i = links.size() - 1; i >= 0; i--) {
if (links[i]->getVariable() != dst) {
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
cout << " message from " << links[i]->getVariable()->label();
cout << ": " << endl;
}
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
repetitions *= links[i]->getVariable()->nrStates();
repetitions *= links[i]->getVariable()->range();
} else {
unsigned ds = links[i]->getVariable()->nrStates();
unsigned ds = links[i]->getVariable()->range();
Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
repetitions *= ds;
}
}
}
Factor result (src->factor()->getVarIds(),
src->factor()->getRanges(),
msgProduct);
result.multiply (*(src->factor()));
if (DL >= 5) {
cout << " message product: " ;
cout << Util::parametersToString (msgProduct) << endl;
cout << " original factor: " ;
cout << Util::parametersToString (src->getParameters()) << endl;
cout << " factor product: " ;
cout << Util::parametersToString (result.getParameters()) << endl;
Factor result (src->factor().arguments(),
src->factor().ranges(), msgProduct);
result.multiply (src->factor());
if (Constants::DEBUG >= 5) {
cout << " message product: " << msgProduct << endl;
cout << " original factor: " << src->factor().params() << endl;
cout << " factor product: " << result.params() << endl;
}
result.sumOutAllExcept (dst->varId());
if (DL >= 5) {
if (Constants::DEBUG >= 5) {
cout << " marginalized: " ;
cout << Util::parametersToString (result.getParameters()) << endl;
cout << result.params() << endl;
}
const Params& resultParams = result.getParameters();
const Params& resultParams = result.params();
Params& message = link->getNextMessage();
for (unsigned i = 0; i < resultParams.size(); i++) {
message[i] = resultParams[i];
}
Util::normalize (message);
if (DL >= 5) {
cout << " curr msg: " ;
cout << Util::parametersToString (link->getMessage()) << endl;
cout << " next msg: " ;
cout << Util::parametersToString (message) << endl;
LogAware::normalize (message);
if (Constants::DEBUG >= 5) {
cout << " curr msg: " << link->getMessage() << endl;
cout << " next msg: " << message << endl;
}
}
Params
FgBpSolver::getVar2FactorMsg (const SpLink* link) const
BpSolver::getVar2FactorMsg (const SpLink* link) const
{
const FgVarNode* src = link->getVariable();
const FgFacNode* dst = link->getFactor();
const VarNode* src = link->getVariable();
const FacNode* dst = link->getFactor();
Params msg;
if (src->hasEvidence()) {
msg.resize (src->nrStates(), Util::noEvidence());
msg[src->getEvidence()] = Util::withEvidence();
if (DL >= 5) {
cout << Util::parametersToString (msg);
msg.resize (src->range(), LogAware::noEvidence());
msg[src->getEvidence()] = LogAware::withEvidence();
if (Constants::DEBUG >= 5) {
cout << msg;
}
} else {
msg.resize (src->nrStates(), Util::one());
msg.resize (src->range(), LogAware::one());
}
if (DL >= 5) {
cout << Util::parametersToString (msg);
if (Constants::DEBUG >= 5) {
cout << msg;
}
const SpLinkSet& links = ninf (src)->getLinks();
if (Globals::logDomain) {
@ -430,14 +350,14 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
Util::multiply (msg, links[i]->getMessage());
if (DL >= 5) {
cout << " x " << Util::parametersToString (links[i]->getMessage());
if (Constants::DEBUG >= 5) {
cout << " x " << links[i]->getMessage();
}
}
}
}
if (DL >= 5) {
cout << " = " << Util::parametersToString (msg);
if (Constants::DEBUG >= 5) {
cout << " = " << msg;
}
return msg;
}
@ -445,16 +365,16 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
Params
FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
{
FgVarSet jointVars;
VarNodes jointVars;
for (unsigned i = 0; i < jointVarIds.size(); i++) {
assert (factorGraph_->getFgVarNode (jointVarIds[i]));
jointVars.push_back (factorGraph_->getFgVarNode (jointVarIds[i]));
assert (fg_->getVarNode (jointVarIds[i]));
jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
}
FactorGraph* fg = new FactorGraph (*factorGraph_);
FgBpSolver solver (*fg);
FactorGraph* fg = new FactorGraph (*fg_);
BpSolver solver (*fg);
solver.runSolver();
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
@ -463,9 +383,9 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
for (unsigned i = 1; i < jointVarIds.size(); i++) {
assert (jointVars[i]->hasEvidence() == false);
Params newBeliefs;
VarNodes observedVars;
Vars observedVars;
for (unsigned j = 0; j < observedVids.size(); j++) {
observedVars.push_back (fg->getFgVarNode (observedVids[j]));
observedVars.push_back (fg->getVarNode (observedVids[j]));
}
StatesIndexer idx (observedVars, false);
while (idx.valid()) {
@ -473,7 +393,7 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
observedVars[j]->setEvidence (idx[j]);
}
++ idx;
FgBpSolver solver (*fg);
BpSolver solver (*fg);
solver.runSolver();
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
for (unsigned k = 0; k < beliefs.size(); k++) {
@ -483,7 +403,7 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
int count = -1;
for (unsigned j = 0; j < newBeliefs.size(); j++) {
if (j % jointVars[i]->nrStates() == 0) {
if (j % jointVars[i]->range() == 0) {
count ++;
}
newBeliefs[j] *= prevBeliefs[count];
@ -497,15 +417,76 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
void
FgBpSolver::printLinkInformation (void) const
BpSolver::initializeSolver (void)
{
const VarNodes& varNodes = fg_->varNodes();
varsI_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
varsI_.push_back (new SPNodeInfo());
}
const FacNodes& facNodes = fg_->facNodes();
facsI_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) {
facsI_.push_back (new SPNodeInfo());
}
createLinks();
for (unsigned i = 0; i < links_.size(); i++) {
FacNode* src = links_[i]->getFactor();
VarNode* dst = links_[i]->getVariable();
ninf (dst)->addSpLink (links_[i]);
ninf (src)->addSpLink (links_[i]);
}
}
bool
BpSolver::converged (void)
{
if (links_.size() == 0) {
return true;
}
if (nIters_ <= 1) {
return false;
}
bool converged = true;
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
if (maxResidual > BpOptions::accuracy) {
converged = false;
} else {
converged = true;
}
} else {
for (unsigned i = 0; i < links_.size(); i++) {
double residual = links_[i]->getResidual();
if (Constants::DEBUG >= 2) {
cout << links_[i]->toString() + " residual = " << residual << endl;
}
if (residual > BpOptions::accuracy) {
converged = false;
if (Constants::DEBUG == 0) break;
}
}
if (Constants::DEBUG >= 2) {
cout << endl;
}
}
return converged;
}
void
BpSolver::printLinkInformation (void) const
{
for (unsigned i = 0; i < links_.size(); i++) {
SpLink* l = links_[i];
cout << l->toString() << ":" << endl;
cout << " curr msg = " ;
cout << Util::parametersToString (l->getMessage()) << endl;
cout << l->getMessage() << endl;
cout << " next msg = " ;
cout << Util::parametersToString (l->getNextMessage()) << endl;
cout << l->getNextMessage() << endl;
cout << " residual = " << l->getResidual() << endl;
}
}

View File

@ -0,0 +1,188 @@
#ifndef HORUS_BPSOLVER_H
#define HORUS_BPSOLVER_H
#include <set>
#include <vector>
#include <sstream>
#include "Solver.h"
#include "Factor.h"
#include "FactorGraph.h"
#include "Util.h"
using namespace std;
class SpLink
{
public:
SpLink (FacNode* fn, VarNode* vn)
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
msgSended_ = false;
residual_ = 0.0;
}
virtual ~SpLink (void) { };
FacNode* getFactor (void) const { return fac_; }
VarNode* getVariable (void) const { return var_; }
const Params& getMessage (void) const { return *currMsg_; }
Params& getNextMessage (void) { return *nextMsg_; }
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0.0; }
void updateResidual (void)
{
residual_ = LogAware::getMaxNorm (v1_,v2_);
}
virtual void updateMessage (void)
{
swap (currMsg_, nextMsg_);
msgSended_ = true;
}
string toString (void) const
{
stringstream ss;
ss << fac_->getLabel();
ss << " -- " ;
ss << var_->label();
return ss.str();
}
protected:
FacNode* fac_;
VarNode* var_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
bool msgSended_;
double residual_;
};
typedef vector<SpLink*> SpLinkSet;
class SPNodeInfo
{
public:
void addSpLink (SpLink* link) { links_.push_back (link); }
const SpLinkSet& getLinks (void) { return links_; }
private:
SpLinkSet links_;
};
class BpSolver : public Solver
{
public:
BpSolver (const FactorGraph&);
virtual ~BpSolver (void);
Params solveQuery (VarIds);
virtual Params getPosterioriOf (VarId);
virtual Params getJointDistributionOf (const VarIds&);
protected:
void runSolver (void);
virtual void createLinks (void);
virtual void maxResidualSchedule (void);
virtual void calculateFactor2VariableMsg (SpLink*);
virtual Params getVar2FactorMsg (const SpLink*) const;
virtual Params getJointByConditioning (const VarIds&) const;
SPNodeInfo* ninf (const VarNode* var) const
{
return varsI_[var->getIndex()];
}
SPNodeInfo* ninf (const FacNode* fac) const
{
return facsI_[fac->getIndex()];
}
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
{
if (Constants::DEBUG >= 3) {
cout << "calculating & updating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void calculateMessage (SpLink* link, bool calcResidual = true)
{
if (Constants::DEBUG >= 3) {
cout << "calculating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
if (calcResidual) {
link->updateResidual();
}
}
void updateMessage (SpLink* link)
{
link->updateMessage();
if (Constants::DEBUG >= 3) {
cout << "updating " << link->toString() << endl;
}
}
struct CompareResidual
{
inline bool operator() (const SpLink* link1, const SpLink* link2)
{
return link1->getResidual() > link2->getResidual();
}
};
SpLinkSet links_;
unsigned nIters_;
vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_;
bool runned_;
const FactorGraph* fg_;
typedef multiset<SpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_;
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
SpLinkMap linkMap_;
private:
void initializeSolver (void);
bool converged (void);
void printLinkInformation (void) const;
};
#endif // HORUS_BPSOLVER_H

View File

@ -1,7 +1,6 @@
#include "CFactorGraph.h"
#include "Factor.h"
#include "Distribution.h"
bool CFactorGraph::checkForIdenticalFactors = true;
@ -11,22 +10,22 @@ CFactorGraph::CFactorGraph (const FactorGraph& fg)
groundFg_ = &fg;
freeColor_ = 0;
const FgVarSet& varNodes = fg.getVarNodes();
const VarNodes& varNodes = fg.varNodes();
varSignatures_.reserve (varNodes.size());
for (unsigned i = 0; i < varNodes.size(); i++) {
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
varSignatures_.push_back (Signature (c));
}
const FgFacSet& facNodes = fg.getFactorNodes();
factorSignatures_.reserve (facNodes.size());
const FacNodes& facNodes = fg.facNodes();
facSignatures_.reserve (facNodes.size());
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned c = facNodes[i]->neighbors().size() + 1;
factorSignatures_.push_back (Signature (c));
facSignatures_.push_back (Signature (c));
}
varColors_.resize (varNodes.size());
factorColors_.resize (facNodes.size());
facColors_.resize (facNodes.size());
setInitialColors();
createGroups();
}
@ -50,9 +49,9 @@ CFactorGraph::setInitialColors (void)
{
// create the initial variable colors
VarColorMap colorMap;
const FgVarSet& varNodes = groundFg_->getVarNodes();
const VarNodes& varNodes = groundFg_->varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
unsigned dsize = varNodes[i]->nrStates();
unsigned dsize = varNodes[i]->range();
VarColorMap::iterator it = colorMap.find (dsize);
if (it == colorMap.end()) {
it = colorMap.insert (make_pair (
@ -71,29 +70,40 @@ CFactorGraph::setInitialColors (void)
setColor (varNodes[i], stateColors[idx]);
}
const FgFacSet& facNodes = groundFg_->getFactorNodes();
if (checkForIdenticalFactors) {
for (unsigned i = 0, s = facNodes.size(); i < s; i++) {
Distribution* dist1 = facNodes[i]->getDistribution();
for (unsigned j = 0; j < i; j++) {
Distribution* dist2 = facNodes[j]->getDistribution();
if (dist1 != dist2 && dist1->params == dist2->params) {
if (facNodes[i]->factor()->getRanges() ==
facNodes[j]->factor()->getRanges()) {
facNodes[i]->factor()->setDistribution (dist2);
}
const FacNodes& facNodes = groundFg_->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
facNodes[i]->factor().setDistId (Util::maxUnsigned());
}
// FIXME FIXME FIXME : pfl should give correct dist ids.
if (checkForIdenticalFactors || true) {
unsigned groupCount = 1;
for (unsigned i = 0; i < facNodes.size(); i++) {
Factor& f1 = facNodes[i]->factor();
if (f1.distId() != Util::maxUnsigned()) {
continue;
}
f1.setDistId (groupCount);
for (unsigned j = i + 1; j < facNodes.size(); j++) {
Factor& f2 = facNodes[j]->factor();
if (f2.distId() != Util::maxUnsigned()) {
continue;
}
if (f1.size() == f2.size() &&
f1.ranges() == f2.ranges() &&
f1.params() == f2.params()) {
f2.setDistId (groupCount);
}
}
groupCount ++;
}
}
// create the initial factor colors
DistColorMap distColors;
for (unsigned i = 0; i < facNodes.size(); i++) {
const Distribution* dist = facNodes[i]->getDistribution();
DistColorMap::iterator it = distColors.find (dist);
unsigned distId = facNodes[i]->factor().distId();
DistColorMap::iterator it = distColors.find (distId);
if (it == distColors.end()) {
it = distColors.insert (make_pair (dist, getFreeColor())).first;
it = distColors.insert (make_pair (distId, getFreeColor())).first;
}
setColor (facNodes[i], it->second);
}
@ -104,31 +114,31 @@ CFactorGraph::setInitialColors (void)
void
CFactorGraph::createGroups (void)
{
VarSignMap varGroups;
FacSignMap factorGroups;
VarSignMap varGroups;
FacSignMap facGroups;
unsigned nIters = 0;
bool groupsHaveChanged = true;
const FgVarSet& varNodes = groundFg_->getVarNodes();
const FgFacSet& facNodes = groundFg_->getFactorNodes();
const VarNodes& varNodes = groundFg_->varNodes();
const FacNodes& facNodes = groundFg_->facNodes();
while (groupsHaveChanged || nIters == 1) {
nIters ++;
unsigned prevFactorGroupsSize = factorGroups.size();
factorGroups.clear();
unsigned prevFactorGroupsSize = facGroups.size();
facGroups.clear();
// set a new color to the factors with the same signature
for (unsigned i = 0; i < facNodes.size(); i++) {
const Signature& signature = getSignature (facNodes[i]);
FacSignMap::iterator it = factorGroups.find (signature);
if (it == factorGroups.end()) {
it = factorGroups.insert (make_pair (signature, FgFacSet())).first;
FacSignMap::iterator it = facGroups.find (signature);
if (it == facGroups.end()) {
it = facGroups.insert (make_pair (signature, FacNodes())).first;
}
it->second.push_back (facNodes[i]);
}
for (FacSignMap::iterator it = factorGroups.begin();
it != factorGroups.end(); it++) {
for (FacSignMap::iterator it = facGroups.begin();
it != facGroups.end(); it++) {
Color newColor = getFreeColor();
FgFacSet& groupMembers = it->second;
FacNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
@ -141,36 +151,37 @@ CFactorGraph::createGroups (void)
const Signature& signature = getSignature (varNodes[i]);
VarSignMap::iterator it = varGroups.find (signature);
if (it == varGroups.end()) {
it = varGroups.insert (make_pair (signature, FgVarSet())).first;
it = varGroups.insert (make_pair (signature, VarNodes())).first;
}
it->second.push_back (varNodes[i]);
}
for (VarSignMap::iterator it = varGroups.begin();
it != varGroups.end(); it++) {
Color newColor = getFreeColor();
FgVarSet& groupMembers = it->second;
VarNodes& groupMembers = it->second;
for (unsigned i = 0; i < groupMembers.size(); i++) {
setColor (groupMembers[i], newColor);
}
}
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|| prevFactorGroupsSize != factorGroups.size();
|| prevFactorGroupsSize != facGroups.size();
}
//printGroups (varGroups, factorGroups);
createClusters (varGroups, factorGroups);
printGroups (varGroups, facGroups);
createClusters (varGroups, facGroups);
}
void
CFactorGraph::createClusters (const VarSignMap& varGroups,
const FacSignMap& factorGroups)
CFactorGraph::createClusters (
const VarSignMap& varGroups,
const FacSignMap& facGroups)
{
varClusters_.reserve (varGroups.size());
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); it++) {
const FgVarSet& groupVars = it->second;
const VarNodes& groupVars = it->second;
VarCluster* vc = new VarCluster (groupVars);
for (unsigned i = 0; i < groupVars.size(); i++) {
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
@ -178,12 +189,12 @@ CFactorGraph::createClusters (const VarSignMap& varGroups,
varClusters_.push_back (vc);
}
facClusters_.reserve (factorGroups.size());
for (FacSignMap::const_iterator it = factorGroups.begin();
it != factorGroups.end(); it++) {
FgFacNode* groupFactor = it->second[0];
const FgVarSet& neighs = groupFactor->neighbors();
VarClusterSet varClusters;
facClusters_.reserve (facGroups.size());
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); it++) {
FacNode* groupFactor = it->second[0];
const VarNodes& neighs = groupFactor->neighbors();
VarClusters varClusters;
varClusters.reserve (neighs.size());
for (unsigned i = 0; i < neighs.size(); i++) {
VarId vid = neighs[i]->varId();
@ -196,15 +207,15 @@ CFactorGraph::createClusters (const VarSignMap& varGroups,
const Signature&
CFactorGraph::getSignature (const FgVarNode* varNode)
CFactorGraph::getSignature (const VarNode* varNode)
{
Signature& sign = varSignatures_[varNode->getIndex()];
vector<Color>::iterator it = sign.colors.begin();
const FgFacSet& neighs = varNode->neighbors();
const FacNodes& neighs = varNode->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]);
it ++;
*it = neighs[i]->factor()->indexOf (varNode->varId());
*it = neighs[i]->factor().indexOf (varNode->varId());
it ++;
}
*it = getColor (varNode);
@ -214,11 +225,11 @@ CFactorGraph::getSignature (const FgVarNode* varNode)
const Signature&
CFactorGraph::getSignature (const FgFacNode* facNode)
CFactorGraph::getSignature (const FacNode* facNode)
{
Signature& sign = factorSignatures_[facNode->getIndex()];
Signature& sign = facSignatures_[facNode->getIndex()];
vector<Color>::iterator it = sign.colors.begin();
const FgVarSet& neighs = facNode->neighbors();
const VarNodes& neighs = facNode->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) {
*it = getColor (neighs[i]);
it ++;
@ -230,55 +241,53 @@ CFactorGraph::getSignature (const FgFacNode* facNode)
FactorGraph*
CFactorGraph::getCompressedFactorGraph (void)
CFactorGraph::getGroundFactorGraph (void) const
{
FactorGraph* fg = new FactorGraph();
for (unsigned i = 0; i < varClusters_.size(); i++) {
FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0];
FgVarNode* newVar = new FgVarNode (var);
VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
VarNode* newVar = new VarNode (var);
varClusters_[i]->setRepresentativeVariable (newVar);
fg->addVariable (newVar);
fg->addVarNode (newVar);
}
for (unsigned i = 0; i < facClusters_.size(); i++) {
const VarClusterSet& myVarClusters = facClusters_[i]->getVarClusters();
VarNodes myGroundVars;
const VarClusters& myVarClusters = facClusters_[i]->getVarClusters();
Vars myGroundVars;
myGroundVars.reserve (myVarClusters.size());
for (unsigned j = 0; j < myVarClusters.size(); j++) {
FgVarNode* v = myVarClusters[j]->getRepresentativeVariable();
VarNode* v = myVarClusters[j]->getRepresentativeVariable();
myGroundVars.push_back (v);
}
Factor* newFactor = new Factor (myGroundVars,
facClusters_[i]->getGroundFactors()[0]->getDistribution());
FgFacNode* fn = new FgFacNode (newFactor);
FacNode* fn = new FacNode (Factor (myGroundVars,
facClusters_[i]->getGroundFactors()[0]->factor().params()));
facClusters_[i]->setRepresentativeFactor (fn);
fg->addFactor (fn);
fg->addFacNode (fn);
for (unsigned j = 0; j < myGroundVars.size(); j++) {
fg->addEdge (fn, static_cast<FgVarNode*> (myGroundVars[j]));
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
}
}
fg->setIndexes();
return fg;
}
unsigned
CFactorGraph::getGroundEdgeCount (
CFactorGraph::getEdgeCount (
const FacCluster* fc,
const VarCluster* vc) const
{
const FgFacSet& clusterGroundFactors = fc->getGroundFactors();
FgVarNode* varNode = vc->getGroundFgVarNodes()[0];
unsigned count = 0;
VarId vid = vc->getGroundVarNodes().front()->varId();
const FacNodes& clusterGroundFactors = fc->getGroundFactors();
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
if (clusterGroundFactors[i]->factor()->indexOf (varNode->varId()) != -1) {
if (clusterGroundFactors[i]->factor().contains (vid)) {
count ++;
}
}
// CFgVarSet vars = vc->getGroundFgVarNodes();
// CVarNodes vars = vc->getGroundVarNodes();
// for (unsigned i = 1; i < vars.size(); i++) {
// FgVarNode* var = vc->getGroundFgVarNodes()[i];
// VarNode* var = vc->getGroundVarNodes()[i];
// unsigned count2 = 0;
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
// if (clusterGroundFactors[i]->getPosition (var) != -1) {
@ -293,14 +302,15 @@ CFactorGraph::getGroundEdgeCount (
void
CFactorGraph::printGroups (const VarSignMap& varGroups,
const FacSignMap& factorGroups) const
CFactorGraph::printGroups (
const VarSignMap& varGroups,
const FacSignMap& facGroups) const
{
unsigned count = 1;
cout << "variable groups:" << endl;
for (VarSignMap::const_iterator it = varGroups.begin();
it != varGroups.end(); it++) {
const FgVarSet& groupMembers = it->second;
const VarNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << count << ": " ;
for (unsigned i = 0; i < groupMembers.size(); i++) {
@ -313,9 +323,9 @@ CFactorGraph::printGroups (const VarSignMap& varGroups,
count = 1;
cout << endl << "factor groups:" << endl;
for (FacSignMap::const_iterator it = factorGroups.begin();
it != factorGroups.end(); it++) {
const FgFacSet& groupMembers = it->second;
for (FacSignMap::const_iterator it = facGroups.begin();
it != facGroups.end(); it++) {
const FacNodes& groupMembers = it->second;
if (groupMembers.size() > 0) {
cout << ++count << ": " ;
for (unsigned i = 0; i < groupMembers.size(); i++) {

View File

@ -15,23 +15,25 @@ class Signature;
class SignatureHash;
typedef long Color;
typedef unordered_map<unsigned, vector<Color> > VarColorMap;
typedef unordered_map<const Distribution*, Color> DistColorMap;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
typedef vector<VarCluster*> VarClusterSet;
typedef vector<FacCluster*> FacClusterSet;
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
typedef long Color;
typedef unordered_map<unsigned, vector<Color>> VarColorMap;
typedef unordered_map<unsigned, Color> DistColorMap;
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
typedef vector<VarCluster*> VarClusters;
typedef vector<FacCluster*> FacClusters;
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
struct Signature
{
Signature (unsigned size)
{
colors.resize (size);
}
Signature (unsigned size) : colors(size) { }
bool operator< (const Signature& sig) const
{
if (colors.size() < sig.colors.size()) {
@ -49,6 +51,7 @@ struct Signature
}
return false;
}
bool operator== (const Signature& sig) const
{
if (colors.size() != sig.colors.size()) {
@ -61,12 +64,14 @@ struct Signature
}
return true;
}
vector<Color> colors;
};
struct SignatureHash {
struct SignatureHash
{
size_t operator() (const Signature &sig) const
{
size_t val = hash<size_t>()(sig.colors.size());
@ -82,7 +87,7 @@ struct SignatureHash {
class VarCluster
{
public:
VarCluster (const FgVarSet& vs)
VarCluster (const VarNodes& vs)
{
for (unsigned i = 0; i < vs.size(); i++) {
groundVars_.push_back (vs[i]);
@ -94,26 +99,28 @@ class VarCluster
facClusters_.push_back (fc);
}
const FacClusterSet& getFacClusters (void) const
const FacClusters& getFacClusters (void) const
{
return facClusters_;
}
FgVarNode* getRepresentativeVariable (void) const { return representVar_; }
void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; }
const FgVarSet& getGroundFgVarNodes (void) const { return groundVars_; }
VarNode* getRepresentativeVariable (void) const { return representVar_; }
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
private:
FgVarSet groundVars_;
FacClusterSet facClusters_;
FgVarNode* representVar_;
VarNodes groundVars_;
FacClusters facClusters_;
VarNode* representVar_;
};
class FacCluster
{
public:
FacCluster (const FgFacSet& groundFactors, const VarClusterSet& vcs)
FacCluster (const FacNodes& groundFactors, const VarClusters& vcs)
{
groundFactors_ = groundFactors;
varClusters_ = vcs;
@ -122,12 +129,12 @@ class FacCluster
}
}
const VarClusterSet& getVarClusters (void) const
const VarClusters& getVarClusters (void) const
{
return varClusters_;
}
bool containsGround (const FgFacNode* fn)
bool containsGround (const FacNode* fn)
{
for (unsigned i = 0; i < groundFactors_.size(); i++) {
if (groundFactors_[i] == fn) {
@ -137,24 +144,26 @@ class FacCluster
return false;
}
FgFacNode* getRepresentativeFactor (void) const
FacNode* getRepresentativeFactor (void) const
{
return representFactor_;
}
void setRepresentativeFactor (FgFacNode* fn)
void setRepresentativeFactor (FacNode* fn)
{
representFactor_ = fn;
}
const FgFacSet& getGroundFactors (void) const
const FacNodes& getGroundFactors (void) const
{
return groundFactors_;
}
private:
FgFacSet groundFactors_;
VarClusterSet varClusters_;
FgFacNode* representFactor_;
FacNodes groundFactors_;
VarClusters varClusters_;
FacNode* representFactor_;
};
@ -162,51 +171,48 @@ class CFactorGraph
{
public:
CFactorGraph (const FactorGraph&);
~CFactorGraph (void);
FactorGraph* getCompressedFactorGraph (void);
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
const VarClusters& getVarClusters (void) { return varClusters_; }
FgVarNode* getEquivalentVariable (VarId vid)
const FacClusters& getFacClusters (void) { return facClusters_; }
VarNode* getEquivalentVariable (VarId vid)
{
VarCluster* vc = vid2VarCluster_.find (vid)->second;
return vc->getRepresentativeVariable();
}
const VarClusterSet& getVarClusters (void) { return varClusters_; }
const FacClusterSet& getFacClusters (void) { return facClusters_; }
FactorGraph* getGroundFactorGraph (void) const;
unsigned getEdgeCount (const FacCluster*, const VarCluster*) const;
static bool checkForIdenticalFactors;
private:
void setInitialColors (void);
void createGroups (void);
void createClusters (const VarSignMap&, const FacSignMap&);
const Signature& getSignature (const FgVarNode*);
const Signature& getSignature (const FgFacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
Color getFreeColor (void) {
Color getFreeColor (void)
{
++ freeColor_;
return freeColor_ - 1;
}
Color getColor (const FgVarNode* vn) const
Color getColor (const VarNode* vn) const
{
return varColors_[vn->getIndex()];
}
Color getColor (const FgFacNode* fn) const {
return factorColors_[fn->getIndex()];
Color getColor (const FacNode* fn) const {
return facColors_[fn->getIndex()];
}
void setColor (const FgVarNode* vn, Color c)
void setColor (const VarNode* vn, Color c)
{
varColors_[vn->getIndex()] = c;
}
void setColor (const FgFacNode* fn, Color c)
void setColor (const FacNode* fn, Color c)
{
factorColors_[fn->getIndex()] = c;
facColors_[fn->getIndex()] = c;
}
VarCluster* getVariableCluster (VarId vid) const
@ -214,14 +220,26 @@ class CFactorGraph
return vid2VarCluster_.find (vid)->second;
}
void setInitialColors (void);
void createGroups (void);
void createClusters (const VarSignMap&, const FacSignMap&);
const Signature& getSignature (const VarNode*);
const Signature& getSignature (const FacNode*);
void printGroups (const VarSignMap&, const FacSignMap&) const;
Color freeColor_;
vector<Color> varColors_;
vector<Color> factorColors_;
vector<Color> facColors_;
vector<Signature> varSignatures_;
vector<Signature> factorSignatures_;
VarClusterSet varClusters_;
FacClusterSet facClusters_;
VarId2VarCluster vid2VarCluster_;
vector<Signature> facSignatures_;
VarClusters varClusters_;
FacClusters facClusters_;
VarId2VarCluster vid2VarCluster_;
const FactorGraph* groundFg_;
};

View File

@ -1,10 +1,41 @@
#include "CbpSolver.h"
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
{
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
if (Constants::COLLECT_STATS) {
nGroundVars = fg_->varNodes().size();
nGroundFacs = fg_->facNodes().size();
const VarNodes& vars = fg_->varNodes();
nWithoutNeighs = 0;
for (unsigned i = 0; i < vars.size(); i++) {
const FacNodes& factors = vars[i]->neighbors();
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
nWithoutNeighs ++;
}
}
}
cfg_ = new CFactorGraph (fg);
fg_ = cfg_->getGroundFactorGraph();
if (Constants::COLLECT_STATS) {
unsigned nClusterVars = fg_->varNodes().size();
unsigned nClusterFacs = fg_->facNodes().size();
Statistics::updateCompressingStatistics (nGroundVars,
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
}
Util::printHeader ("Uncompressed Factor Graph");
fg.print();
Util::printHeader ("Compressed Factor Graph");
fg_->print();
}
CbpSolver::~CbpSolver (void)
{
delete lfg_;
delete factorGraph_;
delete cfg_;
delete fg_;
for (unsigned i = 0; i < links_.size(); i++) {
delete links_[i];
}
@ -16,28 +47,31 @@ CbpSolver::~CbpSolver (void)
Params
CbpSolver::getPosterioriOf (VarId vid)
{
assert (lfg_->getEquivalentVariable (vid));
FgVarNode* var = lfg_->getEquivalentVariable (vid);
if (runned_ == false) {
runSolver();
}
assert (cfg_->getEquivalentVariable (vid));
VarNode* var = cfg_->getEquivalentVariable (vid);
Params probs;
if (var->hasEvidence()) {
probs.resize (var->nrStates(), Util::noEvidence());
probs[var->getEvidence()] = Util::withEvidence();
probs.resize (var->range(), LogAware::noEvidence());
probs[var->getEvidence()] = LogAware::withEvidence();
} else {
probs.resize (var->nrStates(), Util::multIdenty());
probs.resize (var->range(), LogAware::multIdenty());
const SpLinkSet& links = ninf(var)->getLinks();
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (probs, l->getPoweredMessage());
}
Util::normalize (probs);
Util::fromLog (probs);
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (probs, l->poweredMessage());
}
LogAware::normalize (probs);
Util::fromLog (probs);
} else {
for (unsigned i = 0; i < links.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (probs, l->getPoweredMessage());
Util::multiply (probs, l->poweredMessage());
}
Util::normalize (probs);
LogAware::normalize (probs);
}
}
return probs;
@ -46,55 +80,14 @@ CbpSolver::getPosterioriOf (VarId vid)
Params
CbpSolver::getJointDistributionOf (const VarIds& jointVarIds)
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
{
VarIds eqVarIds;
for (unsigned i = 0; i < jointVarIds.size(); i++) {
eqVarIds.push_back (lfg_->getEquivalentVariable (jointVarIds[i])->varId());
for (unsigned i = 0; i < jointVids.size(); i++) {
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
eqVarIds.push_back (vn->varId());
}
return FgBpSolver::getJointDistributionOf (eqVarIds);
}
void
CbpSolver::initializeSolver (void)
{
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
if (COLLECT_STATISTICS) {
nGroundVars = factorGraph_->getVarNodes().size();
nGroundFacs = factorGraph_->getFactorNodes().size();
const FgVarSet& vars = factorGraph_->getVarNodes();
nWithoutNeighs = 0;
for (unsigned i = 0; i < vars.size(); i++) {
const FgFacSet& factors = vars[i]->neighbors();
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
nWithoutNeighs ++;
}
}
}
lfg_ = new CFactorGraph (*factorGraph_);
// cout << "Uncompressed Factor Graph" << endl;
// factorGraph_->printGraphicalModel();
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
factorGraph_ = lfg_->getCompressedFactorGraph();
if (COLLECT_STATISTICS) {
unsigned nClusterVars = factorGraph_->getVarNodes().size();
unsigned nClusterFacs = factorGraph_->getFactorNodes().size();
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
nClusterVars, nClusterFacs,
nWithoutNeighs);
}
// cout << "Compressed Factor Graph" << endl;
// factorGraph_->printGraphicalModel();
// factorGraph_->exportToGraphViz ("compressed_fg.dot");
// abort();
FgBpSolver::initializeSolver();
return BpSolver::getJointDistributionOf (eqVarIds);
}
@ -102,12 +95,13 @@ CbpSolver::initializeSolver (void)
void
CbpSolver::createLinks (void)
{
const FacClusterSet fcs = lfg_->getFacClusters();
const FacClusters& fcs = cfg_->getFacClusters();
for (unsigned i = 0; i < fcs.size(); i++) {
const VarClusterSet vcs = fcs[i]->getVarClusters();
const VarClusters& vcs = fcs[i]->getVarClusters();
for (unsigned j = 0; j < vcs.size(); j++) {
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]);
links_.push_back (new CbpSolverLink (fcs[i]->getRepresentativeFactor(),
unsigned c = cfg_->getEdgeCount (fcs[i], vcs[j]);
links_.push_back (new CbpSolverLink (
fcs[i]->getRepresentativeFactor(),
vcs[j]->getRepresentativeVariable(), c));
}
}
@ -123,7 +117,7 @@ CbpSolver::maxResidualSchedule (void)
calculateMessage (links_[i]);
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
linkMap_.insert (make_pair (links_[i], it));
if (DL >= 2 && DL < 5) {
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
cout << "calculating " << links_[i]->toString() << endl;
}
}
@ -131,7 +125,7 @@ CbpSolver::maxResidualSchedule (void)
}
for (unsigned c = 0; c < links_.size(); c++) {
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << endl << "current residuals:" << endl;
for (SortedOrder::iterator it = sortedOrder_.begin();
it != sortedOrder_.end(); it ++) {
@ -142,7 +136,7 @@ CbpSolver::maxResidualSchedule (void)
SortedOrder::iterator it = sortedOrder_.begin();
SpLink* link = *it;
if (DL >= 2) {
if (Constants::DEBUG >= 2) {
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
}
if (link->getResidual() < BpOptions::accuracy) {
@ -154,12 +148,12 @@ CbpSolver::maxResidualSchedule (void)
linkMap_.find (link)->second = sortedOrder_.insert (link);
// update the messages that depend on message source --> destin
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
for (unsigned j = 0; j < links.size(); j++) {
if (links[j]->getVariable() != link->getVariable()) {
if (DL >= 2 && DL < 5) {
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
cout << " calculating " << links[j]->toString() << endl;
}
calculateMessage (links[j]);
@ -174,7 +168,7 @@ CbpSolver::maxResidualSchedule (void)
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getVariable() != link->getVariable()) {
if (DL >= 2 && DL < 5) {
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
cout << " calculating " << links[i]->toString() << endl;
}
calculateMessage (links[i]);
@ -192,43 +186,43 @@ Params
CbpSolver::getVar2FactorMsg (const SpLink* link) const
{
Params msg;
const FgVarNode* src = link->getVariable();
const FgFacNode* dst = link->getFactor();
const VarNode* src = link->getVariable();
const FacNode* dst = link->getFactor();
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
if (src->hasEvidence()) {
msg.resize (src->nrStates(), Util::noEvidence());
msg.resize (src->range(), LogAware::noEvidence());
double value = link->getMessage()[src->getEvidence()];
msg[src->getEvidence()] = Util::pow (value, l->getNumberOfEdges() - 1);
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
} else {
msg = link->getMessage();
Util::pow (msg, l->getNumberOfEdges() - 1);
LogAware::pow (msg, l->nrEdges() - 1);
}
if (DL >= 5) {
cout << " " << "init: " << Util::parametersToString (msg) << endl;
if (Constants::DEBUG >= 5) {
cout << " " << "init: " << msg << endl;
}
const SpLinkSet& links = ninf(src)->getLinks();
if (Globals::logDomain) {
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::add (msg, l->getPoweredMessage());
Util::add (msg, l->poweredMessage());
}
}
} else {
for (unsigned i = 0; i < links.size(); i++) {
if (links[i]->getFactor() != dst) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
Util::multiply (msg, l->getPoweredMessage());
if (DL >= 5) {
Util::multiply (msg, l->poweredMessage());
if (Constants::DEBUG >= 5) {
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
cout << l->poweredMessage() << endl;
}
}
}
}
if (DL >= 5) {
cout << " result = " << Util::parametersToString (msg) << endl;
if (Constants::DEBUG >= 5) {
cout << " result = " << msg << endl;
}
return msg;
}
@ -241,12 +235,9 @@ CbpSolver::printLinkInformation (void) const
for (unsigned i = 0; i < links_.size(); i++) {
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
cout << l->toString() << ":" << endl;
cout << " curr msg = " ;
cout << Util::parametersToString (l->getMessage()) << endl;
cout << " next msg = " ;
cout << Util::parametersToString (l->getNextMessage()) << endl;
cout << " powered = " ;
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
cout << " curr msg = " << l->getMessage() << endl;
cout << " next msg = " << l->getNextMessage() << endl;
cout << " powered = " << l->poweredMessage() << endl;
cout << " residual = " << l->getResidual() << endl;
}
}

View File

@ -1,7 +1,7 @@
#ifndef HORUS_CBP_H
#define HORUS_CBP_H
#include "FgBpSolver.h"
#include "BpSolver.h"
#include "CFactorGraph.h"
class Factor;
@ -9,49 +9,51 @@ class Factor;
class CbpSolverLink : public SpLink
{
public:
CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn)
{
edgeCount_ = c;
poweredMsg_.resize (vn->nrStates(), Util::one());
}
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
: SpLink (fn, vn), nrEdges_(c),
pwdMsg_(vn->range(), LogAware::one()) { }
unsigned nrEdges (void) const { return nrEdges_; }
const Params& poweredMessage (void) const { return pwdMsg_; }
void updateMessage (void)
{
poweredMsg_ = *nextMsg_;
pwdMsg_ = *nextMsg_;
swap (currMsg_, nextMsg_);
msgSended_ = true;
Util::pow (poweredMsg_, edgeCount_);
LogAware::pow (pwdMsg_, nrEdges_);
}
unsigned getNumberOfEdges (void) const { return edgeCount_; }
const Params& getPoweredMessage (void) const { return poweredMsg_; }
private:
Params poweredMsg_;
unsigned edgeCount_;
unsigned nrEdges_;
Params pwdMsg_;
};
class CbpSolver : public FgBpSolver
class CbpSolver : public BpSolver
{
public:
CbpSolver (FactorGraph& fg) : FgBpSolver (fg) { }
~CbpSolver (void);
CbpSolver (const FactorGraph& fg);
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
~CbpSolver (void);
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
private:
void initializeSolver (void);
void createLinks (void);
void maxResidualSchedule (void);
Params getVar2FactorMsg (const SpLink*) const;
void printLinkInformation (void) const;
void createLinks (void);
void maxResidualSchedule (void);
CFactorGraph* lfg_;
Params getVar2FactorMsg (const SpLink*) const;
void printLinkInformation (void) const;
CFactorGraph* cfg_;
};
#endif // HORUS_CBP_H

View File

@ -1,10 +1,11 @@
#include <queue>
#include <fstream>
#include "ConstraintTree.h"
#include "Util.h"
void
CTNode::addChild (CTNode* child, bool updateLevels)
{
@ -42,6 +43,34 @@ CTNode::removeChild (CTNode* child)
void
CTNode::removeChilds (void)
{
childs_.clear();
}
void
CTNode::removeAndDeleteChild (CTNode* child)
{
removeChild (child);
CTNode::deleteSubtree (child);
}
void
CTNode::removeAndDeleteAllChilds (void)
{
for (unsigned i = 0; i < childs_.size(); i++) {
deleteSubtree (childs_[i]);
}
childs_.clear();
}
SymbolSet
CTNode::childSymbols (void) const
{
@ -66,6 +95,32 @@ CTNode::updateChildLevels (CTNode* n, unsigned level)
CTNode*
CTNode::copySubtree (const CTNode* n)
{
CTNode* newNode = new CTNode (*n);
const CTNodes& childs = n->childs();
for (unsigned i = 0; i < childs.size(); i++) {
newNode->addChild (copySubtree (childs[i]));
}
return newNode;
}
void
CTNode::deleteSubtree (CTNode* n)
{
assert (n);
const CTNodes& childs = n->childs();
for (unsigned i = 0; i < childs.size(); i++) {
deleteSubtree (childs[i]);
}
delete n;
}
ostream& operator<< (ostream &out, const CTNode& n)
{
// out << "(" << n.level() << ") " ;
@ -75,6 +130,17 @@ ostream& operator<< (ostream &out, const CTNode& n)
ConstraintTree::ConstraintTree (unsigned nrLvs)
{
for (unsigned i = 0; i < nrLvs; i++) {
logVars_.push_back (LogVar (i));
}
root_ = new CTNode (0, 0);
logVarSet_ = LogVarSet (logVars_);
}
ConstraintTree::ConstraintTree (const LogVars& logVars)
{
root_ = new CTNode (0, 0);
@ -99,7 +165,7 @@ ConstraintTree::ConstraintTree (const LogVars& logVars,
ConstraintTree::ConstraintTree (const ConstraintTree& ct)
{
root_ = copySubtree (ct.root_);
root_ = CTNode::copySubtree (ct.root_);
logVars_ = ct.logVars_;
logVarSet_ = ct.logVarSet_;
}
@ -108,7 +174,7 @@ ConstraintTree::ConstraintTree (const ConstraintTree& ct)
ConstraintTree::~ConstraintTree (void)
{
deleteSubtree (root_);
CTNode::deleteSubtree (root_);
}
@ -200,21 +266,28 @@ ConstraintTree::moveToBottom (const LogVars& lvs)
void
ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
{
{
if (logVarSet_.empty()) {
delete root_;
root_ = CTNode::copySubtree (ct->root());
logVars_ = ct->logVars();
logVarSet_ = ct->logVarSet();
return;
}
LogVarSet intersect = logVarSet_ & ct->logVarSet_;
if (intersect.empty()) {
const CTNodes& childs = ct->root()->childs();
CTNodes leafs = getNodesAtLevel (getLevel (logVars_.back()));
for (unsigned i = 0; i < leafs.size(); i++) {
for (unsigned j = 0; j < childs.size(); j++) {
leafs[i]->addChild (copySubtree (childs[j]));
leafs[i]->addChild (CTNode::copySubtree (childs[j]));
}
}
logVars_.insert (logVars_.end(), ct->logVars_.begin(), ct->logVars_.end());
Util::addToVector (logVars_, ct->logVars_);
logVarSet_ |= ct->logVarSet_;
} else {
moveToBottom (intersect.elements());
ct->moveToTop (intersect.elements());
@ -222,25 +295,27 @@ ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
CTNodes nodes = getNodesAtLevel (level);
Tuples tuples;
CTNodes continuationNodes;
CTNodes continNodes;
getTuples (ct->root(),
Tuples(),
intersect.size(),
tuples,
continuationNodes);
continNodes);
for (unsigned i = 0; i < tuples.size(); i++) {
bool tupleFounded = false;
for (unsigned j = 0; j < nodes.size(); j++) {
tupleFounded |= join (nodes[j], tuples[i], 0, continuationNodes[i]);
tupleFounded |= join (nodes[j], tuples[i], 0, continNodes[i]);
}
if (assertWhenNotFound) {
assert (tupleFounded);
}
}
LogVarSet newLvs = ct->logVarSet_ - intersect;
logVars_.insert (logVars_.end(), newLvs.begin(), newLvs.end());
logVarSet_ |= newLvs;
LogVars newLvs (ct->logVars().begin() + intersect.size(),
ct->logVars().end());
Util::addToVector (logVars_, newLvs);
logVarSet_ |= LogVarSet (newLvs);
}
}
@ -280,6 +355,10 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new)
void
ConstraintTree::applySubstitution (const Substitution& theta)
{
LogVars discardedLvs = theta.getDiscardedLogVars();
for (unsigned i = 0; i < discardedLvs.size(); i++) {
remove(discardedLvs[i]);
}
for (unsigned i = 0; i < logVars_.size(); i++) {
logVars_[i] = theta.newNameFor (logVars_[i]);
}
@ -308,11 +387,7 @@ ConstraintTree::remove (const LogVarSet& X)
unsigned level = getLevel (X.front()) - 1;
CTNodes nodes = getNodesAtLevel (level);
for (unsigned i = 0; i < nodes.size(); i++) {
CTNodes childs = nodes[i]->childs();
for (unsigned j = 0; j < childs.size(); j++) {
nodes[i]->removeChild (childs[j]);
deleteSubtree (childs[j]);
}
nodes[i]->removeAndDeleteAllChilds();
}
logVars_.resize (logVars_.size() - X.size());
logVarSet_ -= X;
@ -545,16 +620,16 @@ ConstraintTree::split (
for (unsigned i = 0; i < commNodes.size(); i++) {
commCt->root()->addChild (commNodes[i]);
}
//cout << commCt->tupleSet() << " + " ;
//cout << exclCt->tupleSet() << " = " ;
//cout << tupleSet() << endl << endl;
// cout << commCt->tupleSet() << " + " ;
// cout << exclCt->tupleSet() << " = " ;
// cout << tupleSet() << endl << endl;
// if (((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet()) == false) {
// exportToGraphViz ("_fail.dot", true);
// commCt->exportToGraphViz ("_fail_comm.dot", true);
// exclCt->exportToGraphViz ("_fail_excl.dot", true);
// }
assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
// assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
// assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
return {commCt, exclCt};
}
@ -601,36 +676,32 @@ ConstraintTree::jointCountNormalize (
LogVar X_new1,
LogVar X_new2)
{
exportToGraphViz ("C.dot", true);
commCt->exportToGraphViz ("C_comm.dot", true);
exclCt->exportToGraphViz ("C_exlc.dot", true);
unsigned N = getConditionalCount (X);
cout << "My tuples: " << tupleSet() << endl;
cout << "CommCt tuples: " << commCt->tupleSet() << endl;
cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
cout << "Counted Lv: " << X << endl;
cout << "Original N: " << N << endl;
cout << endl;
// cout << "My tuples: " << tupleSet() << endl;
// cout << "CommCt tuples: " << commCt->tupleSet() << endl;
// cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
// cout << "Counted Lv: " << X << endl;
// cout << "X_new1: " << X_new1 << endl;
// cout << "X_new2: " << X_new2 << endl;
// cout << "Original N: " << N << endl;
// cout << endl;
ConstraintTrees normCts1 = commCt->countNormalize (X);
vector<unsigned> counts1 (normCts1.size());
for (unsigned i = 0; i < normCts1.size(); i++) {
counts1[i] = normCts1[i]->getConditionalCount (X);
cout << "normCts1[" << i << "] #" << counts1[i] ;
cout << " " << normCts1[i]->tupleSet() << endl;
// cout << "normCts1[" << i << "] #" << counts1[i] ;
// cout << " " << normCts1[i]->tupleSet() << endl;
}
ConstraintTrees normCts2 = exclCt->countNormalize (X);
vector<unsigned> counts2 (normCts2.size());
for (unsigned i = 0; i < normCts2.size(); i++) {
counts2[i] = normCts2[i]->getConditionalCount (X);
cout << "normCts2[" << i << "] #" << counts2[i] ;
cout << " " << normCts2[i]->tupleSet() << endl;
// cout << "normCts2[" << i << "] #" << counts2[i] ;
// cout << " " << normCts2[i]->tupleSet() << endl;
}
cout << endl;
cout << "1###### " << normCts1.size() << endl;
cout << "2###### " << normCts2.size() << endl;
// cout << endl;
ConstraintTree* excl1 = 0;
for (unsigned i = 0; i < normCts1.size(); i++) {
@ -638,7 +709,7 @@ ConstraintTree::jointCountNormalize (
excl1 = normCts1[i];
normCts1.erase (normCts1.begin() + i);
counts1.erase (counts1.begin() + i);
cout << ">joint-count(" << N << ",0)" << endl;
// cout << "joint-count(" << N << ",0)" << endl;
break;
}
}
@ -649,22 +720,21 @@ ConstraintTree::jointCountNormalize (
excl2 = normCts2[i];
normCts2.erase (normCts2.begin() + i);
counts2.erase (counts2.begin() + i);
cout << ">>joint-count(0," << N << ")" << endl;
// cout << "joint-count(0," << N << ")" << endl;
break;
}
}
cout << "3###### " << normCts1.size() << endl;
cout << "4###### " << normCts2.size() << endl;
for (unsigned i = 0; i < normCts1.size(); i++) {
unsigned j;
for (j = 0; counts1[i] + counts2[j] != N; j++) ;
cout << "joint-count(" << counts1[i] << "," << counts2[j] << ")" << endl;
// cout << "joint-count(" << counts1[i] ;
// cout << "," << counts2[j] << ")" << endl;
const CTNodes& childs = normCts2[j]->root_->childs();
for (unsigned k = 0; k < childs.size(); k++) {
normCts1[i]->root_->addChild (childs[k]);
normCts1[i]->root_->addChild (CTNode::copySubtree (childs[k]));
}
delete normCts2[j];
}
ConstraintTrees cts = normCts1;
@ -683,11 +753,6 @@ ConstraintTree::jointCountNormalize (
cts.push_back (excl2);
}
for (unsigned i = 0; i < cts.size(); i++) {
stringstream ss;
ss << "aaacts_" << i + 1 << ".dot" ;
cts[i]->exportToGraphViz (ss.str().c_str(), true);
}
return cts;
}
@ -735,11 +800,11 @@ ConstraintTree::expand (LogVar X)
unsigned nrSymbols = getConditionalCount (X);
for (unsigned i = 0; i < nodes.size(); i++) {
Symbols symbols;
CTNodes childs = nodes[i]->childs();
const CTNodes& childs = nodes[i]->childs();
for (unsigned j = 0; j < childs.size(); j++) {
symbols.push_back (childs[j]->symbol());
nodes[i]->removeChild (childs[j]);
}
nodes[i]->removeAndDeleteAllChilds();
CTNode* prev = nodes[i];
assert (symbols.size() == nrSymbols);
for (unsigned j = 0; j < nrSymbols; j++) {
@ -768,7 +833,7 @@ ConstraintTree::ground (LogVar X)
ConstraintTrees cts;
const CTNodes& nodes = root_->childs();
for (unsigned i = 0; i < nodes.size(); i++) {
CTNode* copy = copySubtree (nodes[i]);
CTNode* copy = CTNode::copySubtree (nodes[i]);
copy->setSymbol (nodes[i]->symbol());
ConstraintTree* newCt = new ConstraintTree (logVars_);
newCt->root()->addChild (copy);
@ -840,19 +905,19 @@ ConstraintTree::getNodesAtLevel (unsigned level) const
void
ConstraintTree::swapLogVar (LogVar X)
{
TupleSet before = tupleSet();
LogVars::iterator it =
std::find (logVars_.begin(),logVars_.end(), X);
assert (it != logVars_.end());
unsigned pos = std::distance (logVars_.begin(), it);
const CTNodes& nodes = getNodesAtLevel (pos);
for (unsigned i = 0; i < nodes.size(); i++) {
const CTNodes childs = nodes[i]->childs();
for (unsigned j = 0; j < childs.size(); j++) {
nodes[i]->removeChild (childs[j]);
const CTNodes grandsons = childs[j]->childs();
CTNodes childsCopy = nodes[i]->childs();
nodes[i]->removeChilds();
for (unsigned j = 0; j < childsCopy.size(); j++) {
const CTNodes grandsons = childsCopy[j]->childs();
for (unsigned k = 0; k < grandsons.size(); k++) {
CTNode* childCopy = new CTNode (*childs[j]);
CTNode* childCopy = new CTNode (*childsCopy[j]);
const CTNodes greatGrandsons = grandsons[k]->childs();
for (unsigned t = 0; t < greatGrandsons.size(); t++) {
grandsons[k]->removeChild (greatGrandsons[t]);
@ -863,10 +928,9 @@ ConstraintTree::swapLogVar (LogVar X)
grandsons[k]->setLevel (grandsons[k]->level() - 1);
nodes[i]->addChild (grandsons[k], false);
}
delete childs[j];
delete childsCopy[j];
}
}
std::swap (logVars_[pos], logVars_[pos + 1]);
}
@ -884,7 +948,7 @@ ConstraintTree::join (
if (currIdx == tuple.size() - 1) {
const CTNodes& childs = appendNode->childs();
for (unsigned i = 0; i < childs.size(); i++) {
n->addChild (copySubtree (childs[i]));
n->addChild (CTNode::copySubtree (childs[i]));
}
return true;
}
@ -985,7 +1049,7 @@ ConstraintTree::countNormalize (
{
if (n->level() == stopLevel) {
return vector<pair<CTNode*, unsigned>>() = {
make_pair (copySubtree (n), countTuples (n))
make_pair (CTNode::copySubtree (n), countTuples (n))
};
}
@ -1004,65 +1068,6 @@ ConstraintTree::countNormalize (
}
/*
void
ConstraintTree::split (
CTNode* n1,
CTNode* n2,
CTNodes& nodes,
unsigned stopLevel)
{
CTNodes& childs1 = n1->childs();
CTNodes& childs2 = n2->childs();
// cout << string (n1->level() * 8, '-') << "Level = " << n1->level() + 1;
// cout << ", #I = " << childs1.size();
// cout << ", #J = " << childs2.size() << endl;
for (unsigned i = 0; i < childs1.size(); i++) {
for (unsigned j = 0; j < childs2.size(); j++) {
if (childs1[i]->symbol() != childs2[j]->symbol()) {
continue;
}
if (childs1[i]->level() == stopLevel) {
CTNode* newNode = copySubtree (childs1[i]);
newNode->setSymbol (childs1[i]->symbol());
nodes.push_back (newNode);
childs1[i]->setSymbol (Symbol::invalid());
break;
} else {
CTNodes lowerNodes;
split (childs1[i], childs2[j], lowerNodes, stopLevel);
if (lowerNodes.empty() == false) {
CTNode* me = new CTNode (childs1[i]->symbol(), childs1[i]->level());
for (unsigned k = 0; k < lowerNodes.size(); k++) {
me->addChild (lowerNodes[k]);
}
nodes.push_back (me);
}
if (childs1[i]->isLeaf()) {
break;
}
}
}
}
for (int i = 0; i < (int)childs1.size(); i++) {
// cout << string (n1->level() * 8, '-') << childs1[i];
if (childs1[i]->symbol() == Symbol::invalid()) {
// cout << " empty, removing..." ;
n1->removeChild (childs1[i]);
i --;
} else if (childs1[i]->isLeaf() &&
childs1[i]->level() != stopLevel) {
// cout << " leaf, removing..." ;
n1->removeChild (childs1[i]);
i --;
}
// cout << endl;
}
}
*/
void
ConstraintTree::split (
@ -1085,7 +1090,7 @@ ConstraintTree::split (
continue;
}
if (childs1[i]->level() == stopLevel) {
CTNode* newNode = copySubtree (childs1[i]);
CTNode* newNode = CTNode::copySubtree (childs1[i]);
nodes.push_back (newNode);
childs1[i]->setSymbol (Symbol::invalid());
} else {
@ -1103,11 +1108,11 @@ ConstraintTree::split (
for (int i = 0; i < (int)childs1.size(); i++) {
if (childs1[i]->symbol() == Symbol::invalid()) {
n1->removeChild (childs1[i]);
n1->removeAndDeleteChild (childs1[i]);
i --;
} else if (childs1[i]->isLeaf() &&
childs1[i]->level() != stopLevel) {
n1->removeChild (childs1[i]);
n1->removeAndDeleteChild (childs1[i]);
i --;
}
}
@ -1141,29 +1146,3 @@ ConstraintTree::overlap (
return false;
}
CTNode*
ConstraintTree::copySubtree (const CTNode* n)
{
CTNode* newNode = new CTNode (*n);
const CTNodes& childs = n->childs();
for (unsigned i = 0; i < childs.size(); i++) {
newNode->addChild (copySubtree (childs[i]));
}
return newNode;
}
void
ConstraintTree::deleteSubtree (CTNode* n)
{
assert (n);
const CTNodes& childs = n->childs();
for (unsigned i = 0; i < childs.size(); i++) {
deleteSubtree (childs[i]);
}
delete n;
}

View File

@ -21,7 +21,6 @@ typedef vector<ConstraintTree*> ConstraintTrees;
class CTNode
{
public:
@ -47,29 +46,44 @@ class CTNode
bool isLeaf (void) const { return childs_.empty(); }
void addChild (CTNode*, bool = true);
void removeChild (CTNode*);
SymbolSet childSymbols (void) const;
void addChild (CTNode*, bool = true);
void removeChild (CTNode*);
void removeChilds (void);
void removeAndDeleteChild (CTNode*);
void removeAndDeleteAllChilds (void);
SymbolSet childSymbols (void) const;
static CTNode* copySubtree (const CTNode*);
static void deleteSubtree (CTNode*);
private:
void updateChildLevels (CTNode*, unsigned);
void updateChildLevels (CTNode*, unsigned);
Symbol symbol_;
CTNodes childs_;
unsigned level_;
Symbol symbol_;
CTNodes childs_;
unsigned level_;
};
ostream& operator<< (ostream &out, const CTNode&);
class ConstraintTree
{
public:
ConstraintTree (unsigned);
ConstraintTree (const LogVars&);
ConstraintTree (const LogVars&, const Tuples&);
ConstraintTree (const ConstraintTree&);
~ConstraintTree (void);
CTNode* root (void) const { return root_; }
@ -94,94 +108,95 @@ class ConstraintTree
assert (LogVarSet (logVars_) == logVarSet_);
}
void addTuple (const Tuple&);
bool containsTuple (const Tuple&);
void moveToTop (const LogVars&);
void moveToBottom (const LogVars&);
void join (ConstraintTree*, bool = false);
unsigned getLevel (LogVar) const;
void rename (LogVar, LogVar);
void applySubstitution (const Substitution&);
void project (const LogVarSet&);
void remove (const LogVarSet&);
bool isSingleton (LogVar);
LogVarSet singletons (void);
TupleSet tupleSet (unsigned = 0) const;
TupleSet tupleSet (const LogVars&);
unsigned size (void) const;
unsigned nrSymbols (LogVar);
void exportToGraphViz (const char*, bool = false) const;
bool isCountNormalized (const LogVarSet&);
unsigned getConditionalCount (const LogVarSet&);
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
bool isCarteesianProduct (const LogVarSet&) const;
void addTuple (const Tuple&);
bool containsTuple (const Tuple&);
void moveToTop (const LogVars&);
void moveToBottom (const LogVars&);
void join (ConstraintTree*, bool = false);
unsigned getLevel (LogVar) const;
void rename (LogVar, LogVar);
void applySubstitution (const Substitution&);
void project (const LogVarSet&);
void remove (const LogVarSet&);
bool isSingleton (LogVar);
LogVarSet singletons (void);
TupleSet tupleSet (unsigned = 0) const;
TupleSet tupleSet (const LogVars&);
unsigned size (void) const;
unsigned nrSymbols (LogVar);
void exportToGraphViz (const char*, bool = false) const;
bool isCountNormalized (const LogVarSet&);
unsigned getConditionalCount (const LogVarSet&);
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
bool isCarteesianProduct (const LogVarSet&) const;
std::pair<ConstraintTree*, ConstraintTree*> split (
const Tuple&,
unsigned);
const Tuple&, unsigned);
std::pair<ConstraintTree*, ConstraintTree*> split (
const ConstraintTree*,
unsigned) const;
const ConstraintTree*, unsigned) const;
ConstraintTrees countNormalize (const LogVarSet&);
ConstraintTrees countNormalize (const LogVarSet&);
ConstraintTrees jointCountNormalize (
ConstraintTree*,
ConstraintTree*,
LogVar,
LogVar,
LogVar);
ConstraintTree*, ConstraintTree*, LogVar, LogVar, LogVar);
static bool identical (
const ConstraintTree*,
const ConstraintTree*,
unsigned);
static bool identical (
const ConstraintTree*, const ConstraintTree*, unsigned);
static bool overlap (
const ConstraintTree*,
const ConstraintTree*,
unsigned);
static bool overlap (
const ConstraintTree*, const ConstraintTree*, unsigned);
LogVars expand (LogVar);
ConstraintTrees ground (LogVar);
LogVars expand (LogVar);
ConstraintTrees ground (LogVar);
private:
unsigned countTuples (const CTNode*) const;
CTNodes getNodesBelow (CTNode*) const;
CTNodes getNodesAtLevel (unsigned) const;
void swapLogVar (LogVar);
bool join (CTNode*, const Tuple&, unsigned, CTNode*);
unsigned countTuples (const CTNode*) const;
bool indenticalSubtrees (
const CTNode*,
const CTNode*,
bool) const;
CTNodes getNodesBelow (CTNode*) const;
void getTuples (
CTNode*,
Tuples,
unsigned,
Tuples&,
CTNodes&) const;
CTNodes getNodesAtLevel (unsigned) const;
void swapLogVar (LogVar);
bool join (CTNode*, const Tuple&, unsigned, CTNode*);
bool indenticalSubtrees (
const CTNode*, const CTNode*, bool) const;
void getTuples (CTNode*, Tuples, unsigned, Tuples&, CTNodes&) const;
vector<std::pair<CTNode*, unsigned>> countNormalize (
const CTNode*,
unsigned);
const CTNode*, unsigned);
static void split (
CTNode*,
CTNode*,
CTNodes&,
unsigned);
static void split (
CTNode*, CTNode*, CTNodes&, unsigned);
static bool overlap (const CTNode*, const CTNode*, unsigned);
static CTNode* copySubtree (const CTNode*);
static void deleteSubtree (CTNode*);
static bool overlap (const CTNode*, const CTNode*, unsigned);
CTNode* root_;
LogVars logVars_;
LogVarSet logVarSet_;
CTNode* root_;
LogVars logVars_;
LogVarSet logVarSet_;
};

View File

@ -1,45 +0,0 @@
#ifndef HORUS_DISTRIBUTION_H
#define HORUS_DISTRIBUTION_H
#include <vector>
#include "Horus.h"
//TODO die die die die die
using namespace std;
struct Distribution
{
public:
Distribution (int id)
{
this->id = id;
}
Distribution (const Params& params, int id = -1)
{
this->id = id;
this->params = params;
}
void updateParameters (const Params& params)
{
this->params = params;
}
bool shared (void)
{
return id != -1;
}
int id;
Params params;
private:
DISALLOW_COPY_AND_ASSIGN (Distribution);
};
#endif // HORUS_DISTRIBUTION_H

View File

@ -1,53 +1,39 @@
#include <limits>
#include "ElimGraph.h"
#include "BayesNet.h"
#include <fstream>
#include "ElimGraph.h"
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
ElimGraph::ElimGraph (const BayesNet& bayesNet)
ElimGraph::ElimGraph (const vector<Factor*>& factors)
{
const BnNodeSet& bnNodes = bayesNet.getBayesNodes();
for (unsigned i = 0; i < bnNodes.size(); i++) {
if (bnNodes[i]->hasEvidence() == false) {
addNode (new EgNode (bnNodes[i]));
}
}
for (unsigned i = 0; i < bnNodes.size(); i++) {
if (bnNodes[i]->hasEvidence() == false) {
EgNode* n = getEgNode (bnNodes[i]->varId());
const BnNodeSet& childs = bnNodes[i]->getChilds();
for (unsigned j = 0; j < childs.size(); j++) {
if (childs[j]->hasEvidence() == false) {
addEdge (n, getEgNode (childs[j]->varId()));
for (unsigned i = 0; i < factors.size(); i++) {
const VarIds& vids = factors[i]->arguments();
for (unsigned j = 0; j < vids.size() - 1; j++) {
EgNode* n1 = getEgNode (vids[j]);
if (n1 == 0) {
n1 = new EgNode (vids[j], factors[i]->range (j));
addNode (n1);
}
for (unsigned k = j + 1; k < vids.size(); k++) {
EgNode* n2 = getEgNode (vids[k]);
if (n2 == 0) {
n2 = new EgNode (vids[k], factors[i]->range (k));
addNode (n2);
}
if (neighbors (n1, n2) == false) {
addEdge (n1, n2);
}
}
}
if (vids.size() == 1) {
if (getEgNode (vids[0]) == 0) {
addNode (new EgNode (vids[0], factors[i]->range (0)));
}
}
}
for (unsigned i = 0; i < bnNodes.size(); i++) {
vector<EgNode*> neighs;
const vector<BayesNode*>& parents = bnNodes[i]->getParents();
for (unsigned i = 0; i < parents.size(); i++) {
if (parents[i]->hasEvidence() == false) {
neighs.push_back (getEgNode (parents[i]->varId()));
}
}
if (neighs.size() > 0) {
for (unsigned i = 0; i < neighs.size() - 1; i++) {
for (unsigned j = i+1; j < neighs.size(); j++) {
if (!neighbors (neighs[i], neighs[j])) {
addEdge (neighs[i], neighs[j]);
}
}
}
}
}
setIndexes();
}
@ -61,40 +47,16 @@ ElimGraph::~ElimGraph (void)
void
ElimGraph::addNode (EgNode* n)
{
nodes_.push_back (n);
varMap_.insert (make_pair (n->varId(), n));
}
EgNode*
ElimGraph::getEgNode (VarId vid) const
{
unordered_map<VarId,EgNode*>::const_iterator it =varMap_.find (vid);
if (it ==varMap_.end()) {
return 0;
} else {
return it->second;
}
}
VarIds
ElimGraph::getEliminatingOrder (const VarIds& exclude)
{
VarIds elimOrder;
marked_.resize (nodes_.size(), false);
for (unsigned i = 0; i < exclude.size(); i++) {
assert (getEgNode (exclude[i]));
EgNode* node = getEgNode (exclude[i]);
assert (node);
marked_[*node] = true;
}
unsigned nVarsToEliminate = nodes_.size() - exclude.size();
for (unsigned i = 0; i < nVarsToEliminate; i++) {
EgNode* node = getLowestCostNode();
@ -107,6 +69,100 @@ ElimGraph::getEliminatingOrder (const VarIds& exclude)
void
ElimGraph::print (void) const
{
for (unsigned i = 0; i < nodes_.size(); i++) {
cout << "node " << nodes_[i]->label() << " neighs:" ;
vector<EgNode*> neighs = nodes_[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
cout << " " << neighs[j]->label();
}
cout << endl;
}
}
void
ElimGraph::exportToGraphViz (
const char* fileName,
bool showNeighborless,
const VarIds& highlightVarIds) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "Markov::exportToDotFile()" << endl;
abort();
}
out << "strict graph {" << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
out << '"' << nodes_[i]->label() << '"' << endl;
}
}
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
EgNode* node =getEgNode (highlightVarIds[i]);
if (node) {
out << '"' << node->label() << '"' ;
out << " [shape=box3d]" << endl;
} else {
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
abort();
}
}
for (unsigned i = 0; i < nodes_.size(); i++) {
vector<EgNode*> neighs = nodes_[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
out << '"' << nodes_[i]->label() << '"' << " -- " ;
out << '"' << neighs[j]->label() << '"' << endl;
}
}
out << "}" << endl;
out.close();
}
VarIds
ElimGraph::getEliminationOrder (
const vector<Factor*> factors,
VarIds excludedVids)
{
ElimGraph graph (factors);
// graph.print();
// graph.exportToGraphViz ("_egg.dot");
return graph.getEliminatingOrder (excludedVids);
}
void
ElimGraph::addNode (EgNode* n)
{
nodes_.push_back (n);
n->setIndex (nodes_.size() - 1);
varMap_.insert (make_pair (n->varId(), n));
}
EgNode*
ElimGraph::getEgNode (VarId vid) const
{
unordered_map<VarId, EgNode*>::const_iterator it;
it = varMap_.find (vid);
return (it != varMap_.end()) ? it->second : 0;
}
EgNode*
ElimGraph::getLowestCostNode (void) const
{
@ -164,7 +220,7 @@ ElimGraph::getWeightCost (const EgNode* n) const
const vector<EgNode*>& neighs = n->neighbors();
for (unsigned i = 0; i < neighs.size(); i++) {
if (marked_[*neighs[i]] == false) {
cost *= neighs[i]->nrStates();
cost *= neighs[i]->range();
}
}
return cost;
@ -204,7 +260,7 @@ ElimGraph::getWeightedFillCost (const EgNode* n) const
for (unsigned j = i+1; j < neighs.size(); j++) {
if (marked_[*neighs[j]] == true) continue;
if (!neighbors (neighs[i], neighs[j])) {
cost += neighs[i]->nrStates() * neighs[j]->nrStates();
cost += neighs[i]->range() * neighs[j]->range();
}
}
}
@ -245,78 +301,3 @@ ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const
return false;
}
void
ElimGraph::setIndexes (void)
{
for (unsigned i = 0; i < nodes_.size(); i++) {
nodes_[i]->setIndex (i);
}
}
void
ElimGraph::printGraphicalModel (void) const
{
for (unsigned i = 0; i < nodes_.size(); i++) {
cout << "node " << nodes_[i]->label() << " neighs:" ;
vector<EgNode*> neighs = nodes_[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
cout << " " << neighs[j]->label();
}
cout << endl;
}
}
void
ElimGraph::exportToGraphViz (const char* fileName,
bool showNeighborless,
const VarIds& highlightVarIds) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "Markov::exportToDotFile()" << endl;
abort();
}
out << "strict graph {" << endl;
for (unsigned i = 0; i < nodes_.size(); i++) {
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
out << '"' << nodes_[i]->label() << '"' ;
if (nodes_[i]->hasEvidence()) {
out << " [style=filled, fillcolor=yellow]" << endl;
} else {
out << endl;
}
}
}
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
EgNode* node =getEgNode (highlightVarIds[i]);
if (node) {
out << '"' << node->label() << '"' ;
out << " [shape=box3d]" << endl;
} else {
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
abort();
}
}
for (unsigned i = 0; i < nodes_.size(); i++) {
vector<EgNode*> neighs = nodes_[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
out << '"' << nodes_[i]->label() << '"' << " -- " ;
out << '"' << neighs[j]->label() << '"' << endl;
}
}
out << "}" << endl;
out.close();
}

View File

@ -17,15 +17,15 @@ enum ElimHeuristic
};
class EgNode : public VarNode {
class EgNode : public Var
{
public:
EgNode (VarNode* var) : VarNode (var) { }
void addNeighbor (EgNode* n)
{
neighs_.push_back (n);
}
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
const vector<EgNode*>& neighbors (void) const { return neighs_; }
private:
vector<EgNode*> neighs_;
};
@ -34,22 +34,18 @@ class EgNode : public VarNode {
class ElimGraph
{
public:
ElimGraph (const BayesNet&);
ElimGraph (const vector<Factor*>&); // TODO
~ElimGraph (void);
void addEdge (EgNode* n1, EgNode* n2)
{
assert (n1 != n2);
n1->addNeighbor (n2);
n2->addNeighbor (n1);
}
void addNode (EgNode*);
EgNode* getEgNode (VarId) const;
VarIds getEliminatingOrder (const VarIds&);
void printGraphicalModel (void) const;
void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const;
void setIndexes();
VarIds getEliminatingOrder (const VarIds&);
void print (void) const;
void exportToGraphViz (const char*, bool = true,
const VarIds& = VarIds()) const;
static VarIds getEliminationOrder (const vector<Factor*>, VarIds);
static void setEliminationHeuristic (ElimHeuristic h)
{
@ -57,18 +53,34 @@ class ElimGraph
}
private:
EgNode* getLowestCostNode (void) const;
unsigned getNeighborsCost (const EgNode*) const;
unsigned getWeightCost (const EgNode*) const;
unsigned getFillCost (const EgNode*) const;
unsigned getWeightedFillCost (const EgNode*) const;
void connectAllNeighbors (const EgNode*);
bool neighbors (const EgNode*, const EgNode*) const;
void addEdge (EgNode* n1, EgNode* n2)
{
assert (n1 != n2);
n1->addNeighbor (n2);
n2->addNeighbor (n1);
}
void addNode (EgNode*);
EgNode* getEgNode (VarId) const;
EgNode* getLowestCostNode (void) const;
unsigned getNeighborsCost (const EgNode*) const;
unsigned getWeightCost (const EgNode*) const;
unsigned getFillCost (const EgNode*) const;
unsigned getWeightedFillCost (const EgNode*) const;
void connectAllNeighbors (const EgNode*);
bool neighbors (const EgNode*, const EgNode*) const;
vector<EgNode*> nodes_;
vector<bool> marked_;
unordered_map<VarId,EgNode*> varMap_;
unordered_map<VarId, EgNode*> varMap_;
static ElimHeuristic elimHeuristic_;
};

View File

@ -8,7 +8,7 @@
#include "Factor.h"
#include "Indexer.h"
#include "Util.h"
Factor::Factor (const Factor& g)
@ -18,206 +18,33 @@ Factor::Factor (const Factor& g)
Factor::Factor (VarId vid, unsigned nStates)
Factor::Factor (
const VarIds& vids,
const Ranges& ranges,
const Params& params,
unsigned distId)
{
varids_.push_back (vid);
ranges_.push_back (nStates);
dist_ = new Distribution (Params (nStates, 1.0));
}
Factor::Factor (const VarNodes& vars)
{
int nParams = 1;
for (unsigned i = 0; i < vars.size(); i++) {
varids_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
nParams *= vars[i]->nrStates();
}
// create a uniform distribution
double val = 1.0 / nParams;
dist_ = new Distribution (Params (nParams, val));
}
Factor::Factor (VarId vid, unsigned nStates, const Params& params)
{
varids_.push_back (vid);
ranges_.push_back (nStates);
dist_ = new Distribution (params);
}
Factor::Factor (VarNodes& vars, Distribution* dist)
{
for (unsigned i = 0; i < vars.size(); i++) {
varids_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
}
dist_ = dist;
}
Factor::Factor (const VarNodes& vars, const Params& params)
{
for (unsigned i = 0; i < vars.size(); i++) {
varids_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->nrStates());
}
dist_ = new Distribution (params);
}
Factor::Factor (const VarIds& vids,
const Ranges& ranges,
const Params& params)
{
varids_ = vids;
args_ = vids;
ranges_ = ranges;
dist_ = new Distribution (params);
params_ = params;
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
}
Factor::~Factor (void)
Factor::Factor (
const Vars& vars,
const Params& params,
unsigned distId)
{
if (dist_->shared() == false) {
delete dist_;
}
}
void
Factor::setParameters (const Params& params)
{
assert (dist_->params.size() == params.size());
dist_->params = params;
}
void
Factor::copyFromFactor (const Factor& g)
{
varids_ = g.getVarIds();
ranges_ = g.getRanges();
dist_ = new Distribution (g.getParameters());
}
void
Factor::multiply (const Factor& g)
{
if (varids_.size() == 0) {
copyFromFactor (g);
return;
}
const VarIds& g_varids = g.getVarIds();
const Ranges& g_ranges = g.getRanges();
const Params& g_params = g.getParameters();
if (varids_ == g_varids) {
// optimization: if the factors contain the same set of variables,
// we can do a 1 to 1 operation on the parameters
if (Globals::logDomain) {
Util::add (dist_->params, g_params);
} else {
Util::multiply (dist_->params, g_params);
}
} else {
bool sharedVars = false;
vector<unsigned> gvarpos;
for (unsigned i = 0; i < g_varids.size(); i++) {
int idx = indexOf (g_varids[i]);
if (idx == -1) {
insertVariable (g_varids[i], g_ranges[i]);
gvarpos.push_back (varids_.size() - 1);
} else {
sharedVars = true;
gvarpos.push_back (idx);
}
}
if (sharedVars == false) {
// optimization: if the original factors doesn't have common variables,
// we don't need to marry the states of the common variables
unsigned count = 0;
for (unsigned i = 0; i < dist_->params.size(); i++) {
if (Globals::logDomain) {
dist_->params[i] += g_params[count];
} else {
dist_->params[i] *= g_params[count];
}
count ++;
if (count >= g_params.size()) {
count = 0;
}
}
} else {
StatesIndexer indexer (ranges_, false);
while (indexer.valid()) {
unsigned g_li = 0;
unsigned prod = 1;
for (int j = gvarpos.size() - 1; j >= 0; j--) {
g_li += indexer[gvarpos[j]] * prod;
prod *= g_ranges[j];
}
if (Globals::logDomain) {
dist_->params[indexer] += g_params[g_li];
} else {
dist_->params[indexer] *= g_params[g_li];
}
++ indexer;
}
}
}
}
void
Factor::insertVariable (VarId varId, unsigned nrStates)
{
assert (indexOf (varId) == -1);
Params oldParams = dist_->params;
dist_->params.clear();
dist_->params.reserve (oldParams.size() * nrStates);
for (unsigned i = 0; i < oldParams.size(); i++) {
for (unsigned reps = 0; reps < nrStates; reps++) {
dist_->params.push_back (oldParams[i]);
}
}
varids_.push_back (varId);
ranges_.push_back (nrStates);
}
void
Factor::insertVariables (const VarIds& varIds, const Ranges& ranges)
{
Params oldParams = dist_->params;
unsigned nrStates = 1;
for (unsigned i = 0; i < varIds.size(); i++) {
assert (indexOf (varIds[i]) == -1);
varids_.push_back (varIds[i]);
ranges_.push_back (ranges[i]);
nrStates *= ranges[i];
}
dist_->params.clear();
dist_->params.reserve (oldParams.size() * nrStates);
for (unsigned i = 0; i < oldParams.size(); i++) {
for (unsigned reps = 0; reps < nrStates; reps++) {
dist_->params.push_back (oldParams[i]);
}
for (unsigned i = 0; i < vars.size(); i++) {
args_.push_back (vars[i]->varId());
ranges_.push_back (vars[i]->range());
}
params_ = params;
distId_ = distId;
assert (params_.size() == Util::expectedSize (ranges_));
}
@ -226,10 +53,10 @@ void
Factor::sumOutAllExcept (VarId vid)
{
assert (indexOf (vid) != -1);
while (varids_.back() != vid) {
while (args_.back() != vid) {
sumOutLastVariable();
}
while (varids_.front() != vid) {
while (args_.front() != vid) {
sumOutFirstVariable();
}
}
@ -239,9 +66,10 @@ Factor::sumOutAllExcept (VarId vid)
void
Factor::sumOutAllExcept (const VarIds& vids)
{
for (unsigned i = 0; i < varids_.size(); i++) {
if (std::find (vids.begin(), vids.end(), varids_[i]) == vids.end()) {
sumOut (varids_[i]);
for (int i = 0; i < (int)args_.size(); i++) {
if (Util::contains (vids, args_[i]) == false) {
sumOut (args_[i]);
i --;
}
}
}
@ -254,11 +82,11 @@ Factor::sumOut (VarId vid)
int idx = indexOf (vid);
assert (idx != -1);
if (vid == varids_.back()) {
if (vid == args_.back()) {
sumOutLastVariable(); // optimization
return;
}
if (vid == varids_.front()) {
if (vid == args_.front()) {
sumOutFirstVariable(); // optimization
return;
}
@ -271,7 +99,7 @@ Factor::sumOut (VarId vid)
// on the left of `var', with the states of the remaining vars fixed
unsigned leftVarOffset = 1;
for (int i = varids_.size() - 1; i > idx; i--) {
for (int i = args_.size() - 1; i > idx; i--) {
varOffset *= ranges_[i];
leftVarOffset *= ranges_[i];
}
@ -280,25 +108,24 @@ Factor::sumOut (VarId vid)
unsigned offset = 0;
unsigned count1 = 0;
unsigned count2 = 0;
unsigned newpsSize = dist_->params.size() / ranges_[idx];
unsigned newpsSize = params_.size() / ranges_[idx];
Params newps;
newps.reserve (newpsSize);
Params& params = dist_->params;
while (newps.size() < newpsSize) {
double sum = Util::addIdenty();
double sum = LogAware::addIdenty();
for (unsigned i = 0; i < ranges_[idx]; i++) {
if (Globals::logDomain) {
Util::logSum (sum, params[offset]);
sum = Util::logSum (sum, params_[offset]);
} else {
sum += params[offset];
sum += params_[offset];
}
offset += varOffset;
}
newps.push_back (sum);
count1 ++;
if (idx == (int)varids_.size() - 1) {
if (idx == (int)args_.size() - 1) {
offset = count1 * ranges_[idx];
} else {
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
@ -308,9 +135,9 @@ Factor::sumOut (VarId vid)
offset = (leftVarOffset * count2) + count1;
}
}
varids_.erase (varids_.begin() + idx);
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
dist_->params = newps;
params_ = newps;
}
@ -318,20 +145,19 @@ Factor::sumOut (VarId vid)
void
Factor::sumOutFirstVariable (void)
{
Params& params = dist_->params;
unsigned nStates = ranges_.front();
unsigned sep = params.size() / nStates;
unsigned range = ranges_.front();
unsigned sep = params_.size() / range;
if (Globals::logDomain) {
for (unsigned i = sep; i < params.size(); i++) {
Util::logSum (params[i % sep], params[i]);
for (unsigned i = sep; i < params_.size(); i++) {
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
}
} else {
for (unsigned i = sep; i < params.size(); i++) {
params[i % sep] += params[i];
for (unsigned i = sep; i < params_.size(); i++) {
params_[i % sep] += params_[i];
}
}
params.resize (sep);
varids_.erase (varids_.begin());
params_.resize (sep);
args_.erase (args_.begin());
ranges_.erase (ranges_.begin());
}
@ -340,143 +166,55 @@ Factor::sumOutFirstVariable (void)
void
Factor::sumOutLastVariable (void)
{
Params& params = dist_->params;
unsigned nStates = ranges_.back();
unsigned range = ranges_.back();
unsigned idx1 = 0;
unsigned idx2 = 0;
if (Globals::logDomain) {
while (idx1 < params.size()) {
params[idx2] = params[idx1];
while (idx1 < params_.size()) {
params_[idx2] = params_[idx1];
idx1 ++;
for (unsigned j = 1; j < nStates; j++) {
Util::logSum (params[idx2], params[idx1]);
for (unsigned j = 1; j < range; j++) {
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
idx1 ++;
}
idx2 ++;
}
} else {
while (idx1 < params.size()) {
params[idx2] = params[idx1];
while (idx1 < params_.size()) {
params_[idx2] = params_[idx1];
idx1 ++;
for (unsigned j = 1; j < nStates; j++) {
params[idx2] += params[idx1];
for (unsigned j = 1; j < range; j++) {
params_[idx2] += params_[idx1];
idx1 ++;
}
idx2 ++;
}
}
params.resize (idx2);
varids_.pop_back();
params_.resize (idx2);
args_.pop_back();
ranges_.pop_back();
}
void
Factor::orderVariables (void)
Factor::multiply (Factor& g)
{
VarIds sortedVarIds = varids_;
sort (sortedVarIds.begin(), sortedVarIds.end());
reorderVariables (sortedVarIds);
}
void
Factor::reorderVariables (const VarIds& newVarIds)
{
assert (newVarIds.size() == varids_.size());
if (newVarIds == varids_) {
if (args_.size() == 0) {
copyFromFactor (g);
return;
}
Ranges newRanges;
vector<unsigned> positions;
for (unsigned i = 0; i < newVarIds.size(); i++) {
unsigned idx = indexOf (newVarIds[i]);
newRanges.push_back (ranges_[idx]);
positions.push_back (idx);
}
unsigned N = ranges_.size();
Params newParams (dist_->params.size());
for (unsigned i = 0; i < dist_->params.size(); i++) {
unsigned li = i;
// calculate vector index corresponding to linear index
vector<unsigned> vi (N);
for (int k = N-1; k >= 0; k--) {
vi[k] = li % ranges_[k];
li /= ranges_[k];
}
// convert permuted vector index to corresponding linear index
unsigned prod = 1;
unsigned new_li = 0;
for (int k = N-1; k >= 0; k--) {
new_li += vi[positions[k]] * prod;
prod *= ranges_[positions[k]];
}
newParams[new_li] = dist_->params[i];
}
varids_ = newVarIds;
ranges_ = newRanges;
dist_->params = newParams;
TFactor<VarId>::multiply (g);
}
void
Factor::absorveEvidence (VarId vid, unsigned evidence)
Factor::reorderAccordingVarIds (void)
{
int idx = indexOf (vid);
assert (idx != -1);
Params oldParams = dist_->params;
dist_->params.clear();
dist_->params.reserve (oldParams.size() / ranges_[idx]);
StatesIndexer indexer (ranges_);
for (unsigned i = 0; i < evidence; i++) {
indexer.increment (idx);
}
while (indexer.valid()) {
dist_->params.push_back (oldParams[indexer]);
indexer.incrementExcluding (idx);
}
varids_.erase (varids_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
void
Factor::normalize (void)
{
Util::normalize (dist_->params);
}
bool
Factor::contains (const VarIds& vars) const
{
for (unsigned i = 0; i < vars.size(); i++) {
if (indexOf (vars[i]) == -1) {
return false;
}
}
return true;
}
int
Factor::indexOf (VarId vid) const
{
for (unsigned i = 0; i < varids_.size(); i++) {
if (varids_[i] == vid) {
return i;
}
}
return -1;
VarIds sortedVarIds = args_;
sort (sortedVarIds.begin(), sortedVarIds.end());
reorderArguments (sortedVarIds);
}
@ -486,9 +224,9 @@ Factor::getLabel (void) const
{
stringstream ss;
ss << "f(" ;
for (unsigned i = 0; i < varids_.size(); i++) {
for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ;
ss << VarNode (varids_[i], ranges_[i]).label();
ss << Var (args_[i], ranges_[i]).label();
}
ss << ")" ;
return ss.str();
@ -499,14 +237,14 @@ Factor::getLabel (void) const
void
Factor::print (void) const
{
VarNodes vars;
for (unsigned i = 0; i < varids_.size(); i++) {
vars.push_back (new VarNode (varids_[i], ranges_[i]));
Vars vars;
for (unsigned i = 0; i < args_.size(); i++) {
vars.push_back (new Var (args_[i], ranges_[i]));
}
vector<string> jointStrings = Util::getJointStateStrings (vars);
for (unsigned i = 0; i < dist_->params.size(); i++) {
cout << "f(" << jointStrings[i] << ")" ;
cout << " = " << dist_->params[i] << endl;
vector<string> jointStrings = Util::getStateLines (vars);
for (unsigned i = 0; i < params_.size(); i++) {
cout << "[" << distId_ << "] f(" << jointStrings[i] << ")" ;
cout << " = " << params_[i] << endl;
}
cout << endl;
for (unsigned i = 0; i < vars.size(); i++) {
@ -515,3 +253,13 @@ Factor::print (void) const
}
void
Factor::copyFromFactor (const Factor& g)
{
args_ = g.arguments();
ranges_ = g.ranges();
params_ = g.params();
distId_ = g.distId();
}

View File

@ -3,64 +3,285 @@
#include <vector>
#include "Distribution.h"
#include "VarNode.h"
#include "Var.h"
#include "Indexer.h"
#include "Util.h"
using namespace std;
class Distribution;
template <typename T>
class TFactor
{
public:
const vector<T>& arguments (void) const { return args_; }
vector<T>& arguments (void) { return args_; }
const Ranges& ranges (void) const { return ranges_; }
const Params& params (void) const { return params_; }
Params& params (void) { return params_; }
unsigned nrArguments (void) const { return args_.size(); }
unsigned size (void) const { return params_.size(); }
unsigned distId (void) const { return distId_; }
void setDistId (unsigned id) { distId_ = id; }
void normalize (void) { LogAware::normalize (params_); }
void setParams (const Params& newParams)
{
params_ = newParams;
assert (params_.size() == Util::expectedSize (ranges_));
}
int indexOf (const T& t) const
{
int idx = -1;
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i] == t) {
idx = i;
break;
}
}
return idx;
}
const T& argument (unsigned idx) const
{
assert (idx < args_.size());
return args_[idx];
}
T& argument (unsigned idx)
{
assert (idx < args_.size());
return args_[idx];
}
unsigned range (unsigned idx) const
{
assert (idx < ranges_.size());
return ranges_[idx];
}
void multiply (TFactor<T>& g)
{
const vector<T>& g_args = g.arguments();
const Ranges& g_ranges = g.ranges();
const Params& g_params = g.params();
if (args_ == g_args) {
// optimization: if the factors contain the same set of args,
// we can do a 1 to 1 operation on the parameters
if (Globals::logDomain) {
Util::add (params_, g_params);
} else {
Util::multiply (params_, g_params);
}
} else {
bool sharedArgs = false;
vector<unsigned> gvarpos;
for (unsigned i = 0; i < g_args.size(); i++) {
int idx = indexOf (g_args[i]);
if (idx == -1) {
insertArgument (g_args[i], g_ranges[i]);
gvarpos.push_back (args_.size() - 1);
} else {
sharedArgs = true;
gvarpos.push_back (idx);
}
}
if (sharedArgs == false) {
// optimization: if the original factors doesn't have common args,
// we don't need to marry the states of the common args
unsigned count = 0;
for (unsigned i = 0; i < params_.size(); i++) {
if (Globals::logDomain) {
params_[i] += g_params[count];
} else {
params_[i] *= g_params[count];
}
count ++;
if (count >= g_params.size()) {
count = 0;
}
}
} else {
StatesIndexer indexer (ranges_, false);
while (indexer.valid()) {
unsigned g_li = 0;
unsigned prod = 1;
for (int j = gvarpos.size() - 1; j >= 0; j--) {
g_li += indexer[gvarpos[j]] * prod;
prod *= g_ranges[j];
}
if (Globals::logDomain) {
params_[indexer] += g_params[g_li];
} else {
params_[indexer] *= g_params[g_li];
}
++ indexer;
}
}
}
}
void absorveEvidence (const T& arg, unsigned evidence)
{
int idx = indexOf (arg);
assert (idx != -1);
assert (evidence < ranges_[idx]);
Params copy = params_;
params_.clear();
params_.reserve (copy.size() / ranges_[idx]);
StatesIndexer indexer (ranges_);
for (unsigned i = 0; i < evidence; i++) {
indexer.increment (idx);
}
while (indexer.valid()) {
params_.push_back (copy[indexer]);
indexer.incrementExcluding (idx);
}
args_.erase (args_.begin() + idx);
ranges_.erase (ranges_.begin() + idx);
}
void reorderArguments (const vector<T> newArgs)
{
assert (newArgs.size() == args_.size());
if (newArgs == args_) {
return; // already in the wanted order
}
Ranges newRanges;
vector<unsigned> positions;
for (unsigned i = 0; i < newArgs.size(); i++) {
unsigned idx = indexOf (newArgs[i]);
newRanges.push_back (ranges_[idx]);
positions.push_back (idx);
}
unsigned N = ranges_.size();
Params newParams (params_.size());
for (unsigned i = 0; i < params_.size(); i++) {
unsigned li = i;
// calculate vector index corresponding to linear index
vector<unsigned> vi (N);
for (int k = N-1; k >= 0; k--) {
vi[k] = li % ranges_[k];
li /= ranges_[k];
}
// convert permuted vector index to corresponding linear index
unsigned prod = 1;
unsigned new_li = 0;
for (int k = N - 1; k >= 0; k--) {
new_li += vi[positions[k]] * prod;
prod *= ranges_[positions[k]];
}
newParams[new_li] = params_[i];
}
args_ = newArgs;
ranges_ = newRanges;
params_ = newParams;
}
bool contains (const T& arg) const
{
return Util::contains (args_, arg);
}
bool contains (const vector<T>& args) const
{
for (unsigned i = 0; i < args_.size(); i++) {
if (contains (args[i]) == false) {
return false;
}
}
return true;
}
protected:
vector<T> args_;
Ranges ranges_;
Params params_;
unsigned distId_;
private:
void insertArgument (const T& arg, unsigned range)
{
assert (indexOf (arg) == -1);
Params copy = params_;
params_.clear();
params_.reserve (copy.size() * range);
for (unsigned i = 0; i < copy.size(); i++) {
for (unsigned reps = 0; reps < range; reps++) {
params_.push_back (copy[i]);
}
}
args_.push_back (arg);
ranges_.push_back (range);
}
void insertArguments (const vector<T>& args, const Ranges& ranges)
{
Params copy = params_;
unsigned nrStates = 1;
for (unsigned i = 0; i < args.size(); i++) {
assert (indexOf (args[i]) == -1);
args_.push_back (args[i]);
ranges_.push_back (ranges[i]);
nrStates *= ranges[i];
}
params_.clear();
params_.reserve (copy.size() * nrStates);
for (unsigned i = 0; i < copy.size(); i++) {
for (unsigned reps = 0; reps < nrStates; reps++) {
params_.push_back (copy[i]);
}
}
}
};
class Factor
class Factor : public TFactor<VarId>
{
public:
Factor (void) { }
Factor (const Factor&);
Factor (VarId, unsigned);
Factor (const VarNodes&);
Factor (VarId, unsigned, const Params&);
Factor (VarNodes&, Distribution*);
Factor (const VarNodes&, const Params&);
Factor (const VarIds&, const Ranges&, const Params&);
~Factor (void);
void setParameters (const Params&);
void copyFromFactor (const Factor& f);
void multiply (const Factor&);
void insertVariable (VarId, unsigned);
void insertVariables (const VarIds&, const Ranges&);
void sumOutAllExcept (VarId);
void sumOutAllExcept (const VarIds&);
void sumOut (VarId);
void sumOutFirstVariable (void);
void sumOutLastVariable (void);
void orderVariables (void);
void reorderVariables (const VarIds&);
void absorveEvidence (VarId, unsigned);
void normalize (void);
bool contains (const VarIds&) const;
int indexOf (VarId) const;
string getLabel (void) const;
void print (void) const;
Factor (const VarIds&, const Ranges&, const Params&,
unsigned = Util::maxUnsigned());
const VarIds& getVarIds (void) const { return varids_; }
const Ranges& getRanges (void) const { return ranges_; }
const Params& getParameters (void) const { return dist_->params; }
Distribution* getDistribution (void) const { return dist_; }
unsigned nrVariables (void) const { return varids_.size(); }
unsigned nrParameters() const { return dist_->params.size(); }
Factor (const Vars&, const Params&,
unsigned = Util::maxUnsigned());
void setDistribution (Distribution* dist)
{
dist_ = dist;
}
void sumOutAllExcept (VarId);
void sumOutAllExcept (const VarIds&);
void sumOut (VarId);
void sumOutFirstVariable (void);
void sumOutLastVariable (void);
void multiply (Factor&);
void reorderAccordingVarIds (void);
string getLabel (void) const;
void print (void) const;
private:
void copyFromFactor (const Factor& f);
VarIds varids_;
Ranges ranges_;
Distribution* dist_;
};
#endif // HORUS_FACTOR_H

View File

@ -9,6 +9,7 @@
#include "FactorGraph.h"
#include "Factor.h"
#include "BayesNet.h"
#include "BayesBall.h"
#include "Util.h"
@ -17,140 +18,92 @@ bool FactorGraph::orderFactorVariables = false;
FactorGraph::FactorGraph (const FactorGraph& fg)
{
const FgVarSet& vars = fg.getVarNodes();
for (unsigned i = 0; i < vars.size(); i++) {
FgVarNode* varNode = new FgVarNode (vars[i]);
addVariable (varNode);
const VarNodes& varNodes = fg.varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
addVarNode (new VarNode (varNodes[i]));
}
const FgFacSet& facs = fg.getFactorNodes();
for (unsigned i = 0; i < facs.size(); i++) {
FgFacNode* facNode = new FgFacNode (facs[i]);
addFactor (facNode);
const FgVarSet& neighs = facs[i]->neighbors();
const FacNodes& facNodes = fg.facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
FacNode* facNode = new FacNode (facNodes[i]->factor());
addFacNode (facNode);
const VarNodes& neighs = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
}
}
}
FactorGraph::FactorGraph (const BayesNet& bn)
{
const BnNodeSet& nodes = bn.getBayesNodes();
for (unsigned i = 0; i < nodes.size(); i++) {
FgVarNode* varNode = new FgVarNode (nodes[i]);
addVariable (varNode);
}
for (unsigned i = 0; i < nodes.size(); i++) {
const BnNodeSet& parents = nodes[i]->getParents();
if (!(nodes[i]->hasEvidence() && parents.size() == 0)) {
VarNodes neighs;
neighs.push_back (varNodes_[nodes[i]->getIndex()]);
for (unsigned j = 0; j < parents.size(); j++) {
neighs.push_back (varNodes_[parents[j]->getIndex()]);
}
FgFacNode* fn = new FgFacNode (
new Factor (neighs, nodes[i]->getDistribution()));
if (orderFactorVariables) {
sort (neighs.begin(), neighs.end(), CompVarId());
fn->factor()->orderVariables();
}
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
}
}
}
setIndexes();
}
void
FactorGraph::readFromUaiFormat (const char* fileName)
{
ifstream is (fileName);
std::ifstream is (fileName);
if (!is.is_open()) {
cerr << "error: cannot read from file " + std::string (fileName) << endl;
cerr << "error: cannot read from file " << fileName << endl;
abort();
}
ignoreLines (is);
string line;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
getline (is, line);
if (line != "MARKOV") {
cerr << "error: the network must be a MARKOV network " << endl;
abort();
}
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
unsigned nVars;
is >> nVars;
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
vector<int> domainSizes (nVars);
for (unsigned i = 0; i < nVars; i++) {
unsigned ds;
is >> ds;
domainSizes[i] = ds;
// read the number of vars
ignoreLines (is);
unsigned nrVars;
is >> nrVars;
// read the range of each var
ignoreLines (is);
Ranges ranges (nrVars);
for (unsigned i = 0; i < nrVars; i++) {
is >> ranges[i];
}
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
for (unsigned i = 0; i < nVars; i++) {
addVariable (new FgVarNode (i, domainSizes[i]));
}
unsigned nFactors;
is >> nFactors;
for (unsigned i = 0; i < nFactors; i++) {
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
unsigned nFactorVars;
is >> nFactorVars;
VarNodes neighs;
for (unsigned j = 0; j < nFactorVars; j++) {
unsigned vid;
unsigned nrFactors;
unsigned nrArgs;
unsigned vid;
is >> nrFactors;
vector<VarIds> factorVarIds;
vector<Ranges> factorRanges;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrArgs;
factorVarIds.push_back ({ });
factorRanges.push_back ({ });
for (unsigned j = 0; j < nrArgs; j++) {
is >> vid;
FgVarNode* neigh = getFgVarNode (vid);
if (!neigh) {
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
if (vid >= ranges.size()) {
cerr << "error: invalid variable identifier `" << vid << "'" << endl;
cerr << "identifiers must be between 0 and " << ranges.size() - 1 ;
cerr << endl;
abort();
}
neighs.push_back (neigh);
}
FgFacNode* fn = new FgFacNode (new Factor (neighs));
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
factorVarIds.back().push_back (vid);
factorRanges.back().push_back (ranges[vid]);
}
}
for (unsigned i = 0; i < nFactors; i++) {
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
unsigned nParams;
is >> nParams;
if (facNodes_[i]->getParameters().size() != nParams) {
cerr << "error: invalid number of parameters for factor " ;
cerr << facNodes_[i]->getLabel() ;
cerr << ", expected: " << facNodes_[i]->getParameters().size();
cerr << ", given: " << nParams << endl;
// read the parameters
unsigned nrParams;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
is >> nrParams;
if (nrParams != Util::expectedSize (factorRanges[i])) {
cerr << "error: invalid number of parameters for factor nº " << i ;
cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
cerr << ", given: " << nrParams << endl;
abort();
}
Params params (nParams);
for (unsigned j = 0; j < nParams; j++) {
double param;
is >> param;
params[j] = param;
Params params (nrParams);
for (unsigned j = 0; j < nrParams; j++) {
is >> params[j];
}
if (Globals::logDomain) {
Util::toLog (params);
}
facNodes_[i]->factor()->setParameters (params);
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
}
is.close();
setIndexes();
}
@ -158,87 +111,58 @@ FactorGraph::readFromUaiFormat (const char* fileName)
void
FactorGraph::readFromLibDaiFormat (const char* fileName)
{
ifstream is (fileName);
std::ifstream is (fileName);
if (!is.is_open()) {
cerr << "error: cannot read from file " + std::string (fileName) << endl;
cerr << "error: cannot read from file " << fileName << endl;
abort();
}
string line;
unsigned nFactors;
while ((is.peek()) == '#') getline (is, line);
is >> nFactors;
if (is.fail()) {
cerr << "error: cannot read the number of factors" << endl;
abort();
}
getline (is, line);
if (is.fail() || line.size() > 0) {
cerr << "error: cannot read the number of factors" << endl;
abort();
}
for (unsigned i = 0; i < nFactors; i++) {
unsigned nVars;
while ((is.peek()) == '#') getline (is, line);
is >> nVars;
ignoreLines (is);
unsigned nrFactors;
unsigned nrArgs;
VarId vid;
is >> nrFactors;
for (unsigned i = 0; i < nrFactors; i++) {
ignoreLines (is);
// read the factor arguments
is >> nrArgs;
VarIds vids;
for (unsigned j = 0; j < nVars; j++) {
VarId vid;
while ((is.peek()) == '#') getline (is, line);
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> vid;
vids.push_back (vid);
}
VarNodes neighs;
unsigned nParams = 1;
for (unsigned j = 0; j < nVars; j++) {
unsigned dsize;
while ((is.peek()) == '#') getline (is, line);
is >> dsize;
FgVarNode* var = getFgVarNode (vids[j]);
if (var == 0) {
var = new FgVarNode (vids[j], dsize);
addVariable (var);
} else {
if (var->nrStates() != dsize) {
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
cerr << "more factors with different domain sizes" << endl;
}
// read ranges
Ranges ranges (nrArgs);
for (unsigned j = 0; j < nrArgs; j++) {
ignoreLines (is);
is >> ranges[j];
VarNode* var = getVarNode (vids[j]);
if (var != 0 && ranges[j] != var->range()) {
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
cerr << "more factors with a different range" << endl;
}
neighs.push_back (var);
nParams *= var->nrStates();
}
Params params (nParams, 0);
// read parameters
ignoreLines (is);
unsigned nNonzeros;
while ((is.peek()) == '#') getline (is, line);
is >> nNonzeros;
Params params (Util::expectedSize (ranges), 0);
for (unsigned j = 0; j < nNonzeros; j++) {
ignoreLines (is);
unsigned index;
double val;
while ((is.peek()) == '#') getline (is, line);
is >> index;
while ((is.peek()) == '#') getline (is, line);
ignoreLines (is);
double val;
is >> val;
params[index] = val;
}
reverse (neighs.begin(), neighs.end());
reverse (vids.begin(), vids.end());
if (Globals::logDomain) {
Util::toLog (params);
}
FgFacNode* fn = new FgFacNode (new Factor (neighs, params));
addFactor (fn);
for (unsigned j = 0; j < neighs.size(); j++) {
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
}
addFactor (Factor (vids, ranges, params));
}
is.close();
setIndexes();
}
@ -256,17 +180,41 @@ FactorGraph::~FactorGraph (void)
void
FactorGraph::addVariable (FgVarNode* vn)
FactorGraph::addFactor (const Factor& factor)
{
varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1);
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
FacNode* fn = new FacNode (factor);
addFacNode (fn);
const VarIds& vids = factor.arguments();
for (unsigned i = 0; i < vids.size(); i++) {
bool found = false;
for (unsigned j = 0; j < varNodes_.size(); j++) {
if (varNodes_[j]->varId() == vids[i]) {
addEdge (varNodes_[j], fn);
found = true;
}
}
if (found == false) {
VarNode* vn = new VarNode (vids[i], factor.range (i));
addVarNode (vn);
addEdge (vn, fn);
}
}
}
void
FactorGraph::addFactor (FgFacNode* fn)
FactorGraph::addVarNode (VarNode* vn)
{
varNodes_.push_back (vn);
vn->setIndex (varNodes_.size() - 1);
varMap_.insert (make_pair (vn->varId(), vn));
}
void
FactorGraph::addFacNode (FacNode* fn)
{
facNodes_.push_back (fn);
fn->setIndex (facNodes_.size() - 1);
@ -275,7 +223,7 @@ FactorGraph::addFactor (FgFacNode* fn)
void
FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
FactorGraph::addEdge (VarNode* vn, FacNode* fn)
{
vn->addNeighbor (fn);
fn->addNeighbor (vn);
@ -283,37 +231,6 @@ FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
void
FactorGraph::addEdge (FgFacNode* fn, FgVarNode* vn)
{
fn->addNeighbor (vn);
vn->addNeighbor (fn);
}
VarNode*
FactorGraph::getVariableNode (VarId vid) const
{
FgVarNode* vn = getFgVarNode (vid);
assert (vn);
return vn;
}
VarNodes
FactorGraph::getVariableNodes (void) const
{
VarNodes vars;
for (unsigned i = 0; i < varNodes_.size(); i++) {
vars.push_back (varNodes_[i]);
}
return vars;
}
bool
FactorGraph::isTree (void) const
{
@ -322,51 +239,42 @@ FactorGraph::isTree (void) const
void
FactorGraph::setIndexes (void)
DAGraph&
FactorGraph::getStructure (void)
{
for (unsigned i = 0; i < varNodes_.size(); i++) {
varNodes_[i]->setIndex (i);
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
facNodes_[i]->setIndex (i);
assert (fromBayesNet_);
if (structure_.empty()) {
for (unsigned i = 0; i < varNodes_.size(); i++) {
structure_.addNode (new DAGraphNode (varNodes_[i]));
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
const VarIds& vids = facNodes_[i]->factor().arguments();
for (unsigned j = 1; j < vids.size(); j++) {
structure_.addEdge (vids[j], vids[0]);
}
}
}
return structure_;
}
void
FactorGraph::freeDistributions (void)
{
set<Distribution*> dists;
for (unsigned i = 0; i < facNodes_.size(); i++) {
dists.insert (facNodes_[i]->factor()->getDistribution());
}
for (set<Distribution*>::iterator it = dists.begin();
it != dists.end(); it++) {
delete *it;
}
}
void
FactorGraph::printGraphicalModel (void) const
FactorGraph::print (void) const
{
for (unsigned i = 0; i < varNodes_.size(); i++) {
cout << "VarId = " << varNodes_[i]->varId() << endl;
cout << "Label = " << varNodes_[i]->label() << endl;
cout << "Nr States = " << varNodes_[i]->nrStates() << endl;
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
cout << "Factors = " ;
cout << "var id = " << varNodes_[i]->varId() << endl;
cout << "label = " << varNodes_[i]->label() << endl;
cout << "range = " << varNodes_[i]->range() << endl;
cout << "evidence = " << varNodes_[i]->getEvidence() << endl;
cout << "factors = " ;
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
}
cout << endl << endl;
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
facNodes_[i]->factor()->print();
cout << endl;
facNodes_[i]->factor().print();
}
}
@ -381,31 +289,26 @@ FactorGraph::exportToGraphViz (const char* fileName) const
cerr << "FactorGraph::exportToDotFile()" << endl;
abort();
}
out << "graph \"" << fileName << "\" {" << endl;
for (unsigned i = 0; i < varNodes_.size(); i++) {
if (varNodes_[i]->hasEvidence()) {
out << '"' << varNodes_[i]->label() << '"' ;
out << " [style=filled, fillcolor=yellow]" << endl;
}
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " [label=\"" << facNodes_[i]->getLabel();
out << "\"" << ", shape=box]" << endl;
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
const FgVarSet& myVars = facNodes_[i]->neighbors();
const VarNodes& myVars = facNodes_[i]->neighbors();
for (unsigned j = 0; j < myVars.size(); j++) {
out << '"' << facNodes_[i]->getLabel() << '"' ;
out << " -- " ;
out << '"' << myVars[j]->label() << '"' << endl;
}
}
out << "}" << endl;
out.close();
}
@ -417,30 +320,26 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "FactorGraph::exportToUaiFormat()" << endl;
cerr << "error: cannot open file " << fileName << endl;
abort();
}
out << "MARKOV" << endl;
out << varNodes_.size() << endl;
for (unsigned i = 0; i < varNodes_.size(); i++) {
out << varNodes_[i]->nrStates() << " " ;
out << varNodes_[i]->range() << " " ;
}
out << endl;
out << facNodes_.size() << endl;
for (unsigned i = 0; i < facNodes_.size(); i++) {
const FgVarSet& factorVars = facNodes_[i]->neighbors();
const VarNodes& factorVars = facNodes_[i]->neighbors();
out << factorVars.size();
for (unsigned j = 0; j < factorVars.size(); j++) {
out << " " << factorVars[j]->getIndex();
}
out << endl;
}
for (unsigned i = 0; i < facNodes_.size(); i++) {
Params params = facNodes_[i]->getParameters();
Params params = facNodes_[i]->factor().params();
if (Globals::logDomain) {
Util::fromLog (params);
}
@ -450,7 +349,6 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
}
out << endl;
}
out.close();
}
@ -461,23 +359,22 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "FactorGraph::exportToLibDaiFormat()" << endl;
cerr << "error: cannot open file " << fileName << endl;
abort();
}
out << facNodes_.size() << endl << endl;
for (unsigned i = 0; i < facNodes_.size(); i++) {
const FgVarSet& factorVars = facNodes_[i]->neighbors();
const VarNodes& factorVars = facNodes_[i]->neighbors();
out << factorVars.size() << endl;
for (int j = factorVars.size() - 1; j >= 0; j--) {
out << factorVars[j]->varId() << " " ;
}
out << endl;
for (unsigned j = 0; j < factorVars.size(); j++) {
out << factorVars[j]->nrStates() << " " ;
out << factorVars[j]->range() << " " ;
}
out << endl;
Params params = facNodes_[i]->factor()->getParameters();
Params params = facNodes_[i]->factor().params();
if (Globals::logDomain) {
Util::fromLog (params);
}
@ -492,6 +389,17 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
void
FactorGraph::ignoreLines (std::ifstream& is) const
{
string ignoreStr;
while (is.peek() == '#' || is.peek() == '\n') {
getline (is, ignoreStr);
}
}
bool
FactorGraph::containsCycle (void) const
{
@ -511,13 +419,14 @@ FactorGraph::containsCycle (void) const
bool
FactorGraph::containsCycle (const FgVarNode* v,
const FgFacNode* p,
vector<bool>& visitedVars,
vector<bool>& visitedFactors) const
FactorGraph::containsCycle (
const VarNode* v,
const FacNode* p,
vector<bool>& visitedVars,
vector<bool>& visitedFactors) const
{
visitedVars[v->getIndex()] = true;
const FgFacSet& adjacencies = v->neighbors();
const FacNodes& adjacencies = v->neighbors();
for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i]->getIndex();
if (!visitedFactors[w]) {
@ -535,13 +444,14 @@ FactorGraph::containsCycle (const FgVarNode* v,
bool
FactorGraph::containsCycle (const FgFacNode* v,
const FgVarNode* p,
vector<bool>& visitedVars,
vector<bool>& visitedFactors) const
FactorGraph::containsCycle (
const FacNode* v,
const VarNode* p,
vector<bool>& visitedVars,
vector<bool>& visitedFactors) const
{
visitedFactors[v->getIndex()] = true;
const FgVarSet& adjacencies = v->neighbors();
const VarNodes& adjacencies = v->neighbors();
for (unsigned i = 0; i < adjacencies.size(); i++) {
int w = adjacencies[i]->getIndex();
if (!visitedVars[w]) {

View File

@ -3,135 +3,139 @@
#include <vector>
#include "GraphicalModel.h"
#include "Distribution.h"
#include "Factor.h"
#include "BayesNet.h"
#include "Horus.h"
using namespace std;
class BayesNet;
class FgFacNode;
class FgVarNode : public VarNode
class FacNode;
class VarNode : public Var
{
public:
FgVarNode (VarId varId, unsigned nrStates) : VarNode (varId, nrStates) { }
FgVarNode (const VarNode* v) : VarNode (v) { }
VarNode (VarId varId, unsigned nrStates)
: Var (varId, nrStates) { }
void addNeighbor (FgFacNode* fn) { neighs_.push_back (fn); }
const FgFacSet& neighbors (void) const { return neighs_; }
VarNode (const Var* v) : Var (v) { }
void addNeighbor (FacNode* fn) { neighs_.push_back (fn); }
const FacNodes& neighbors (void) const { return neighs_; }
private:
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
// members
FgFacSet neighs_;
DISALLOW_COPY_AND_ASSIGN (VarNode);
FacNodes neighs_;
};
class FgFacNode
class FacNode
{
public:
FgFacNode (const FgFacNode* fn) {
factor_ = new Factor (*fn->factor());
index_ = -1;
}
FgFacNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
Factor* factor() const { return factor_; }
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); }
const FgVarSet& neighbors (void) const { return neighs_; }
FacNode (const Factor& f) : factor_(f), index_(-1) { }
const Factor& factor (void) const { return factor_; }
Factor& factor (void) { return factor_; }
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
const VarNodes& neighbors (void) const { return neighs_; }
int getIndex (void) const { return index_; }
void setIndex (int index) { index_ = index; }
string getLabel (void) { return factor_.getLabel(); }
int getIndex (void) const
{
assert (index_ != -1);
return index_;
}
void setIndex (int index)
{
index_ = index;
}
Distribution* getDistribution (void)
{
return factor_->getDistribution();
}
const Params& getParameters (void) const
{
return factor_->getParameters();
}
string getLabel (void)
{
return factor_->getLabel();
}
private:
DISALLOW_COPY_AND_ASSIGN (FgFacNode);
DISALLOW_COPY_AND_ASSIGN (FacNode);
Factor* factor_;
int index_;
FgVarSet neighs_;
VarNodes neighs_;
Factor factor_;
int index_;
};
struct CompVarId
{
bool operator() (const VarNode* vn1, const VarNode* vn2) const
bool operator() (const Var* v1, const Var* v2) const
{
return vn1->varId() < vn2->varId();
return v1->varId() < v2->varId();
}
};
class FactorGraph : public GraphicalModel
class FactorGraph
{
public:
FactorGraph (void) {};
FactorGraph (bool fbn = false) : fromBayesNet_(fbn) { }
FactorGraph (const FactorGraph&);
FactorGraph (const BayesNet&);
~FactorGraph (void);
void readFromUaiFormat (const char*);
void readFromLibDaiFormat (const char*);
void addVariable (FgVarNode*);
void addFactor (FgFacNode*);
void addEdge (FgVarNode*, FgFacNode*);
void addEdge (FgFacNode*, FgVarNode*);
VarNode* getVariableNode (unsigned) const;
VarNodes getVariableNodes (void) const;
bool isTree (void) const;
void setIndexes (void);
void freeDistributions (void);
void printGraphicalModel (void) const;
void exportToGraphViz (const char*) const;
void exportToUaiFormat (const char*) const;
void exportToLibDaiFormat (const char*) const;
const FgVarSet& getVarNodes (void) const { return varNodes_; }
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
const VarNodes& varNodes (void) const { return varNodes_; }
FgVarNode* getFgVarNode (VarId vid) const
const FacNodes& facNodes (void) const { return facNodes_; }
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
VarNode* getVarNode (VarId vid) const
{
IndexMap::const_iterator it = varMap_.find (vid);
if (it == varMap_.end()) {
return 0;
} else {
return varNodes_[it->second];
}
VarMap::const_iterator it = varMap_.find (vid);
return it != varMap_.end() ? it->second : 0;
}
void readFromUaiFormat (const char*);
void readFromLibDaiFormat (const char*);
void addFactor (const Factor& factor);
void addVarNode (VarNode*);
void addFacNode (FacNode*);
void addEdge (VarNode*, FacNode*);
bool isTree (void) const;
DAGraph& getStructure (void);
void print (void) const;
void exportToGraphViz (const char*) const;
void exportToUaiFormat (const char*) const;
void exportToLibDaiFormat (const char*) const;
static bool orderFactorVariables;
private:
//DISALLOW_COPY_AND_ASSIGN (FactorGraph);
bool containsCycle (void) const;
bool containsCycle (const FgVarNode*, const FgFacNode*,
vector<bool>&, vector<bool>&) const;
bool containsCycle (const FgFacNode*, const FgVarNode*,
vector<bool>&, vector<bool>&) const;
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
FgVarSet varNodes_;
FgFacSet facNodes_;
void ignoreLines (std::ifstream&) const;
typedef unordered_map<unsigned, unsigned> IndexMap;
IndexMap varMap_;
bool containsCycle (void) const;
bool containsCycle (const VarNode*, const FacNode*,
vector<bool>&, vector<bool>&) const;
bool containsCycle (const FacNode*, const VarNode*,
vector<bool>&, vector<bool>&) const;
VarNodes varNodes_;
FacNodes facNodes_;
DAGraph structure_;
bool fromBayesNet_;
typedef unordered_map<unsigned, VarNode*> VarMap;
VarMap varMap_;
};
#endif // HORUS_FACTORGRAPH_H

View File

@ -1,175 +0,0 @@
#ifndef HORUS_FGBPSOLVER_H
#define HORUS_FGBPSOLVER_H
#include <set>
#include <vector>
#include <sstream>
#include "Solver.h"
#include "Factor.h"
#include "FactorGraph.h"
#include "Util.h"
using namespace std;
class SpLink
{
public:
SpLink (FgFacNode* fn, FgVarNode* vn)
{
fac_ = fn;
var_ = vn;
v1_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
v2_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
currMsg_ = &v1_;
nextMsg_ = &v2_;
msgSended_ = false;
residual_ = 0.0;
}
virtual ~SpLink (void) {};
virtual void updateMessage (void)
{
swap (currMsg_, nextMsg_);
msgSended_ = true;
}
void updateResidual (void)
{
residual_ = Util::getMaxNorm (v1_, v2_);
}
string toString (void) const
{
stringstream ss;
ss << fac_->getLabel();
ss << " -- " ;
ss << var_->label();
return ss.str();
}
FgFacNode* getFactor (void) const { return fac_; }
FgVarNode* getVariable (void) const { return var_; }
const Params& getMessage (void) const { return *currMsg_; }
Params& getNextMessage (void) { return *nextMsg_; }
bool messageWasSended (void) const { return msgSended_; }
double getResidual (void) const { return residual_; }
void clearResidual (void) { residual_ = 0.0; }
protected:
FgFacNode* fac_;
FgVarNode* var_;
Params v1_;
Params v2_;
Params* currMsg_;
Params* nextMsg_;
bool msgSended_;
double residual_;
};
typedef vector<SpLink*> SpLinkSet;
class SPNodeInfo
{
public:
void addSpLink (SpLink* link) { links_.push_back (link); }
const SpLinkSet& getLinks (void) { return links_; }
private:
SpLinkSet links_;
};
class FgBpSolver : public Solver
{
public:
FgBpSolver (const FactorGraph&);
virtual ~FgBpSolver (void);
void runSolver (void);
virtual Params getPosterioriOf (VarId);
virtual Params getJointDistributionOf (const VarIds&);
protected:
virtual void initializeSolver (void);
virtual void createLinks (void);
virtual void maxResidualSchedule (void);
virtual void calculateFactor2VariableMsg (SpLink*) const;
virtual Params getVar2FactorMsg (const SpLink*) const;
virtual Params getJointByConditioning (const VarIds&) const;
virtual void printLinkInformation (void) const;
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
{
if (DL >= 3) {
cout << "calculating & updating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
if (calcResidual) {
link->updateResidual();
}
link->updateMessage();
}
void calculateMessage (SpLink* link, bool calcResidual = true)
{
if (DL >= 3) {
cout << "calculating " << link->toString() << endl;
}
calculateFactor2VariableMsg (link);
if (calcResidual) {
link->updateResidual();
}
}
void updateMessage (SpLink* link)
{
link->updateMessage();
if (DL >= 3) {
cout << "updating " << link->toString() << endl;
}
}
SPNodeInfo* ninf (const FgVarNode* var) const
{
return varsI_[var->getIndex()];
}
SPNodeInfo* ninf (const FgFacNode* fac) const
{
return facsI_[fac->getIndex()];
}
struct CompareResidual {
inline bool operator() (const SpLink* link1, const SpLink* link2)
{
return link1->getResidual() > link2->getResidual();
}
};
SpLinkSet links_;
unsigned nIters_;
vector<SPNodeInfo*> varsI_;
vector<SPNodeInfo*> facsI_;
const FactorGraph* factorGraph_;
typedef multiset<SpLink*, CompareResidual> SortedOrder;
SortedOrder sortedOrder_;
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
SpLinkMap linkMap_;
private:
void runLoopySolver (void);
bool converged (void);
};
#endif // HORUS_FGBPSOLVER_H

View File

@ -8,7 +8,9 @@
vector<LiftedOperator*>
LiftedOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
LiftedOperator::getValidOps (
ParfactorList& pfList,
const Grounds& query)
{
vector<LiftedOperator*> validOps;
vector<SumOutOperator*> sumOutOps;
@ -28,12 +30,15 @@ LiftedOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
void
LiftedOperator::printValidOps (ParfactorList& pfList, const Grounds& query)
LiftedOperator::printValidOps (
ParfactorList& pfList,
const Grounds& query)
{
vector<LiftedOperator*> validOps;
validOps = LiftedOperator::getValidOps (pfList, query);
for (unsigned i = 0; i < validOps.size(); i++) {
cout << "-> " << validOps[i]->toString() << endl;
delete validOps[i];
}
}
@ -56,14 +61,14 @@ SumOutOperator::getCost (void)
pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
if ((*pfIter)->containsGroup (groupSet[i])) {
int idx = (*pfIter)->indexOfFormulaWithGroup (groupSet[i]);
int idx = (*pfIter)->indexOfGroup (groupSet[i]);
cost *= (*pfIter)->range (idx);
break;
}
++ pfIter;
}
}
return cost;
return cost;
}
@ -77,14 +82,13 @@ SumOutOperator::apply (void)
pfList_.remove (iters[0]);
for (unsigned i = 1; i < iters.size(); i++) {
product->multiply (**(iters[i]));
delete *(iters[i]);
pfList_.remove (iters[i]);
pfList_.removeAndDelete (iters[i]);
}
if (product->nrFormulas() == 1) {
if (product->nrArguments() == 1) {
delete product;
return;
}
int fIdx = product->indexOfFormulaWithGroup (group_);
int fIdx = product->indexOfGroup (group_);
LogVarSet excl = product->exclusiveLogVars (fIdx);
if (product->constr()->isCountNormalized (excl)) {
product->sumOut (fIdx);
@ -96,21 +100,21 @@ SumOutOperator::apply (void)
pfList_.add (pfs[i]);
}
delete product;
pfList_.shatter();
}
}
vector<SumOutOperator*>
SumOutOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
SumOutOperator::getValidOps (
ParfactorList& pfList,
const Grounds& query)
{
vector<SumOutOperator*> validOps;
set<unsigned> allGroups;
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
assert (*it);
const ProbFormulas& formulas = (*it)->formulas();
const ProbFormulas& formulas = (*it)->arguments();
for (unsigned i = 0; i < formulas.size(); i++) {
allGroups.insert (formulas[i].group());
}
@ -134,8 +138,8 @@ SumOutOperator::toString (void)
stringstream ss;
vector<ParfactorList::iterator> pfIters;
pfIters = parfactorsWithGroup (pfList_, group_);
int idx = (*pfIters[0])->indexOfFormulaWithGroup (group_);
ProbFormula f = (*pfIters[0])->formula (idx);
int idx = (*pfIters[0])->indexOfGroup (group_);
ProbFormula f = (*pfIters[0])->argument (idx);
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
ss << "sum out " << f.functor() << "/" << f.arity();
ss << "|" << tupleSet << " (group " << group_ << ")";
@ -158,9 +162,9 @@ SumOutOperator::validOp (
}
unordered_map<unsigned, unsigned> groupToRange;
for (unsigned i = 0; i < pfIters.size(); i++) {
int fIdx = (*pfIters[i])->indexOfFormulaWithGroup (group);
if ((*pfIters[i])->formulas()[fIdx].contains (
(*pfIters[i])->elimLogVars()) == false) {
int fIdx = (*pfIters[i])->indexOfGroup (group);
if ((*pfIters[i])->argument (fIdx).contains (
(*pfIters[i])->elimLogVars()) == false) {
return false;
}
vector<unsigned> ranges = (*pfIters[i])->ranges();
@ -206,8 +210,8 @@ SumOutOperator::isToEliminate (
unsigned group,
const Grounds& query)
{
int fIdx = g->indexOfFormulaWithGroup (group);
const ProbFormula& formula = g->formula (fIdx);
int fIdx = g->indexOfGroup (group);
const ProbFormula& formula = g->argument (fIdx);
bool toElim = true;
for (unsigned i = 0; i < query.size(); i++) {
if (formula.functor() == query[i].functor() &&
@ -228,7 +232,7 @@ unsigned
CountingOperator::getCost (void)
{
unsigned cost = 0;
int fIdx = (*pfIter_)->indexOfFormulaWithLogVar (X_);
int fIdx = (*pfIter_)->indexOfLogVar (X_);
unsigned range = (*pfIter_)->range (fIdx);
unsigned size = (*pfIter_)->size() / range;
TinySet<unsigned> counts;
@ -247,18 +251,19 @@ CountingOperator::apply (void)
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
(*pfIter_)->countConvert (X_);
} else {
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_);
Parfactors pfs = FoveSolver::countNormalize (pf, X_);
for (unsigned i = 0; i < pfs.size(); i++) {
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
bool cartProduct = pfs[i]->constr()->isCarteesianProduct (
(*pfIter_)->countedLogVars() | X_);
pfs[i]->countedLogVars() | X_);
if (condCount > 1 && cartProduct) {
pfs[i]->countConvert (X_);
}
pfList_.add (pfs[i]);
}
pfList_.deleteAndRemove (pfIter_);
pfList_.shatter();
delete pf;
}
}
@ -289,14 +294,17 @@ CountingOperator::toString (void)
{
stringstream ss;
ss << "count convert " << X_ << " in " ;
ss << (*pfIter_)->getHeaderString();
ss << (*pfIter_)->getLabel();
ss << " [cost=" << getCost() << "]" << endl;
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
for (unsigned i = 0; i < pfs.size(); i++) {
ss << " º " << pfs[i]->getHeaderString() << endl;
ss << " º " << pfs[i]->getLabel() << endl;
}
}
for (unsigned i = 0; i < pfs.size(); i++) {
delete pfs[i];
}
return ss.str();
}
@ -308,8 +316,8 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
if (g->nrFormulas (X) != 1) {
return false;
}
int fIdx = g->indexOfFormulaWithLogVar (X);
if (g->formulas()[fIdx].isCounting()) {
int fIdx = g->indexOfLogVar (X);
if (g->argument (fIdx).isCounting()) {
return false;
}
bool countNormalized = g->constr()->isCountNormalized (X);
@ -332,10 +340,10 @@ GroundOperator::getCost (void)
unsigned cost = 0;
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
if (isCountingLv) {
int fIdx = (*pfIter_)->indexOfFormulaWithLogVar (X_);
int fIdx = (*pfIter_)->indexOfLogVar (X_);
unsigned currSize = (*pfIter_)->size();
unsigned nrHists = (*pfIter_)->range (fIdx);
unsigned range = (*pfIter_)->formula(fIdx).range();
unsigned range = (*pfIter_)->argument (fIdx).range();
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
} else {
@ -350,18 +358,17 @@ void
GroundOperator::apply (void)
{
bool countedLv = (*pfIter_)->countedLogVars().contains (X_);
Parfactor* pf = *pfIter_;
pfList_.remove (pfIter_);
if (countedLv) {
(*pfIter_)->fullExpand (X_);
(*pfIter_)->setNewGroups();
pfList_.shatter();
pf->fullExpand (X_);
pfList_.add (pf);
} else {
ConstraintTrees cts = (*pfIter_)->constr()->ground (X_);
ConstraintTrees cts = pf->constr()->ground (X_);
for (unsigned i = 0; i < cts.size(); i++) {
Parfactor* newPf = new Parfactor (*pfIter_, cts[i]);
pfList_.add (newPf);
pfList_.add (new Parfactor (pf, cts[i]));
}
pfList_.deleteAndRemove (pfIter_);
pfList_.shatter();
delete pf;
}
}
@ -393,24 +400,13 @@ GroundOperator::toString (void)
((*pfIter_)->countedLogVars().contains (X_))
? ss << "full expanding "
: ss << "grounding " ;
ss << X_ << " in " << (*pfIter_)->getHeaderString();
ss << X_ << " in " << (*pfIter_)->getLabel();
ss << " [cost=" << getCost() << "]" << endl;
return ss.str();
}
FoveSolver::FoveSolver (const ParfactorList* pfList)
{
for (ParfactorList::const_iterator it = pfList->begin();
it != pfList->end();
it ++) {
pfList_.addShattered (new Parfactor (**it));
}
}
Params
FoveSolver::getPosterioriOf (const Ground& query)
{
@ -422,14 +418,12 @@ FoveSolver::getPosterioriOf (const Ground& query)
Params
FoveSolver::getJointDistributionOf (const Grounds& query)
{
shatterAgainstQuery (query);
runSolver (query);
(*pfList_.begin())->normalize();
Params params = (*pfList_.begin())->params();
if (Globals::logDomain) {
Util::fromLog (params);
}
delete *pfList_.begin();
return params;
}
@ -438,32 +432,38 @@ FoveSolver::getJointDistributionOf (const Grounds& query)
void
FoveSolver::absorveEvidence (
ParfactorList& pfList,
const ObservedFormulas& obsFormulas)
ObservedFormulas& obsFormulas)
{
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
bool increment = true;
for (unsigned i = 0; i < obsFormulas.size(); i++) {
if (absorved (pfList, it, obsFormulas[i])) {
it = pfList.deleteAndRemove (it);
increment = false;
break;
}
}
if (increment) {
++ it;
for (unsigned i = 0; i < obsFormulas.size(); i++) {
Parfactors newPfs;
ParfactorList::iterator it = pfList.begin();
while (it != pfList.end()) {
Parfactor* pf = *it;
it = pfList.remove (it);
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
if (absorvedPfs.empty() == false) {
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
// just remove pf;
} else {
Util::addToVector (newPfs, absorvedPfs);
}
delete pf;
} else {
it = pfList.insertShattered (it, pf);
++ it;
}
}
pfList.add (newPfs);
}
pfList.shatter();
if (obsFormulas.empty() == false) {
cout << "*******************************************************" << endl;
if (Constants::DEBUG >= 2 && obsFormulas.empty() == false) {
Util::printAsteriskLine();
cout << "AFTER EVIDENCE ABSORVED" << endl;
for (unsigned i = 0; i < obsFormulas.size(); i++) {
cout << " -> " << *obsFormulas[i] << endl;
cout << " -> " << obsFormulas[i] << endl;
}
cout << "*******************************************************" << endl;
Util::printAsteriskLine();
pfList.print();
}
pfList.print();
}
@ -473,14 +473,14 @@ FoveSolver::countNormalize (
Parfactor* g,
const LogVarSet& set)
{
if (set.empty()) {
assert (false); // TODO
return {};
}
Parfactors normPfs;
ConstraintTrees normCts = g->constr()->countNormalize (set);
for (unsigned i = 0; i < normCts.size(); i++) {
normPfs.push_back (new Parfactor (g, normCts[i]));
if (set.empty()) {
normPfs.push_back (new Parfactor (*g));
} else {
ConstraintTrees normCts = g->constr()->countNormalize (set);
for (unsigned i = 0; i < normCts.size(); i++) {
normPfs.push_back (new Parfactor (g, normCts[i]));
}
}
return normPfs;
}
@ -490,17 +490,25 @@ FoveSolver::countNormalize (
void
FoveSolver::runSolver (const Grounds& query)
{
shatterAgainstQuery (query);
runWeakBayesBall (query);
while (true) {
cout << "---------------------------------------------------" << endl;
pfList_.print();
LiftedOperator::printValidOps (pfList_, query);
if (Constants::DEBUG >= 2) {
Util::printDashedLine();
pfList_.print();
LiftedOperator::printValidOps (pfList_, query);
}
LiftedOperator* op = getBestOperation (query);
if (op == 0) {
break;
}
cout << "best operation: " << op->toString() << endl;
if (Constants::DEBUG >= 2) {
cout << "best operation: " << op->toString() << endl;
}
op->apply();
delete op;
}
assert (pfList_.size() > 0);
if (pfList_.size() > 1) {
ParfactorList::iterator pfIter = pfList_.begin();
pfIter ++;
@ -514,26 +522,6 @@ FoveSolver::runSolver (const Grounds& query)
bool
FoveSolver::allEliminated (const Grounds&)
{
ParfactorList::iterator pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
const ProbFormulas formulas = (*pfIter)->formulas();
for (unsigned i = 0; i < formulas.size(); i++) {
//bool toElim = false;
//for (unsigned j = 0; j < queries.size(); j++) {
// if ((*pfIter)->containsGround (queries[i]) == false) {
// return
// }
}
++ pfIter;
}
return false;
}
LiftedOperator*
FoveSolver::getBestOperation (const Grounds& query)
{
@ -548,156 +536,176 @@ FoveSolver::getBestOperation (const Grounds& query)
bestCost = cost;
}
}
for (unsigned i = 0; i < validOps.size(); i++) {
if (validOps[i] != bestOp) {
delete validOps[i];
}
}
return bestOp;
}
void
FoveSolver::runWeakBayesBall (const Grounds& query)
{
queue<unsigned> todo; // groups to process
set<unsigned> done; // processed or in queue
for (unsigned i = 0; i < query.size(); i++) {
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
int group = (*it)->findGroup (query[i]);
if (group != -1) {
todo.push (group);
done.insert (group);
break;
}
++ it;
}
}
set<Parfactor*> requiredPfs;
while (todo.empty() == false) {
unsigned group = todo.front();
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false &&
(*it)->containsGroup (group)) {
vector<unsigned> groups = (*it)->getAllGroups();
for (unsigned i = 0; i < groups.size(); i++) {
if (Util::contains (done, groups[i]) == false) {
todo.push (groups[i]);
done.insert (groups[i]);
}
}
requiredPfs.insert (*it);
}
++ it;
}
todo.pop();
}
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if (Util::contains (requiredPfs, *it) == false) {
it = pfList_.removeAndDelete (it);
} else {
++ it;
}
}
if (Constants::DEBUG >= 2) {
Util::printHeader ("REQUIRED PARFACTORS");
pfList_.print();
}
}
void
FoveSolver::shatterAgainstQuery (const Grounds& query)
{
// return;
for (unsigned i = 0; i < query.size(); i++) {
if (query[i].isAtom()) {
continue;
}
ParfactorList pfListCopy = pfList_;
pfList_.clear();
for (ParfactorList::iterator it = pfListCopy.begin();
it != pfListCopy.end(); ++ it) {
Parfactor* pf = *it;
if (pf->containsGround (query[i])) {
bool found = false;
Parfactors newPfs;
ParfactorList::iterator it = pfList_.begin();
while (it != pfList_.end()) {
if ((*it)->containsGround (query[i])) {
found = true;
std::pair<ConstraintTree*, ConstraintTree*> split =
pf->constr()->split (query[i].args(), query[i].arity());
(*it)->constr()->split (query[i].args(), query[i].arity());
ConstraintTree* commCt = split.first;
ConstraintTree* exclCt = split.second;
pfList_.add (new Parfactor (pf, commCt));
newPfs.push_back (new Parfactor (*it, commCt));
if (exclCt->empty() == false) {
pfList_.add (new Parfactor (pf, exclCt));
newPfs.push_back (new Parfactor (*it, exclCt));
} else {
delete exclCt;
}
delete pf;
it = pfList_.removeAndDelete (it);
} else {
pfList_.add (pf);
++ it;
}
}
pfList_.shatter();
if (found == false) {
cerr << "error: could not find a parfactor with ground " ;
cerr << "`" << query[i] << "'" << endl;
exit (0);
}
pfList_.add (newPfs);
}
cout << endl;
cout << "*******************************************************" << endl;
cout << "SHATTERED AGAINST THE QUERY" << endl;
for (unsigned i = 0; i < query.size(); i++) {
cout << " -> " << query[i] << endl;
if (Constants::DEBUG >= 2) {
cout << endl;
Util::printAsteriskLine();
cout << "SHATTERED AGAINST THE QUERY" << endl;
for (unsigned i = 0; i < query.size(); i++) {
cout << " -> " << query[i] << endl;
}
Util::printAsteriskLine();
pfList_.print();
}
cout << "*******************************************************" << endl;
pfList_.print();
}
bool
FoveSolver::absorved (
ParfactorList& pfList,
ParfactorList::iterator pfIter,
const ObservedFormula* obsFormula)
Parfactors
FoveSolver::absorve (
ObservedFormula& obsFormula,
Parfactor* g)
{
Parfactors absorvedPfs;
Parfactor* g = *pfIter;
const ProbFormulas& formulas = g->formulas();
const ProbFormulas& formulas = g->arguments();
for (unsigned i = 0; i < formulas.size(); i++) {
if (obsFormula->functor() == formulas[i].functor() &&
obsFormula->arity() == formulas[i].arity()) {
if (obsFormula.functor() == formulas[i].functor() &&
obsFormula.arity() == formulas[i].arity()) {
if (obsFormula->isAtom()) {
if (obsFormula.isAtom()) {
if (formulas.size() > 1) {
g->absorveEvidence (i, obsFormula->evidence());
g->absorveEvidence (formulas[i], obsFormula.evidence());
} else {
return true;
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
}
g->constr()->moveToTop (formulas[i].logVars());
std::pair<ConstraintTree*, ConstraintTree*> res
= g->constr()->split (obsFormula->constr(), formulas[i].arity());
= g->constr()->split (&(obsFormula.constr()), formulas[i].arity());
ConstraintTree* commCt = res.first;
ConstraintTree* exclCt = res.second;
if (commCt->empty()) {
delete commCt;
delete exclCt;
continue;
}
if (exclCt->empty() == false) {
pfList.add (new Parfactor (g, exclCt));
} else {
delete exclCt;
}
if (formulas.size() > 1) {
LogVarSet excl = g->exclusiveLogVars (i);
Parfactors countNormPfs = countNormalize (g, excl);
for (unsigned j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence (i, obsFormula->evidence());
absorvedPfs.push_back (countNormPfs[j]);
if (commCt->empty() == false) {
if (formulas.size() > 1) {
LogVarSet excl = g->exclusiveLogVars (i);
Parfactors countNormPfs = countNormalize (g, excl);
for (unsigned j = 0; j < countNormPfs.size(); j++) {
countNormPfs[j]->absorveEvidence (
formulas[i], obsFormula.evidence());
absorvedPfs.push_back (countNormPfs[j]);
}
} else {
delete commCt;
}
if (exclCt->empty() == false) {
absorvedPfs.push_back (new Parfactor (g, exclCt));
} else {
delete exclCt;
}
if (absorvedPfs.empty()) {
// hack to erase parfactor g
absorvedPfs.push_back (0);
}
break;
} else {
delete commCt;
delete exclCt;
}
return true;
}
}
return false;
}
bool
FoveSolver::proper (
const ProbFormula& f1,
ConstraintTree* c1,
const ProbFormula& f2,
ConstraintTree* c2)
{
return disjoint (f1, c1, f2, c2)
|| identical (f1, c1, f2, c2);
}
bool
FoveSolver::identical (
const ProbFormula& f1,
ConstraintTree* c1,
const ProbFormula& f2,
ConstraintTree* c2)
{
if (f1.sameSkeletonAs (f2) == false) {
return false;
}
c1->moveToTop (f1.logVars());
c2->moveToTop (f2.logVars());
return ConstraintTree::identical (
c1, c2, f1.logVars().size());
}
bool
FoveSolver::disjoint (
const ProbFormula& f1,
ConstraintTree* c1,
const ProbFormula& f2,
ConstraintTree* c2)
{
if (f1.sameSkeletonAs (f2) == false) {
return true;
}
c1->moveToTop (f1.logVars());
c2->moveToTop (f2.logVars());
return ConstraintTree::overlap (
c1, c2, f1.arity()) == false;
return absorvedPfs;
}

View File

@ -9,10 +9,14 @@ class LiftedOperator
{
public:
virtual unsigned getCost (void) = 0;
virtual void apply (void) = 0;
virtual string toString (void) = 0;
static vector<LiftedOperator*> getValidOps (
ParfactorList&, const Grounds&);
static void printValidOps (ParfactorList&, const Grounds&);
};
@ -23,18 +27,26 @@ class SumOutOperator : public LiftedOperator
public:
SumOutOperator (unsigned group, ParfactorList& pfList)
: group_(group), pfList_(pfList) { }
unsigned getCost (void);
void apply (void);
static vector<SumOutOperator*> getValidOps (
ParfactorList&, const Grounds&);
string toString (void);
private:
static bool validOp (unsigned, ParfactorList&, const Grounds&);
static vector<ParfactorList::iterator> parfactorsWithGroup (
ParfactorList& pfList, unsigned group);
static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
unsigned group_;
ParfactorList& pfList_;
unsigned group_;
ParfactorList& pfList_;
};
@ -47,15 +59,21 @@ class CountingOperator : public LiftedOperator
LogVar X,
ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
unsigned getCost (void);
void apply (void);
static vector<CountingOperator*> getValidOps (ParfactorList&);
string toString (void);
private:
static bool validOp (Parfactor*, LogVar);
ParfactorList::iterator pfIter_;
LogVar X_;
ParfactorList& pfList_;
ParfactorList::iterator pfIter_;
LogVar X_;
ParfactorList& pfList_;
};
@ -68,14 +86,19 @@ class GroundOperator : public LiftedOperator
LogVar X,
ParfactorList& pfList)
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
unsigned getCost (void);
void apply (void);
static vector<GroundOperator*> getValidOps (ParfactorList&);
string toString (void);
private:
ParfactorList::iterator pfIter_;
LogVar X_;
ParfactorList& pfList_;
ParfactorList::iterator pfIter_;
LogVar X_;
ParfactorList& pfList_;
};
@ -83,49 +106,29 @@ class GroundOperator : public LiftedOperator
class FoveSolver
{
public:
FoveSolver (const ParfactorList*);
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
Params getPosterioriOf (const Ground&);
Params getJointDistributionOf (const Grounds&);
Params getPosterioriOf (const Ground&);
static void absorveEvidence (
ParfactorList& pfList,
const ObservedFormulas& obsFormulas);
Params getJointDistributionOf (const Grounds&);
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
static void absorveEvidence (
ParfactorList& pfList, ObservedFormulas& obsFormulas);
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
private:
void runSolver (const Grounds&);
bool allEliminated (const Grounds&);
LiftedOperator* getBestOperation (const Grounds&);
void shatterAgainstQuery (const Grounds&);
void runSolver (const Grounds&);
static bool absorved (
ParfactorList& pfList,
ParfactorList::iterator pfIter,
const ObservedFormula*);
LiftedOperator* getBestOperation (const Grounds&);
public:
void runWeakBayesBall (const Grounds&);
static bool proper (
const ProbFormula&,
ConstraintTree*,
const ProbFormula&,
ConstraintTree*);
void shatterAgainstQuery (const Grounds&);
static bool identical (
const ProbFormula&,
ConstraintTree*,
const ProbFormula&,
ConstraintTree*);
static Parfactors absorve (ObservedFormula&, Parfactor*);
static bool disjoint (
const ProbFormula&,
ConstraintTree*,
const ProbFormula&,
ConstraintTree*);
ParfactorList pfList_;
ParfactorList pfList_;
};
#endif // HORUS_FOVESOLVER_H

View File

@ -1,67 +0,0 @@
#ifndef HORUS_GRAPHICALMODEL_H
#define HORUS_GRAPHICALMODEL_H
#include <sstream>
#include "VarNode.h"
#include "Distribution.h"
#include "Horus.h"
using namespace std;
struct VariableInfo
{
VariableInfo (string l, const States& sts)
{
label = l;
states = sts;
}
string label;
States states;
};
class GraphicalModel
{
public:
virtual ~GraphicalModel (void) {};
virtual VarNode* getVariableNode (VarId) const = 0;
virtual VarNodes getVariableNodes (void) const = 0;
virtual void printGraphicalModel (void) const = 0;
static void addVariableInformation (VarId vid, string label,
const States& states)
{
assert (varsInfo_.find (vid) == varsInfo_.end());
varsInfo_.insert (make_pair (vid, VariableInfo (label, states)));
}
static VariableInfo getVariableInformation (VarId vid)
{
assert (varsInfo_.find (vid) != varsInfo_.end());
return varsInfo_.find (vid)->second;
}
static bool variablesHaveInformation (void)
{
return varsInfo_.size() != 0;
}
static void clearVariablesInformation (void)
{
varsInfo_.clear();
}
static void addDistribution (unsigned id, Distribution* dist)
{
distsInfo_[id] = dist;
}
static void updateDistribution (unsigned id, const Params& params)
{
distsInfo_[id]->updateParameters (params);
}
private:
static unordered_map<VarId,VariableInfo> varsInfo_;
static unordered_map<unsigned,Distribution*> distsInfo_;
};
#endif // HORUS_GRAPHICALMODEL_H

View File

@ -84,16 +84,34 @@ HistogramSet::nrHistograms (unsigned N, unsigned R)
unsigned
HistogramSet::findIndex (
const Histogram& hist,
const vector<Histogram>& histograms)
const Histogram& h,
const vector<Histogram>& hists)
{
vector<Histogram>::const_iterator it = std::lower_bound (
histograms.begin(),
histograms.end(),
hist,
std::greater<Histogram>());
assert (it != histograms.end() && *it == hist);
return std::distance (histograms.begin(), it);
hists.begin(), hists.end(), h, std::greater<Histogram>());
assert (it != hists.end() && *it == h);
return std::distance (hists.begin(), it);
}
vector<double>
HistogramSet::getNumAssigns (unsigned N, unsigned R)
{
HistogramSet hs (N, R);
unsigned N_factorial = Util::factorial (N);
unsigned H = hs.nrHistograms();
vector<double> numAssigns;
numAssigns.reserve (H);
for (unsigned h = 0; h < H; h++) {
unsigned prod = 1;
for (unsigned r = 0; r < R; r++) {
prod *= Util::factorial (hs[r]);
}
numAssigns.push_back (LogAware::tl (N_factorial / prod));
hs.nextHistogram();
}
return numAssigns;
}

View File

@ -26,8 +26,9 @@ class HistogramSet
static unsigned nrHistograms (unsigned, unsigned);
static unsigned findIndex (
const Histogram&,
const vector<Histogram>&);
const Histogram&, const vector<Histogram>&);
static vector<double> getNumAssigns (unsigned, unsigned);
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);

View File

@ -1,17 +1,9 @@
#ifndef HORUS_HORUS_H
#define HORUS_HORUS_H
#include <cmath>
#include <cassert>
#include <limits>
#include <algorithm>
#include <vector>
#include <unordered_map>
#include <iostream>
#include <fstream>
#include <sstream>
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&); \
@ -19,55 +11,51 @@
using namespace std;
class VarNode;
class BayesNode;
class FgVarNode;
class FgFacNode;
class Var;
class Factor;
class VarNode;
class FacNode;
typedef vector<double> Params;
typedef unsigned VarId;
typedef vector<VarId> VarIds;
typedef vector<VarNode*> VarNodes;
typedef vector<BayesNode*> BnNodeSet;
typedef vector<FgVarNode*> FgVarSet;
typedef vector<FgFacNode*> FgFacSet;
typedef vector<Factor*> FactorSet;
typedef vector<string> States;
typedef vector<unsigned> Ranges;
typedef vector<double> Params;
typedef unsigned VarId;
typedef vector<VarId> VarIds;
typedef vector<Var*> Vars;
typedef vector<VarNode*> VarNodes;
typedef vector<FacNode*> FacNodes;
typedef vector<Factor*> Factors;
typedef vector<string> States;
typedef vector<unsigned> Ranges;
namespace Globals {
extern bool logDomain;
enum InfAlgorithms
{
VE, // variable elimination
BP, // belief propagation
CBP // counting belief propagation
};
// level of debug information
static const unsigned DL = 1;
namespace Globals {
static const int NO_EVIDENCE = -1;
extern bool logDomain;
extern InfAlgorithms infAlgorithm;
};
namespace Constants {
// level of debug information
const unsigned DEBUG = 0;
const int NO_EVIDENCE = -1;
// number of digits to show when printing a parameter
static const unsigned PRECISION = 5;
const unsigned PRECISION = 5;
static const bool COLLECT_STATISTICS = false;
const bool COLLECT_STATS = false;
static const bool EXPORT_TO_GRAPHVIZ = false;
static const unsigned EXPORT_MINIMAL_SIZE = 100;
static const double INF = -numeric_limits<double>::infinity();
namespace InfAlgorithms {
enum InfAlgs
{
VE, // variable elimination
BN_BP, // bayesian network belief propagation
FG_BP, // factor graph belief propagation
CBP // counting bp solver
};
extern InfAlgs infAlgorithm;
};

View File

@ -3,197 +3,89 @@
#include <iostream>
#include <sstream>
#include "BayesNet.h"
#include "FactorGraph.h"
#include "VarElimSolver.h"
#include "BnBpSolver.h"
#include "FgBpSolver.h"
#include "BpSolver.h"
#include "CbpSolver.h"
//#include "TinySet.h"
#include "LiftedUtils.h"
using namespace std;
void processArguments (BayesNet&, int, const char* []);
void processArguments (FactorGraph&, int, const char* []);
void runSolver (Solver*, const VarNodes&);
void runSolver (const FactorGraph&, const VarIds&);
const string USAGE = "usage: \
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
class Cenas
{
public:
Cenas (int cc)
{
c = cc;
}
//operator int (void) const
//{
// cout << "return int" << endl;
// return c;
//}
operator double (void) const
{
cout << "return double" << endl;
return 0.0;
}
private:
int c;
};
./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
int
main (int argc, const char* argv[])
{
LogVar X = 3;
LogVarSet Xs = X;
cout << "set: " << X << endl;
Cenas c1 (1);
Cenas c2 (3);
cout << (c1 < c2) << endl;
return 0;
if (!argv[1]) {
if (argc <= 1) {
cerr << "error: no solver specified" << endl;
cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl;
exit (0);
}
const string& fileName = argv[1];
const string& extension = fileName.substr (fileName.find_last_of ('.') + 1);
if (extension == "xml") {
BayesNet bn;
bn.readFromBifFormat (argv[1]);
processArguments (bn, argc, argv);
} else if (extension == "uai") {
FactorGraph fg;
fg.readFromUaiFormat (argv[1]);
processArguments (fg, argc, argv);
} else if (extension == "fg") {
FactorGraph fg;
fg.readFromLibDaiFormat (argv[1]);
processArguments (fg, argc, argv);
} else {
cerr << "error: the graphical model must be defined either " ;
cerr << "in a xml, uai or libDAI file" << endl;
if (argc <= 2) {
cerr << "error: no graphical model specified" << endl;
cerr << USAGE << endl;
exit (0);
}
string solver (argv[1]);
if (solver == "ve") {
Globals::infAlgorithm = InfAlgorithms::VE;
} else if (solver == "bp") {
Globals::infAlgorithm = InfAlgorithms::BP;
} else if (solver == "cbp") {
Globals::infAlgorithm = InfAlgorithms::CBP;
} else {
cerr << "error: unknow solver `" << solver << "'" << endl ;
cerr << USAGE << endl;
exit(0);
}
string fileName (argv[2]);
string extension = fileName.substr (
fileName.find_last_of ('.') + 1);
FactorGraph fg;
if (extension == "uai") {
fg.readFromUaiFormat (fileName.c_str());
} else if (extension == "fg") {
fg.readFromLibDaiFormat (fileName.c_str());
} else {
cerr << "error: the graphical model must be defined either " ;
cerr << "in a UAI or libDAI file" << endl;
exit (0);
}
processArguments (fg, argc, argv);
return 0;
}
void
processArguments (BayesNet& bn, int argc, const char* argv[])
{
VarNodes queryVars;
for (int i = 2; i < argc; i++) {
const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) {
BayesNode* queryVar = bn.getBayesNode (arg);
if (queryVar) {
queryVars.push_back (queryVar);
} else {
cerr << "error: there isn't a variable labeled of " ;
cerr << "`" << arg << "'" ;
cerr << endl;
bn.freeDistributions();
exit (0);
}
} else {
size_t pos = arg.find ('=');
const string& label = arg.substr (0, pos);
const string& state = arg.substr (pos + 1);
if (label.empty()) {
cerr << "error: missing left argument" << endl;
cerr << USAGE << endl;
bn.freeDistributions();
exit (0);
}
if (state.empty()) {
cerr << "error: missing right argument" << endl;
cerr << USAGE << endl;
bn.freeDistributions();
exit (0);
}
BayesNode* node = bn.getBayesNode (label);
if (node) {
if (node->isValidState (state)) {
node->setEvidence (state);
} else {
cerr << "error: `" << state << "' " ;
cerr << "is not a valid state for " ;
cerr << "`" << node->label() << "'" ;
cerr << endl;
bn.freeDistributions();
exit (0);
}
} else {
cerr << "error: there isn't a variable labeled of " ;
cerr << "`" << label << "'" ;
cerr << endl;
bn.freeDistributions();
exit (0);
}
}
}
Solver* solver = 0;
FactorGraph* fg = 0;
switch (InfAlgorithms::infAlgorithm) {
case InfAlgorithms::VE:
fg = new FactorGraph (bn);
solver = new VarElimSolver (*fg);
break;
case InfAlgorithms::BN_BP:
solver = new BnBpSolver (bn);
break;
case InfAlgorithms::FG_BP:
fg = new FactorGraph (bn);
solver = new FgBpSolver (*fg);
break;
case InfAlgorithms::CBP:
fg = new FactorGraph (bn);
solver = new CbpSolver (*fg);
break;
default:
assert (false);
}
runSolver (solver, queryVars);
delete fg;
bn.freeDistributions();
}
void
processArguments (FactorGraph& fg, int argc, const char* argv[])
{
VarNodes queryVars;
for (int i = 2; i < argc; i++) {
VarIds queryIds;
for (int i = 3; i < argc; i++) {
const string& arg = argv[i];
if (arg.find ('=') == std::string::npos) {
if (!Util::isInteger (arg)) {
cerr << "error: `" << arg << "' " ;
cerr << "is not a valid variable id" ;
cerr << endl;
fg.freeDistributions();
exit (0);
}
VarId vid;
stringstream ss;
ss << arg;
ss >> vid;
VarNode* queryVar = fg.getFgVarNode (vid);
VarNode* queryVar = fg.getVarNode (vid);
if (queryVar) {
queryVars.push_back (queryVar);
queryIds.push_back (vid);
} else {
cerr << "error: there isn't a variable with " ;
cerr << "`" << vid << "' as id" ;
cerr << endl;
fg.freeDistributions();
exit (0);
}
} else {
@ -201,33 +93,29 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
if (arg.substr (0, pos).empty()) {
cerr << "error: missing left argument" << endl;
cerr << USAGE << endl;
fg.freeDistributions();
exit (0);
}
if (arg.substr (pos + 1).empty()) {
cerr << "error: missing right argument" << endl;
cerr << USAGE << endl;
fg.freeDistributions();
exit (0);
}
if (!Util::isInteger (arg.substr (0, pos))) {
cerr << "error: `" << arg.substr (0, pos) << "' " ;
cerr << "is not a variable id" ;
cerr << endl;
fg.freeDistributions();
exit (0);
}
VarId vid;
stringstream ss;
ss << arg.substr (0, pos);
ss >> vid;
VarNode* var = fg.getFgVarNode (vid);
VarNode* var = fg.getVarNode (vid);
if (var) {
if (!Util::isInteger (arg.substr (pos + 1))) {
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
cerr << "is not a state index" ;
cerr << endl;
fg.freeDistributions();
exit (0);
}
int stateIndex;
@ -241,29 +129,31 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
cerr << "is not a valid state index for variable " ;
cerr << "`" << var->varId() << "'" ;
cerr << endl;
fg.freeDistributions();
exit (0);
}
} else {
cerr << "error: there isn't a variable with " ;
cerr << "`" << vid << "' as id" ;
cerr << endl;
fg.freeDistributions();
exit (0);
}
}
}
runSolver (fg, queryIds);
}
void
runSolver (const FactorGraph& fg, const VarIds& queryIds)
{
Solver* solver = 0;
switch (InfAlgorithms::infAlgorithm) {
switch (Globals::infAlgorithm) {
case InfAlgorithms::VE:
solver = new VarElimSolver (fg);
break;
case InfAlgorithms::BN_BP:
case InfAlgorithms::FG_BP:
//cout << "here!" << endl;
//fg.printGraphicalModel();
//fg.exportToLibDaiFormat ("net.fg");
solver = new FgBpSolver (fg);
case InfAlgorithms::BP:
solver = new BpSolver (fg);
break;
case InfAlgorithms::CBP:
solver = new CbpSolver (fg);
@ -271,28 +161,10 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
default:
assert (false);
}
runSolver (solver, queryVars);
fg.freeDistributions();
}
void
runSolver (Solver* solver, const VarNodes& queryVars)
{
VarIds vids;
for (unsigned i = 0; i < queryVars.size(); i++) {
vids.push_back (queryVars[i]->varId());
}
if (queryVars.size() == 0) {
solver->runSolver();
if (queryIds.size() == 0) {
solver->printAllPosterioris();
} else if (queryVars.size() == 1) {
solver->runSolver();
solver->printPosterioriOf (vids[0]);
} else {
solver->runSolver();
solver->printJointDistributionOf (vids);
solver->printAnswer (queryIds);
}
delete solver;
}

View File

@ -7,22 +7,50 @@
#include <YapInterface.h>
#include "BayesNet.h"
#include "ParfactorList.h"
#include "FactorGraph.h"
#include "FoveSolver.h"
#include "VarElimSolver.h"
#include "BnBpSolver.h"
#include "FgBpSolver.h"
#include "BpSolver.h"
#include "CbpSolver.h"
#include "ElimGraph.h"
#include "FoveSolver.h"
#include "ParfactorList.h"
#include "BayesBall.h"
using namespace std;
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
Params readParameters (YAP_Term);
vector<unsigned> readUnsignedList (YAP_Term);
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
Parfactor* readParfactor (YAP_Term);
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
vector<Params>& results);
vector<unsigned>
readUnsignedList (YAP_Term list)
{
vector<unsigned> vec;
while (list != YAP_TermNil()) {
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
list = YAP_TailOfTerm (list);
}
return vec;
}
Params readParams (YAP_Term);
int createLiftedNetwork (void)
@ -30,107 +58,121 @@ int createLiftedNetwork (void)
Parfactors parfactors;
YAP_Term parfactorList = YAP_ARG1;
while (parfactorList != YAP_TermNil()) {
YAP_Term parfactor = YAP_HeadOfTerm (parfactorList);
// read dist id
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, parfactor));
// read the ranges
Ranges ranges;
YAP_Term rangeList = YAP_ArgOfTerm (3, parfactor);
while (rangeList != YAP_TermNil()) {
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
ranges.push_back (range);
rangeList = YAP_TailOfTerm (rangeList);
}
// read parametric random vars
ProbFormulas formulas;
unsigned count = 0;
unordered_map<YAP_Term, LogVar> lvMap;
YAP_Term pvList = YAP_ArgOfTerm (2, parfactor);
while (pvList != YAP_TermNil()) {
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
if (YAP_IsAtomTerm (formulaTerm)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
Symbol functor = LiftedUtils::getSymbol (name);
formulas.push_back (ProbFormula (functor, ranges[count]));
} else {
LogVars logVars;
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
Symbol functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
if (it != lvMap.end()) {
logVars.push_back (it->second);
} else {
unsigned newLv = lvMap.size();
lvMap[ti] = newLv;
logVars.push_back (newLv);
}
}
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
}
count ++;
pvList = YAP_TailOfTerm (pvList);
}
// read the parameters
const Params& params = readParams (YAP_ArgOfTerm (4, parfactor));
// read the constraint
Tuples tuples;
if (lvMap.size() >= 1) {
YAP_Term tupleList = YAP_ArgOfTerm (5, parfactor);
while (tupleList != YAP_TermNil()) {
YAP_Term term = YAP_HeadOfTerm (tupleList);
assert (YAP_IsApplTerm (term));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
assert (lvMap.size() == arity);
Tuple tuple (arity);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, term);
if (YAP_IsAtomTerm (ti) == false) {
cerr << "error: bad formed constraint" << endl;
abort();
}
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
tuple[i - 1] = LiftedUtils::getSymbol (name);
}
tuples.push_back (tuple);
tupleList = YAP_TailOfTerm (tupleList);
}
}
parfactors.push_back (new Parfactor (formulas, params, tuples, distId));
YAP_Term pfTerm = YAP_HeadOfTerm (parfactorList);
parfactors.push_back (readParfactor (pfTerm));
parfactorList = YAP_TailOfTerm (parfactorList);
}
// LiftedUtils::printSymbolDictionary();
cout << "*******************************************************" << endl;
cout << "INITIAL PARFACTORS" << endl;
cout << "*******************************************************" << endl;
for (unsigned i = 0; i < parfactors.size(); i++) {
parfactors[i]->print();
cout << endl;
if (Constants::DEBUG > 2) {
// Util::printHeader ("INITIAL PARFACTORS");
// for (unsigned i = 0; i < parfactors.size(); i++) {
// parfactors[i]->print();
// }
}
ParfactorList* pfList = new ParfactorList();
for (unsigned i = 0; i < parfactors.size(); i++) {
pfList->add (parfactors[i]);
}
cout << endl;
cout << "*******************************************************" << endl;
cout << "SHATTERED PARFACTORS" << endl;
cout << "*******************************************************" << endl;
pfList->shatter();
pfList->print();
// insert the evidence
ObservedFormulas obsFormulas;
YAP_Term observedList = YAP_ARG2;
ParfactorList* pfList = new ParfactorList (parfactors);
if (Constants::DEBUG >= 2) {
Util::printHeader ("SHATTERED PARFACTORS");
pfList->print();
}
// read evidence
ObservedFormulas* obsFormulas = new ObservedFormulas();
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
YAP_Int p = (YAP_Int) (net);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
}
Parfactor* readParfactor (YAP_Term pfTerm)
{
// read dist id
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
// read the ranges
Ranges ranges;
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
while (rangeList != YAP_TermNil()) {
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
ranges.push_back (range);
rangeList = YAP_TailOfTerm (rangeList);
}
// read parametric random vars
ProbFormulas formulas;
unsigned count = 0;
unordered_map<YAP_Term, LogVar> lvMap;
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
while (pvList != YAP_TermNil()) {
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
if (YAP_IsAtomTerm (formulaTerm)) {
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
Symbol functor = LiftedUtils::getSymbol (name);
formulas.push_back (ProbFormula (functor, ranges[count]));
} else {
LogVars logVars;
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
Symbol functor = LiftedUtils::getSymbol (name);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
if (it != lvMap.end()) {
logVars.push_back (it->second);
} else {
unsigned newLv = lvMap.size();
lvMap[ti] = newLv;
logVars.push_back (newLv);
}
}
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
}
count ++;
pvList = YAP_TailOfTerm (pvList);
}
// read the parameters
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
// read the constraint
Tuples tuples;
if (lvMap.size() >= 1) {
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
while (tupleList != YAP_TermNil()) {
YAP_Term term = YAP_HeadOfTerm (tupleList);
assert (YAP_IsApplTerm (term));
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
assert (lvMap.size() == arity);
Tuple tuple (arity);
for (unsigned i = 1; i <= arity; i++) {
YAP_Term ti = YAP_ArgOfTerm (i, term);
if (YAP_IsAtomTerm (ti) == false) {
cerr << "error: constraint has free variables" << endl;
abort();
}
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
tuple[i - 1] = LiftedUtils::getSymbol (name);
}
tuples.push_back (tuple);
tupleList = YAP_TailOfTerm (tupleList);
}
}
return new Parfactor (formulas, params, tuples, distId);
}
void readLiftedEvidence (
YAP_Term observedList,
ObservedFormulas& obsFormulas)
{
while (observedList != YAP_TermNil()) {
YAP_Term pair = YAP_HeadOfTerm (observedList);
YAP_Term ground = YAP_ArgOfTerm (1, pair);
@ -155,22 +197,18 @@ int createLiftedNetwork (void)
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
bool found = false;
for (unsigned i = 0; i < obsFormulas.size(); i++) {
if (obsFormulas[i]->functor() == functor &&
obsFormulas[i]->arity() == args.size() &&
obsFormulas[i]->evidence() == evidence) {
obsFormulas[i]->addTuple (args);
if (obsFormulas[i].functor() == functor &&
obsFormulas[i].arity() == args.size() &&
obsFormulas[i].evidence() == evidence) {
obsFormulas[i].addTuple (args);
found = true;
}
}
if (found == false) {
obsFormulas.push_back (new ObservedFormula (functor, evidence, args));
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
}
observedList = YAP_TailOfTerm (observedList);
}
FoveSolver::absorveEvidence (*pfList, obsFormulas);
YAP_Int p = (YAP_Int) (pfList);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
}
}
@ -178,93 +216,46 @@ int createLiftedNetwork (void)
int
createGroundNetwork (void)
{
Statistics::incrementPrimaryNetworksCounting();
// cout << "creating network number " ;
// cout << Statistics::getPrimaryNetworksCounting() << endl;
// if (Statistics::getPrimaryNetworksCounting() > 98) {
// Statistics::writeStatisticsToFile ("../../compressing.stats");
// }
BayesNet* bn = new BayesNet();
YAP_Term varList = YAP_ARG1;
BnNodeSet nodes;
vector<VarIds> parents;
while (varList != YAP_TermNil()) {
YAP_Term var = YAP_HeadOfTerm (varList);
VarId vid = (VarId) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var));
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var));
YAP_Term parentL = YAP_ArgOfTerm (4, var);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var));
parents.push_back (VarIds());
while (parentL != YAP_TermNil()) {
unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL));
parents.back().push_back (parentId);
parentL = YAP_TailOfTerm (parentL);
}
Distribution* dist = bn->getDistribution (distId);
if (!dist) {
dist = new Distribution (distId);
bn->addDistribution (dist);
}
assert (bn->getBayesNode (vid) == 0);
nodes.push_back (bn->addNode (vid, dsize, evidence, dist));
varList = YAP_TailOfTerm (varList);
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
bool fromBayesNet = factorsType == "bayes";
FactorGraph* fg = new FactorGraph (fromBayesNet);
YAP_Term factorList = YAP_ARG2;
while (factorList != YAP_TermNil()) {
YAP_Term factor = YAP_HeadOfTerm (factorList);
// read the var ids
VarIds varIds = readUnsignedList (YAP_ArgOfTerm (1, factor));
// read the ranges
Ranges ranges = readUnsignedList (YAP_ArgOfTerm (2, factor));
// read the parameters
Params params = readParameters (YAP_ArgOfTerm (3, factor));
// read dist id
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (4, factor));
fg->addFactor (Factor (varIds, ranges, params, distId));
factorList = YAP_TailOfTerm (factorList);
}
for (unsigned i = 0; i < nodes.size(); i++) {
BnNodeSet ps;
for (unsigned j = 0; j < parents[i].size(); j++) {
assert (bn->getBayesNode (parents[i][j]) != 0);
ps.push_back (bn->getBayesNode (parents[i][j]));
}
nodes[i]->setParents (ps);
YAP_Term evidenceList = YAP_ARG3;
while (evidenceList != YAP_TermNil()) {
YAP_Term evTerm = YAP_HeadOfTerm (evidenceList);
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
assert (fg->getVarNode (vid));
fg->getVarNode (vid)->setEvidence (ev);
evidenceList = YAP_TailOfTerm (evidenceList);
}
bn->setIndexes();
YAP_Int p = (YAP_Int) (bn);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2);
}
int
setBayesNetParams (void)
{
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distList = YAP_ARG2;
while (distList != YAP_TermNil()) {
YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
const Params params = readParams (YAP_ArgOfTerm (2, dist));
bn->getDistribution(distId)->updateParameters (params);
distList = YAP_TailOfTerm (distList);
}
return TRUE;
}
int
setParfactorGraphParams (void)
{
// FIXME
// ParfactorGraph* pfg = (ParfactorGraph*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distList = YAP_ARG2;
while (distList != YAP_TermNil()) {
// YAP_Term dist = YAP_HeadOfTerm (distList);
// unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
// const Params params = readParams (YAP_ArgOfTerm (2, dist));
// pfg->getDistribution(distId)->setData (params);
distList = YAP_TailOfTerm (distList);
}
return TRUE;
YAP_Int p = (YAP_Int) (fg);
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG4);
}
Params
readParams (YAP_Term paramL)
readParameters (YAP_Term paramL)
{
Params params;
while (paramL!= YAP_TermNil()) {
assert (YAP_IsPairTerm (paramL));
while (paramL != YAP_TermNil()) {
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
paramL = YAP_TailOfTerm (paramL);
}
@ -279,15 +270,14 @@ readParams (YAP_Term paramL)
int
runLiftedSolver (void)
{
ParfactorList* pfList = (ParfactorList*) YAP_IntOfTerm (YAP_ARG1);
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term taskList = YAP_ARG2;
vector<Params> results;
ParfactorList pfListCopy (*network->first);
FoveSolver::absorveEvidence (pfListCopy, *network->second);
while (taskList != YAP_TermNil()) {
YAP_Term jointList = YAP_HeadOfTerm (taskList);
Grounds queryVars;
assert (YAP_IsPairTerm (taskList));
assert (YAP_IsPairTerm (jointList));
YAP_Term jointList = YAP_HeadOfTerm (taskList);
while (jointList != YAP_TermNil()) {
YAP_Term ground = YAP_HeadOfTerm (jointList);
if (YAP_IsAtomTerm (ground)) {
@ -310,11 +300,11 @@ runLiftedSolver (void)
}
jointList = YAP_TailOfTerm (jointList);
}
FoveSolver solver (pfList);
FoveSolver solver (pfListCopy);
if (queryVars.size() == 1) {
results.push_back (solver.getPosterioriOf (queryVars[0]));
} else {
assert (false); // TODO joint dist
results.push_back (solver.getJointDistributionOf (queryVars));
}
taskList = YAP_TailOfTerm (taskList);
}
@ -339,77 +329,23 @@ runLiftedSolver (void)
int
runOtherSolvers (void)
runGroundSolver (void)
{
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term taskList = YAP_ARG2;
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
vector<VarIds> tasks;
std::set<VarId> vids;
YAP_Term taskList = YAP_ARG2;
while (taskList != YAP_TermNil()) {
if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) {
tasks.push_back (VarIds());
YAP_Term jointList = YAP_HeadOfTerm (taskList);
while (jointList != YAP_TermNil()) {
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
assert (bn->getBayesNode (vid));
tasks.back().push_back (vid);
vids.insert (vid);
jointList = YAP_TailOfTerm (jointList);
}
} else {
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
assert (bn->getBayesNode (vid));
tasks.push_back (VarIds() = {vid});
vids.insert (vid);
}
tasks.push_back (readUnsignedList (YAP_HeadOfTerm (taskList)));
taskList = YAP_TailOfTerm (taskList);
}
Solver* bpSolver = 0;
GraphicalModel* graphicalModel = 0;
CFactorGraph::checkForIdenticalFactors = false;
if (InfAlgorithms::infAlgorithm != InfAlgorithms::VE) {
BayesNet* mrn = bn->getMinimalRequesiteNetwork (
VarIds (vids.begin(), vids.end()));
if (InfAlgorithms::infAlgorithm == InfAlgorithms::BN_BP) {
graphicalModel = mrn;
bpSolver = new BnBpSolver (*static_cast<BayesNet*> (graphicalModel));
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::FG_BP) {
graphicalModel = new FactorGraph (*mrn);
bpSolver = new FgBpSolver (*static_cast<FactorGraph*> (graphicalModel));
delete mrn;
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::CBP) {
graphicalModel = new FactorGraph (*mrn);
bpSolver = new CbpSolver (*static_cast<FactorGraph*> (graphicalModel));
delete mrn;
}
bpSolver->runSolver();
}
vector<Params> results;
results.reserve (tasks.size());
for (unsigned i = 0; i < tasks.size(); i++) {
//if (i == 1) exit (0);
if (InfAlgorithms::infAlgorithm == InfAlgorithms::VE) {
BayesNet* mrn = bn->getMinimalRequesiteNetwork (tasks[i]);
VarElimSolver* veSolver = new VarElimSolver (*mrn);
if (tasks[i].size() == 1) {
results.push_back (veSolver->getPosterioriOf (tasks[i][0]));
} else {
results.push_back (veSolver->getJointDistributionOf (tasks[i]));
}
delete mrn;
delete veSolver;
} else {
if (tasks[i].size() == 1) {
results.push_back (bpSolver->getPosterioriOf (tasks[i][0]));
} else {
results.push_back (bpSolver->getJointDistributionOf (tasks[i]));
}
}
if (Globals::infAlgorithm == InfAlgorithms::VE) {
runVeSolver (fg, tasks, results);
} else {
runBpSolver (fg, tasks, results);
}
delete bpSolver;
delete graphicalModel;
YAP_Term list = YAP_TermNil();
for (int i = results.size() - 1; i >= 0; i--) {
@ -424,32 +360,142 @@ runOtherSolvers (void)
}
list = YAP_MkPairTerm (queryBeliefsL, list);
}
return YAP_Unify (list, YAP_ARG3);
}
int
setExtraVarsInfo (void)
void runVeSolver (
FactorGraph* fg,
const vector<VarIds>& tasks,
vector<Params>& results)
{
// BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
GraphicalModel::clearVariablesInformation();
YAP_Term varsInfoL = YAP_ARG2;
while (varsInfoL != YAP_TermNil()) {
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
VarId vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
YAP_Term statesL = YAP_ArgOfTerm (3, head);
States states;
while (statesL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (statesL));
states.push_back ((char*) YAP_AtomName (atom));
statesL = YAP_TailOfTerm (statesL);
results.reserve (tasks.size());
for (unsigned i = 0; i < tasks.size(); i++) {
FactorGraph* mfg = fg;
if (fg->isFromBayesNetwork()) {
mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
}
GraphicalModel::addVariableInformation (vid,
(char*) YAP_AtomName (label), states);
varsInfoL = YAP_TailOfTerm (varsInfoL);
VarElimSolver solver (*mfg);
results.push_back (solver.solveQuery (tasks[i]));
if (fg->isFromBayesNetwork()) {
delete mfg;
}
}
}
void runBpSolver (
FactorGraph* fg,
const vector<VarIds>& tasks,
vector<Params>& results)
{
std::set<VarId> vids;
for (unsigned i = 0; i < tasks.size(); i++) {
Util::addToSet (vids, tasks[i]);
}
Solver* solver = 0;
FactorGraph* mfg = fg;
if (fg->isFromBayesNetwork()) {
mfg = BayesBall::getMinimalFactorGraph (
*fg, VarIds (vids.begin(),vids.end()));
}
if (Globals::infAlgorithm == InfAlgorithms::BP) {
solver = new BpSolver (*mfg);
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
CFactorGraph::checkForIdenticalFactors = false;
solver = new CbpSolver (*mfg);
} else {
cerr << "error: unknow solver" << endl;
abort();
}
results.reserve (tasks.size());
for (unsigned i = 0; i < tasks.size(); i++) {
results.push_back (solver->solveQuery (tasks[i]));
}
if (fg->isFromBayesNetwork()) {
delete mfg;
}
delete solver;
}
int
setParfactorsParams (void)
{
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
ParfactorList* pfList = network->first;
YAP_Term distList = YAP_ARG2;
unordered_map<unsigned, Params> paramsMap;
while (distList != YAP_TermNil()) {
YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
assert (Util::contains (paramsMap, distId) == false);
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
distList = YAP_TailOfTerm (distList);
}
ParfactorList::iterator it = pfList->begin();
while (it != pfList->end()) {
assert (Util::contains (paramsMap, (*it)->distId()));
// (*it)->setParams (paramsMap[(*it)->distId()]);
++ it;
}
return TRUE;
}
int
setFactorsParams (void)
{
return TRUE; // TODO
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
YAP_Term distList = YAP_ARG2;
unordered_map<unsigned, Params> paramsMap;
while (distList != YAP_TermNil()) {
YAP_Term dist = YAP_HeadOfTerm (distList);
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
assert (Util::contains (paramsMap, distId) == false);
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
distList = YAP_TailOfTerm (distList);
}
const FacNodes& facNodes = fg->facNodes();
for (unsigned i = 0; i < facNodes.size(); i++) {
unsigned distId = facNodes[i]->factor().distId();
assert (Util::contains (paramsMap, distId));
facNodes[i]->factor().setParams (paramsMap[distId]);
}
return TRUE;
}
int
setVarsInformation (void)
{
Var::clearVarsInfo();
YAP_Term labelsL = YAP_ARG1;
vector<string> labels;
while (labelsL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
labels.push_back ((char*) YAP_AtomName (atom));
labelsL = YAP_TailOfTerm (labelsL);
}
unsigned count = 0;
YAP_Term stateNamesL = YAP_ARG2;
while (stateNamesL != YAP_TermNil()) {
States states;
YAP_Term namesL = YAP_HeadOfTerm (stateNamesL);
while (namesL != YAP_TermNil()) {
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (namesL));
states.push_back ((char*) YAP_AtomName (atom));
namesL = YAP_TailOfTerm (namesL);
}
Var::addVarInfo (count, labels[count], states);
count ++;
stateNamesL = YAP_TailOfTerm (stateNamesL);
}
return TRUE;
}
@ -463,13 +509,11 @@ setHorusFlag (void)
if (key == "inf_alg") {
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
if ( value == "ve") {
InfAlgorithms::infAlgorithm = InfAlgorithms::VE;
} else if (value == "bn_bp") {
InfAlgorithms::infAlgorithm = InfAlgorithms::BN_BP;
} else if (value == "fg_bp") {
InfAlgorithms::infAlgorithm = InfAlgorithms::FG_BP;
Globals::infAlgorithm = InfAlgorithms::VE;
} else if (value == "bp") {
Globals::infAlgorithm = InfAlgorithms::BP;
} else if (value == "cbp") {
InfAlgorithms::infAlgorithm = InfAlgorithms::CBP;
Globals::infAlgorithm = InfAlgorithms::CBP;
} else {
cerr << "warning: invalid value `" << value << "' " ;
cerr << "for `" << key << "'" << endl;
@ -541,21 +585,21 @@ setHorusFlag (void)
int
freeBayesNetwork (void)
freeGroundNetwork (void)
{
//Statistics::writeStatisticsToFile ("stats.txt");
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
bn->freeDistributions();
delete bn;
delete (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
return TRUE;
}
int
freeParfactorGraph (void)
freeParfactors (void)
{
delete (ParfactorList*) YAP_IntOfTerm (YAP_ARG1);
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
delete network->first;
delete network->second;
delete network;
return TRUE;
}
@ -564,15 +608,15 @@ freeParfactorGraph (void)
extern "C" void
init_predicates (void)
{
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
YAP_UserCPredicate ("set_parfactor_graph_params", setParfactorGraphParams, 2);
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
YAP_UserCPredicate ("run_other_solvers", runOtherSolvers, 3);
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
YAP_UserCPredicate ("free_parfactor_graph", freeParfactorGraph, 1);
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 4);
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
YAP_UserCPredicate ("set_factors_params", setFactorsParams, 2);
YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2);
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
YAP_UserCPredicate ("free_ground_network", freeGroundNetwork, 1);
}

View File

@ -8,11 +8,13 @@
#include <sstream>
#include <iomanip>
#include "VarNode.h"
#include "Var.h"
#include "Util.h"
class StatesIndexer {
class StatesIndexer
{
public:
StatesIndexer (const Ranges& ranges, bool calcOffsets = true)
@ -29,14 +31,14 @@ class StatesIndexer {
}
}
StatesIndexer (const VarNodes& vars, bool calcOffsets = true)
StatesIndexer (const Vars& vars, bool calcOffsets = true)
{
size_ = 1;
indices_.resize (vars.size(), 0);
ranges_.reserve (vars.size());
for (unsigned i = 0; i < vars.size(); i++) {
ranges_.push_back (vars[i]->nrStates());
size_ *= vars[i]->nrStates();
ranges_.push_back (vars[i]->range());
size_ *= vars[i]->range();
}
li_ = 0;
if (calcOffsets) {
@ -134,11 +136,11 @@ class StatesIndexer {
return size_ ;
}
friend ostream& operator<< (ostream &out, const StatesIndexer& idx)
friend ostream& operator<< (ostream &os, const StatesIndexer& idx)
{
out << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
out << idx.indices_;
return out;
os << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
os << idx.indices_;
return os;
}
private:
@ -274,21 +276,14 @@ class MapIndexer
index_ = 0;
}
friend ostream& operator<< (ostream &out, const MapIndexer& idx)
friend ostream& operator<< (ostream &os, const MapIndexer& idx)
{
out << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
out << idx.indices_;
return out;
os << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
os << idx.indices_;
return os;
}
private:
MapIndexer (const Ranges& ranges) :
ranges_(ranges),
indices_(ranges.size(), 0),
offsets_(ranges.size())
{
index_ = 0;
}
unsigned index_;
bool valid_;
vector<unsigned> ranges_;

View File

@ -95,26 +95,37 @@ ostream& operator<< (ostream &os, const Ground& gr)
void
ObservedFormula::addTuple (const Tuple& t)
LogVars
Substitution::getDiscardedLogVars (void) const
{
if (constr_ == 0) {
LogVars lvs (arity_);
for (unsigned i = 0; i < arity_; i++) {
lvs[i] = i;
LogVars discardedLvs;
set<LogVar> doneLvs;
unordered_map<LogVar, LogVar>::const_iterator it;
it = subs_.begin();
while (it != subs_.end()) {
if (Util::contains (doneLvs, it->second)) {
discardedLvs.push_back (it->first);
} else {
doneLvs.insert (it->second);
}
constr_ = new ConstraintTree (lvs);
it ++;
}
constr_->addTuple (t);
return discardedLvs;
}
ostream& operator<< (ostream &os, const ObservedFormula of)
ostream& operator<< (ostream &os, const Substitution& theta)
{
os << of.functor_ << "/" << of.arity_;
os << "|" << of.constr_->tupleSet();
os << " [evidence=" << of.evidence_ << "]";
unordered_map<LogVar, LogVar>::const_iterator it;
os << "[" ;
it = theta.subs_.begin();
while (it != theta.subs_.end()) {
if (it != theta.subs_.begin()) os << ", " ;
os << it->first << "->" << it->second ;
++ it;
}
os << "]" ;
return os;
}

View File

@ -18,11 +18,17 @@ class Symbol
{
public:
Symbol (void) : id_(numeric_limits<unsigned>::max()) { }
Symbol (unsigned id) : id_(id) { }
operator unsigned (void) const { return id_; }
bool valid (void) const { return id_ != numeric_limits<unsigned>::max(); }
static Symbol invalid (void) { return Symbol(); }
friend ostream& operator<< (ostream &os, const Symbol& s);
private:
unsigned id_;
};
@ -32,7 +38,9 @@ class LogVar
{
public:
LogVar (void) : id_(numeric_limits<unsigned>::max()) { }
LogVar (unsigned id) : id_(id) { }
operator unsigned (void) const { return id_; }
LogVar& operator++ (void)
@ -48,6 +56,7 @@ class LogVar
}
friend ostream& operator<< (ostream &os, const LogVar& X);
private:
unsigned id_;
};
@ -79,8 +88,8 @@ ostream& operator<< (ostream &os, const Tuple& t);
namespace LiftedUtils {
Symbol getSymbol (const string&);
void printSymbolDictionary (void);
Symbol getSymbol (const string&);
void printSymbolDictionary (void);
}
@ -89,71 +98,56 @@ class Ground
{
public:
Ground (Symbol f) : functor_(f) { }
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { }
Symbol functor (void) const { return functor_; }
Symbols args (void) const { return args_; }
unsigned arity (void) const { return args_.size(); }
bool isAtom (void) const { return args_.size() == 0; }
Symbol functor (void) const { return functor_; }
Symbols args (void) const { return args_; }
unsigned arity (void) const { return args_.size(); }
bool isAtom (void) const { return args_.size() == 0; }
friend ostream& operator<< (ostream &os, const Ground& gr);
private:
Symbol functor_;
Symbols args_;
Symbol functor_;
Symbols args_;
};
typedef vector<Ground> Grounds;
class ConstraintTree;
class ObservedFormula
{
public:
ObservedFormula (Symbol f, unsigned a, unsigned ev)
: functor_(f), arity_(a), evidence_(ev), constr_(0) { }
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(0)
{
addTuple (tuple);
}
Symbol functor (void) const { return functor_; }
unsigned arity (void) const { return arity_; }
unsigned evidence (void) const { return evidence_; }
ConstraintTree* constr (void) const { return constr_; }
bool isAtom (void) const { return arity_ == 0; }
void addTuple (const Tuple& t);
friend ostream& operator<< (ostream &os, const ObservedFormula opv);
private:
Symbol functor_;
unsigned arity_;
unsigned evidence_;
ConstraintTree* constr_;
};
typedef vector<ObservedFormula*> ObservedFormulas;
class Substitution
{
public:
void add (LogVar X_old, LogVar X_new)
{
assert (Util::contains (subs_, X_old) == false);
subs_.insert (make_pair (X_old, X_new));
}
void rename (LogVar X_old, LogVar X_new)
{
assert (subs_.find (X_old) != subs_.end());
assert (Util::contains (subs_, X_old));
subs_.find (X_old)->second = X_new;
}
LogVar newNameFor (LogVar X) const
{
assert (subs_.find (X) != subs_.end());
assert (Util::contains (subs_, X));
return subs_.find (X)->second;
}
LogVars getDiscardedLogVars (void) const;
friend ostream& operator<< (ostream &os, const Substitution& theta);
private:
unordered_map<LogVar, LogVar> subs_;
};

View File

@ -45,9 +45,8 @@ CWD=$(PWD)
HEADERS = \
$(srcdir)/GraphicalModel.h \
$(srcdir)/BayesNet.h \
$(srcdir)/BayesNode.h \
$(srcdir)/BayesBall.h \
$(srcdir)/ElimGraph.h \
$(srcdir)/FactorGraph.h \
$(srcdir)/Factor.h \
@ -55,12 +54,10 @@ HEADERS = \
$(srcdir)/ConstraintTree.h \
$(srcdir)/Solver.h \
$(srcdir)/VarElimSolver.h \
$(srcdir)/BnBpSolver.h \
$(srcdir)/FgBpSolver.h \
$(srcdir)/BpSolver.h \
$(srcdir)/CbpSolver.h \
$(srcdir)/FoveSolver.h \
$(srcdir)/VarNode.h \
$(srcdir)/Distribution.h \
$(srcdir)/Var.h \
$(srcdir)/Indexer.h \
$(srcdir)/Parfactor.h \
$(srcdir)/ProbFormula.h \
@ -69,22 +66,20 @@ HEADERS = \
$(srcdir)/LiftedUtils.h \
$(srcdir)/TinySet.h \
$(srcdir)/Util.h \
$(srcdir)/Horus.h \
$(srcdir)/xmlParser/xmlParser.h
$(srcdir)/Horus.h
CPP_SOURCES = \
$(srcdir)/BayesNet.cpp \
$(srcdir)/BayesNode.cpp \
$(srcdir)/BayesBall.cpp \
$(srcdir)/ElimGraph.cpp \
$(srcdir)/FactorGraph.cpp \
$(srcdir)/Factor.cpp \
$(srcdir)/CFactorGraph.cpp \
$(srcdir)/ConstraintTree.cpp \
$(srcdir)/VarNode.cpp \
$(srcdir)/Var.cpp \
$(srcdir)/Solver.cpp \
$(srcdir)/VarElimSolver.cpp \
$(srcdir)/BnBpSolver.cpp \
$(srcdir)/FgBpSolver.cpp \
$(srcdir)/BpSolver.cpp \
$(srcdir)/CbpSolver.cpp \
$(srcdir)/FoveSolver.cpp \
$(srcdir)/Parfactor.cpp \
@ -94,22 +89,20 @@ CPP_SOURCES = \
$(srcdir)/LiftedUtils.cpp \
$(srcdir)/Util.cpp \
$(srcdir)/HorusYap.cpp \
$(srcdir)/HorusCli.cpp \
$(srcdir)/xmlParser/xmlParser.cpp
$(srcdir)/HorusCli.cpp
OBJS = \
BayesNet.o \
BayesNode.o \
BayesBall.o \
ElimGraph.o \
FactorGraph.o \
Factor.o \
CFactorGraph.o \
ConstraintTree.o \
VarNode.o \
Var.o \
Solver.o \
VarElimSolver.o \
BnBpSolver.o \
FgBpSolver.o \
BpSolver.o \
CbpSolver.o \
FoveSolver.o \
Parfactor.o \
@ -122,17 +115,16 @@ OBJS = \
HCLI_OBJS = \
BayesNet.o \
BayesNode.o \
BayesBall.o \
ElimGraph.o \
FactorGraph.o \
Factor.o \
CFactorGraph.o \
ConstraintTree.o \
VarNode.o \
Var.o \
Solver.o \
VarElimSolver.o \
BnBpSolver.o \
FgBpSolver.o \
BpSolver.o \
CbpSolver.o \
FoveSolver.o \
Parfactor.o \
@ -141,7 +133,6 @@ HCLI_OBJS = \
ParfactorList.o \
LiftedUtils.o \
Util.o \
xmlParser/xmlParser.o \
HorusCli.o
SOBJS=horus.@SO@
@ -154,10 +145,6 @@ all: $(SOBJS) hcli
$(CXX) -c $(CXXFLAGS) $< -o $@
xmlParser/xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
$(CXX) -c $(CXXFLAGS) $< -o $@
@DO_SECOND_LD@horus.@SO@: $(OBJS)
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
@ -171,7 +158,7 @@ install: all
clean:
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK hcli xmlParser/*.o
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK hcli
erase_dots:

View File

@ -2,6 +2,7 @@
#include "Parfactor.h"
#include "Histogram.h"
#include "Indexer.h"
#include "Util.h"
#include "Horus.h"
@ -11,55 +12,58 @@ Parfactor::Parfactor (
const Tuples& tuples,
unsigned distId)
{
formulas_ = formulas;
params_ = params;
distId_ = distId;
args_ = formulas;
params_ = params;
distId_ = distId;
LogVars logVars;
for (unsigned i = 0; i < formulas_.size(); i++) {
ranges_.push_back (formulas_[i].range());
const LogVars& lvs = formulas_[i].logVars();
for (unsigned i = 0; i < args_.size(); i++) {
ranges_.push_back (args_[i].range());
const LogVars& lvs = args_[i].logVars();
for (unsigned j = 0; j < lvs.size(); j++) {
if (std::find (logVars.begin(), logVars.end(), lvs[j]) ==
logVars.end()) {
if (Util::contains (logVars, lvs[j]) == false) {
logVars.push_back (lvs[j]);
}
}
}
constr_ = new ConstraintTree (logVars, tuples);
assert (params_.size() == Util::expectedSize (ranges_));
}
Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
{
formulas_ = g->formulas();
params_ = g->params();
ranges_ = g->ranges();
distId_ = g->distId();
constr_ = new ConstraintTree (g->logVars(), {tuple});
args_ = g->arguments();
params_ = g->params();
ranges_ = g->ranges();
distId_ = g->distId();
constr_ = new ConstraintTree (g->logVars(), {tuple});
assert (params_.size() == Util::expectedSize (ranges_));
}
Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
{
formulas_ = g->formulas();
params_ = g->params();
ranges_ = g->ranges();
distId_ = g->distId();
constr_ = constr;
args_ = g->arguments();
params_ = g->params();
ranges_ = g->ranges();
distId_ = g->distId();
constr_ = constr;
assert (params_.size() == Util::expectedSize (ranges_));
}
Parfactor::Parfactor (const Parfactor& g)
{
formulas_ = g.formulas();
params_ = g.params();
ranges_ = g.ranges();
distId_ = g.distId();
constr_ = new ConstraintTree (*g.constr());
args_ = g.arguments();
params_ = g.params();
ranges_ = g.ranges();
distId_ = g.distId();
constr_ = new ConstraintTree (*g.constr());
assert (params_.size() == Util::expectedSize (ranges_));
}
@ -75,9 +79,9 @@ LogVarSet
Parfactor::countedLogVars (void) const
{
LogVarSet set;
for (unsigned i = 0; i < formulas_.size(); i++) {
if (formulas_[i].isCounting()) {
set.insert (formulas_[i].countedLogVar());
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].isCounting()) {
set.insert (args_[i].countedLogVar());
}
}
return set;
@ -107,14 +111,14 @@ Parfactor::elimLogVars (void) const
LogVarSet
Parfactor::exclusiveLogVars (unsigned fIdx) const
{
assert (fIdx < formulas_.size());
assert (fIdx < args_.size());
LogVarSet remaining;
for (unsigned i = 0; i < formulas_.size(); i++) {
for (unsigned i = 0; i < args_.size(); i++) {
if (i != fIdx) {
remaining |= formulas_[i].logVarSet();
remaining |= args_[i].logVarSet();
}
}
return formulas_[fIdx].logVarSet() - remaining;
return args_[fIdx].logVarSet() - remaining;
}
@ -131,44 +135,51 @@ Parfactor::setConstraintTree (ConstraintTree* newTree)
void
Parfactor::sumOut (unsigned fIdx)
{
assert (fIdx < formulas_.size());
assert (formulas_[fIdx].contains (elimLogVars()));
assert (fIdx < args_.size());
assert (args_[fIdx].contains (elimLogVars()));
LogVarSet excl = exclusiveLogVars (fIdx);
unsigned condCount = constr_->getConditionalCount (excl);
Util::pow (params_, condCount);
if (args_[fIdx].isCounting()) {
LogAware::pow (params_, constr_->getConditionalCount (
excl - args_[fIdx].countedLogVar()));
} else {
LogAware::pow (params_, constr_->getConditionalCount (excl));
}
vector<unsigned> numAssigns (ranges_[fIdx], 1);
if (formulas_[fIdx].isCounting()) {
if (args_[fIdx].isCounting()) {
unsigned N = constr_->getConditionalCount (
formulas_[fIdx].countedLogVar());
unsigned R = formulas_[fIdx].range();
unsigned H = ranges_[fIdx];
HistogramSet hs (N, R);
unsigned N_factorial = Util::factorial (N);
for (unsigned h = 0; h < H; h++) {
unsigned prod = 1;
for (unsigned r = 0; r < R; r++) {
prod *= Util::factorial (hs[r]);
args_[fIdx].countedLogVar());
unsigned R = args_[fIdx].range();
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
StatesIndexer sindexer (ranges_, fIdx);
while (sindexer.valid()) {
unsigned h = sindexer[fIdx];
if (Globals::logDomain) {
params_[sindexer] += numAssigns[h];
} else {
params_[sindexer] *= numAssigns[h];
}
numAssigns[h] = N_factorial / prod;
hs.nextHistogram();
++ sindexer;
}
cout << endl;
}
Params copy = params_;
params_.clear();
params_.resize (copy.size() / ranges_[fIdx], 0.0);
params_.resize (copy.size() / ranges_[fIdx], LogAware::addIdenty());
MapIndexer indexer (ranges_, fIdx);
for (unsigned i = 0; i < copy.size(); i++) {
unsigned h = indexer[fIdx];
// TODO NOT LOG DOMAIN AWARE :(
params_[indexer] += numAssigns[h] * copy[i];
++ indexer;
if (Globals::logDomain) {
for (unsigned i = 0; i < copy.size(); i++) {
params_[indexer] = Util::logSum (params_[indexer], copy[i]);
++ indexer;
}
} else {
for (unsigned i = 0; i < copy.size(); i++) {
params_[indexer] += copy[i];
++ indexer;
}
}
formulas_.erase (formulas_.begin() + fIdx);
args_.erase (args_.begin() + fIdx);
ranges_.erase (ranges_.begin() + fIdx);
constr_->remove (excl);
}
@ -179,55 +190,7 @@ void
Parfactor::multiply (Parfactor& g)
{
alignAndExponentiate (this, &g);
bool sharedVars = false;
vector<unsigned> g_varpos;
const ProbFormulas& g_formulas = g.formulas();
const Params& g_params = g.params();
const Ranges& g_ranges = g.ranges();
for (unsigned i = 0; i < g_formulas.size(); i++) {
int group = g_formulas[i].group();
if (indexOfFormulaWithGroup (group) == -1) {
insertDimension (g.ranges()[i]);
formulas_.push_back (g_formulas[i]);
g_varpos.push_back (formulas_.size() - 1);
} else {
sharedVars = true;
g_varpos.push_back (indexOfFormulaWithGroup (group));
}
}
if (sharedVars == false) {
unsigned count = 0;
for (unsigned i = 0; i < params_.size(); i++) {
if (Globals::logDomain) {
params_[i] += g_params[count];
} else {
params_[i] *= g_params[count];
}
count ++;
if (count >= g_params.size()) {
count = 0;
}
}
} else {
StatesIndexer indexer (ranges_, false);
while (indexer.valid()) {
unsigned g_li = 0;
unsigned prod = 1;
for (int j = g_varpos.size() - 1; j >= 0; j--) {
g_li += indexer[g_varpos[j]] * prod;
prod *= g_ranges[j];
}
if (Globals::logDomain) {
params_[indexer] += g_params[g_li];
} else {
params_[indexer] *= g_params[g_li];
}
++ indexer;
}
}
TFactor<ProbFormula>::multiply (g);
constr_->join (g.constr(), true);
}
@ -236,7 +199,7 @@ Parfactor::multiply (Parfactor& g)
void
Parfactor::countConvert (LogVar X)
{
int fIdx = indexOfFormulaWithLogVar (X);
int fIdx = indexOfLogVar (X);
assert (fIdx != -1);
assert (constr_->isCountNormalized (X));
assert (constr_->getConditionalCount (X) > 1);
@ -248,12 +211,12 @@ Parfactor::countConvert (LogVar X)
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
StatesIndexer indexer (ranges_);
vector<Params> summout (params_.size() / R);
vector<Params> sumout (params_.size() / R);
unsigned count = 0;
while (indexer.valid()) {
summout[count].reserve (R);
sumout[count].reserve (R);
for (unsigned r = 0; r < R; r++) {
summout[count].push_back (params_[indexer]);
sumout[count].push_back (params_[indexer]);
indexer.increment (fIdx);
}
count ++;
@ -262,45 +225,42 @@ Parfactor::countConvert (LogVar X)
}
params_.clear();
params_.reserve (summout.size() * H);
params_.reserve (sumout.size() * H);
vector<bool> mapDims (ranges_.size(), true);
ranges_[fIdx] = H;
mapDims[fIdx] = false;
MapIndexer mapIndexer (ranges_, mapDims);
MapIndexer mapIndexer (ranges_, fIdx);
while (mapIndexer.valid()) {
double prod = 1.0;
double prod = LogAware::multIdenty();
unsigned i = mapIndexer.mappedIndex();
unsigned h = mapIndexer[fIdx];
for (unsigned r = 0; r < R; r++) {
// TODO not log domain aware
prod *= Util::pow (summout[i][r], histograms[h][r]);
if (Globals::logDomain) {
prod += LogAware::pow (sumout[i][r], histograms[h][r]);
} else {
prod *= LogAware::pow (sumout[i][r], histograms[h][r]);
}
}
params_.push_back (prod);
++ mapIndexer;
}
formulas_[fIdx].setCountedLogVar (X);
args_[fIdx].setCountedLogVar (X);
}
void
Parfactor::expandPotential (
LogVar X,
LogVar X_new1,
LogVar X_new2)
Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
{
int fIdx = indexOfFormulaWithLogVar (X);
int fIdx = indexOfLogVar (X);
assert (fIdx != -1);
assert (formulas_[fIdx].isCounting());
assert (args_[fIdx].isCounting());
unsigned N1 = constr_->getConditionalCount (X_new1);
unsigned N2 = constr_->getConditionalCount (X_new2);
unsigned N = N1 + N2;
unsigned R = formulas_[fIdx].range();
unsigned R = args_[fIdx].range();
unsigned H1 = HistogramSet::nrHistograms (N1, R);
unsigned H2 = HistogramSet::nrHistograms (N2, R);
unsigned H = ranges_[fIdx];
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
@ -320,48 +280,11 @@ Parfactor::expandPotential (
}
}
unsigned size = (params_.size() / H) * H1 * H2;
Params copy = params_;
params_.clear();
params_.reserve (size);
expandPotential (fIdx, H1 * H2, sumIndexes);
unsigned prod = 1;
vector<unsigned> offsets_ (ranges_.size());
for (int i = ranges_.size() - 1; i >= 0; i--) {
offsets_[i] = prod;
prod *= ranges_[i];
}
unsigned index = 0;
ranges_[fIdx] = H1 * H2;
vector<unsigned> indices (ranges_.size(), 0);
for (unsigned k = 0; k < size; k++) {
params_.push_back (copy[index]);
for (int i = ranges_.size() - 1; i >= 0; i--) {
indices[i] ++;
if (i == fIdx) {
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
index += diff * offsets_[i];
} else {
index += offsets_[i];
}
if (indices[i] != ranges_[i]) {
break;
} else {
if (i == fIdx) {
int diff = sumIndexes[0] - sumIndexes[indices[i]];
index += diff * offsets_[i];
} else {
index -= offsets_[i] * ranges_[i];
}
indices[i] = 0;
}
}
}
formulas_.insert (formulas_.begin() + fIdx + 1, formulas_[fIdx]);
formulas_[fIdx].rename (X, X_new1);
formulas_[fIdx + 1].rename (X, X_new2);
args_.insert (args_.begin() + fIdx + 1, args_[fIdx]);
args_[fIdx].rename (X, X_new1);
args_[fIdx + 1].rename (X, X_new2);
ranges_.insert (ranges_.begin() + fIdx + 1, H2);
ranges_[fIdx] = H1;
}
@ -371,13 +294,12 @@ Parfactor::expandPotential (
void
Parfactor::fullExpand (LogVar X)
{
int fIdx = indexOfFormulaWithLogVar (X);
int fIdx = indexOfLogVar (X);
assert (fIdx != -1);
assert (formulas_[fIdx].isCounting());
assert (args_[fIdx].isCounting());
unsigned N = constr_->getConditionalCount (X);
unsigned R = formulas_[fIdx].range();
unsigned H = ranges_[fIdx];
unsigned R = args_[fIdx].range();
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
@ -400,54 +322,17 @@ Parfactor::fullExpand (LogVar X)
++ indexer;
}
unsigned size = (params_.size() / H) * std::pow (R, N);
Params copy = params_;
params_.clear();
params_.reserve (size);
expandPotential (fIdx, std::pow (R, N), sumIndexes);
unsigned prod = 1;
vector<unsigned> offsets_ (ranges_.size());
for (int i = ranges_.size() - 1; i >= 0; i--) {
offsets_[i] = prod;
prod *= ranges_[i];
}
unsigned index = 0;
ranges_[fIdx] = std::pow (R, N);
vector<unsigned> indices (ranges_.size(), 0);
for (unsigned k = 0; k < size; k++) {
params_.push_back (copy[index]);
for (int i = ranges_.size() - 1; i >= 0; i--) {
indices[i] ++;
if (i == fIdx) {
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
index += diff * offsets_[i];
} else {
index += offsets_[i];
}
if (indices[i] != ranges_[i]) {
break;
} else {
if (i == fIdx) {
int diff = sumIndexes[0] - sumIndexes[indices[i]];
index += diff * offsets_[i];
} else {
index -= offsets_[i] * ranges_[i];
}
indices[i] = 0;
}
}
}
ProbFormula f = formulas_[fIdx];
formulas_.erase (formulas_.begin() + fIdx);
ProbFormula f = args_[fIdx];
args_.erase (args_.begin() + fIdx);
ranges_.erase (ranges_.begin() + fIdx);
LogVars newLvs = constr_->expand (X);
assert (newLvs.size() == N);
for (unsigned i = 0 ; i < N; i++) {
ProbFormula newFormula (f.functor(), f.logVars(), f.range());
newFormula.rename (X, newLvs[i]);
formulas_.insert (formulas_.begin() + fIdx + i, newFormula);
args_.insert (args_.begin() + fIdx + i, newFormula);
ranges_.insert (ranges_.begin() + fIdx + i, R);
}
}
@ -459,117 +344,43 @@ Parfactor::reorderAccordingGrounds (const Grounds& grounds)
{
ProbFormulas newFormulas;
for (unsigned i = 0; i < grounds.size(); i++) {
for (unsigned j = 0; j < formulas_.size(); j++) {
if (grounds[i].functor() == formulas_[j].functor() &&
grounds[i].arity() == formulas_[j].arity()) {
constr_->moveToTop (formulas_[j].logVars());
for (unsigned j = 0; j < args_.size(); j++) {
if (grounds[i].functor() == args_[j].functor() &&
grounds[i].arity() == args_[j].arity()) {
constr_->moveToTop (args_[j].logVars());
if (constr_->containsTuple (grounds[i].args())) {
newFormulas.push_back (formulas_[j]);
newFormulas.push_back (args_[j]);
break;
}
}
}
assert (newFormulas.size() == i + 1);
}
reorderFormulas (newFormulas);
reorderArguments (newFormulas);
}
void
Parfactor::reorderFormulas (const ProbFormulas& newFormulas)
{
assert (newFormulas.size() == formulas_.size());
if (newFormulas == formulas_) {
return;
}
Ranges newRanges;
vector<unsigned> positions;
for (unsigned i = 0; i < newFormulas.size(); i++) {
unsigned idx = indexOf (newFormulas[i]);
newRanges.push_back (ranges_[idx]);
positions.push_back (idx);
}
unsigned N = ranges_.size();
Params newParams (params_.size());
for (unsigned i = 0; i < params_.size(); i++) {
unsigned li = i;
// calculate vector index corresponding to linear index
vector<unsigned> vi (N);
for (int k = N-1; k >= 0; k--) {
vi[k] = li % ranges_[k];
li /= ranges_[k];
}
// convert permuted vector index to corresponding linear index
unsigned prod = 1;
unsigned new_li = 0;
for (int k = N - 1; k >= 0; k--) {
new_li += vi[positions[k]] * prod;
prod *= ranges_[positions[k]];
}
newParams[new_li] = params_[i];
}
formulas_ = newFormulas;
ranges_ = newRanges;
params_ = newParams;
}
void
Parfactor::absorveEvidence (unsigned fIdx, unsigned evidence)
Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
{
int fIdx = indexOf (formula);
assert (fIdx != -1);
LogVarSet excl = exclusiveLogVars (fIdx);
assert (fIdx < formulas_.size());
assert (evidence < formulas_[fIdx].range());
assert (formulas_[fIdx].isCounting() == false);
assert (args_[fIdx].isCounting() == false);
assert (constr_->isCountNormalized (excl));
Util::pow (params_, constr_->getConditionalCount (excl));
Params copy = params_;
params_.clear();
params_.reserve (copy.size() / formulas_[fIdx].range());
StatesIndexer indexer (ranges_);
for (unsigned i = 0; i < evidence; i++) {
indexer.increment (fIdx);
}
while (indexer.valid()) {
params_.push_back (copy[indexer]);
indexer.incrementExcluding (fIdx);
}
formulas_.erase (formulas_.begin() + fIdx);
ranges_.erase (ranges_.begin() + fIdx);
LogAware::pow (params_, constr_->getConditionalCount (excl));
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
constr_->remove (excl);
}
void
Parfactor::normalize (void)
{
Util::normalize (params_);
}
void
Parfactor::setFormulaGroup (const ProbFormula& f, int group)
{
assert (indexOf (f) != -1);
formulas_[indexOf (f)].setGroup (group);
}
void
Parfactor::setNewGroups (void)
{
for (unsigned i = 0; i < formulas_.size(); i++) {
formulas_[i].setGroup (ProbFormula::getNewGroup());
for (unsigned i = 0; i < args_.size(); i++) {
args_[i].setGroup (ProbFormula::getNewGroup());
}
}
@ -578,14 +389,14 @@ Parfactor::setNewGroups (void)
void
Parfactor::applySubstitution (const Substitution& theta)
{
for (unsigned i = 0; i < formulas_.size(); i++) {
LogVars& lvs = formulas_[i].logVars();
for (unsigned i = 0; i < args_.size(); i++) {
LogVars& lvs = args_[i].logVars();
for (unsigned j = 0; j < lvs.size(); j++) {
lvs[j] = theta.newNameFor (lvs[j]);
}
if (formulas_[i].isCounting()) {
LogVar clv = formulas_[i].countedLogVar();
formulas_[i].setCountedLogVar (theta.newNameFor (clv));
if (args_[i].isCounting()) {
LogVar clv = args_[i].countedLogVar();
args_[i].setCountedLogVar (theta.newNameFor (clv));
}
}
constr_->applySubstitution (theta);
@ -593,19 +404,29 @@ Parfactor::applySubstitution (const Substitution& theta)
bool
Parfactor::containsGround (const Ground& ground) const
int
Parfactor::findGroup (const Ground& ground) const
{
for (unsigned i = 0; i < formulas_.size(); i++) {
if (formulas_[i].functor() == ground.functor() &&
formulas_[i].arity() == ground.arity()) {
constr_->moveToTop (formulas_[i].logVars());
int group = -1;
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].functor() == ground.functor() &&
args_[i].arity() == ground.arity()) {
constr_->moveToTop (args_[i].logVars());
if (constr_->containsTuple (ground.args())) {
return true;
group = args_[i].group();
break;
}
}
}
return false;
return group;
}
bool
Parfactor::containsGround (const Ground& ground) const
{
return findGroup (ground) != -1;
}
@ -613,8 +434,8 @@ Parfactor::containsGround (const Ground& ground) const
bool
Parfactor::containsGroup (unsigned group) const
{
for (unsigned i = 0; i < formulas_.size(); i++) {
if (formulas_[i].group() == group) {
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].group() == group) {
return true;
}
}
@ -623,30 +444,12 @@ Parfactor::containsGroup (unsigned group) const
const ProbFormula&
Parfactor::formula (unsigned fIdx) const
{
assert (fIdx < formulas_.size());
return formulas_[fIdx];
}
unsigned
Parfactor::range (unsigned fIdx) const
{
assert (fIdx < ranges_.size());
return ranges_[fIdx];
}
unsigned
Parfactor::nrFormulas (LogVar X) const
{
unsigned count = 0;
for (unsigned i = 0; i < formulas_.size(); i++) {
if (formulas_[i].contains (X)) {
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].contains (X)) {
count ++;
}
}
@ -656,27 +459,12 @@ Parfactor::nrFormulas (LogVar X) const
int
Parfactor::indexOf (const ProbFormula& f) const
{
int idx = -1;
for (unsigned i = 0; i < formulas_.size(); i++) {
if (f == formulas_[i]) {
idx = i;
break;
}
}
return idx;
}
int
Parfactor::indexOfFormulaWithLogVar (LogVar X) const
Parfactor::indexOfLogVar (LogVar X) const
{
int idx = -1;
assert (nrFormulas (X) == 1);
for (unsigned i = 0; i < formulas_.size(); i++) {
if (formulas_[i].contains (X)) {
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].contains (X)) {
idx = i;
break;
}
@ -687,11 +475,11 @@ Parfactor::indexOfFormulaWithLogVar (LogVar X) const
int
Parfactor::indexOfFormulaWithGroup (unsigned group) const
Parfactor::indexOfGroup (unsigned group) const
{
int pos = -1;
for (unsigned i = 0; i < formulas_.size(); i++) {
if (formulas_[i].group() == group) {
for (unsigned i = 0; i < args_.size(); i++) {
if (args_[i].group() == group) {
pos = i;
break;
}
@ -704,9 +492,9 @@ Parfactor::indexOfFormulaWithGroup (unsigned group) const
vector<unsigned>
Parfactor::getAllGroups (void) const
{
vector<unsigned> groups (formulas_.size());
for (unsigned i = 0; i < formulas_.size(); i++) {
groups[i] = formulas_[i].group();
vector<unsigned> groups (args_.size());
for (unsigned i = 0; i < args_.size(); i++) {
groups[i] = args_[i].group();
}
return groups;
}
@ -714,13 +502,13 @@ Parfactor::getAllGroups (void) const
string
Parfactor::getHeaderString (void) const
Parfactor::getLabel (void) const
{
stringstream ss;
ss << "phi(" ;
for (unsigned i = 0; i < formulas_.size(); i++) {
for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) ss << "," ;
ss << formulas_[i];
ss << args_[i];
}
ss << ")" ;
ConstraintTree copy (*constr_);
@ -735,32 +523,37 @@ void
Parfactor::print (bool printParams) const
{
cout << "Formulas: " ;
for (unsigned i = 0; i < formulas_.size(); i++) {
for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) cout << ", " ;
cout << formulas_[i];
cout << args_[i];
}
cout << endl;
vector<string> groups;
for (unsigned i = 0; i < formulas_.size(); i++) {
groups.push_back (string ("g") + Util::toString (formulas_[i].group()));
if (args_[0].group() != Util::maxUnsigned()) {
vector<string> groups;
for (unsigned i = 0; i < args_.size(); i++) {
groups.push_back (string ("g") + Util::toString (args_[i].group()));
}
cout << "Groups: " << groups << endl;
}
cout << "Groups: " << groups << endl;
cout << "LogVars: " << constr_->logVars() << endl;
cout << "LogVars: " << constr_->logVarSet() << endl;
cout << "Ranges: " << ranges_ << endl;
if (printParams == false) {
cout << "Params: " << params_ << endl;
}
cout << "Tuples: " << constr_->tupleSet() << endl;
ConstraintTree copy (*constr_);
copy.moveToTop (copy.logVarSet().elements());
cout << "Tuples: " << copy.tupleSet() << endl;
if (printParams) {
vector<string> jointStrings;
StatesIndexer indexer (ranges_);
while (indexer.valid()) {
stringstream ss;
for (unsigned i = 0; i < formulas_.size(); i++) {
for (unsigned i = 0; i < args_.size(); i++) {
if (i != 0) ss << ", " ;
if (formulas_[i].isCounting()) {
unsigned N = constr_->getConditionalCount (formulas_[i].countedLogVar());
HistogramSet hs (N, formulas_[i].range());
if (args_[i].isCounting()) {
unsigned N = constr_->getConditionalCount (
args_[i].countedLogVar());
HistogramSet hs (N, args_[i].range());
unsigned c = 0;
while (c < indexer[i]) {
hs.nextHistogram();
@ -779,22 +572,56 @@ Parfactor::print (bool printParams) const
cout << " = " << params_[i] << endl;
}
}
cout << endl;
}
void
Parfactor::insertDimension (unsigned range)
Parfactor::expandPotential (
int fIdx,
unsigned newRange,
const vector<unsigned>& sumIndexes)
{
unsigned size = (params_.size() / ranges_[fIdx]) * newRange;
Params copy = params_;
params_.clear();
params_.reserve (copy.size() * range);
for (unsigned i = 0; i < copy.size(); i++) {
for (unsigned reps = 0; reps < range; reps++) {
params_.push_back (copy[i]);
params_.reserve (size);
unsigned prod = 1;
vector<unsigned> offsets_ (ranges_.size());
for (int i = ranges_.size() - 1; i >= 0; i--) {
offsets_[i] = prod;
prod *= ranges_[i];
}
unsigned index = 0;
ranges_[fIdx] = newRange;
vector<unsigned> indices (ranges_.size(), 0);
for (unsigned k = 0; k < size; k++) {
params_.push_back (copy[index]);
for (int i = ranges_.size() - 1; i >= 0; i--) {
indices[i] ++;
if (i == fIdx) {
assert (indices[i] - 1 < sumIndexes.size());
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
index += diff * offsets_[i];
} else {
index += offsets_[i];
}
if (indices[i] != ranges_[i]) {
break;
} else {
if (i == fIdx) {
int diff = sumIndexes[0] - sumIndexes[indices[i]];
index += diff * offsets_[i];
} else {
index -= offsets_[i] * ranges_[i];
}
indices[i] = 0;
}
}
}
ranges_.push_back (range);
}
@ -803,29 +630,27 @@ void
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
{
LogVars X_1, X_2;
const ProbFormulas& formulas1 = g1->formulas();
const ProbFormulas& formulas2 = g2->formulas();
const ProbFormulas& formulas1 = g1->arguments();
const ProbFormulas& formulas2 = g2->arguments();
for (unsigned i = 0; i < formulas1.size(); i++) {
for (unsigned j = 0; j < formulas2.size(); j++) {
if (formulas1[i].group() == formulas2[j].group()) {
X_1.insert (X_1.end(),
formulas1[i].logVars().begin(),
formulas1[i].logVars().end());
X_2.insert (X_2.end(),
formulas2[j].logVars().begin(),
formulas2[j].logVars().end());
Util::addToVector (X_1, formulas1[i].logVars());
Util::addToVector (X_2, formulas2[j].logVars());
}
}
}
align (g1, X_1, g2, X_2);
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1);
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2);
assert (g1->constr()->isCountNormalized (Y_1));
assert (g2->constr()->isCountNormalized (Y_2));
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
Util::pow (g1->params(), 1.0 / condCount2);
Util::pow (g2->params(), 1.0 / condCount1);
LogAware::pow (g1->params(), 1.0 / condCount2);
LogAware::pow (g2->params(), 1.0 / condCount1);
// this must be done in the end or else X_1 and X_2
// will refer the old log var names in the code above
align (g1, X_1, g2, X_2);
}
@ -838,7 +663,6 @@ Parfactor::align (
LogVar freeLogVar = 0;
Substitution theta1;
Substitution theta2;
const LogVarSet& allLvs1 = g1->logVarSet();
for (unsigned i = 0; i < allLvs1.size(); i++) {
theta1.add (allLvs1[i], freeLogVar);
@ -850,7 +674,7 @@ Parfactor::align (
theta2.add (allLvs2[i], freeLogVar);
++ freeLogVar;
}
assert (alignLvs1.size() == alignLvs2.size());
for (unsigned i = 0; i < alignLvs1.size(); i++) {
theta1.rename (alignLvs1[i], theta2.newNameFor (alignLvs2[i]));

View File

@ -9,8 +9,9 @@
#include "LiftedUtils.h"
#include "Horus.h"
#include "Factor.h"
class Parfactor
class Parfactor : public TFactor<ProbFormula>
{
public:
Parfactor (
@ -18,27 +19,15 @@ class Parfactor
const Params&,
const Tuples&,
unsigned);
Parfactor (const Parfactor*, const Tuple&);
Parfactor (const Parfactor*, ConstraintTree*);
Parfactor (const Parfactor&);
~Parfactor (void);
ProbFormulas& formulas (void) { return formulas_; }
const ProbFormulas& formulas (void) const { return formulas_; }
unsigned nrFormulas (void) const { return formulas_.size(); }
Params& params (void) { return params_; }
const Params& params (void) const { return params_; }
unsigned size (void) const { return params_.size(); }
const Ranges& ranges (void) const { return ranges_; }
unsigned distId (void) const { return distId_; }
ConstraintTree* constr (void) { return constr_; }
const ConstraintTree* constr (void) const { return constr_; }
@ -57,64 +46,52 @@ class Parfactor
void setConstraintTree (ConstraintTree*);
void sumOut (unsigned);
void sumOut (unsigned fIdx);
void multiply (Parfactor&);
void countConvert (LogVar);
void expandPotential (LogVar, LogVar, LogVar);
void expand (LogVar, LogVar, LogVar);
void fullExpand (LogVar);
void reorderAccordingGrounds (const Grounds&);
void reorderFormulas (const ProbFormulas&);
void absorveEvidence (unsigned, unsigned);
void normalize (void);
void setFormulaGroup (const ProbFormula&, int);
void absorveEvidence (const ProbFormula&, unsigned);
void setNewGroups (void);
void applySubstitution (const Substitution&);
int findGroup (const Ground&) const;
bool containsGround (const Ground&) const;
bool containsGroup (unsigned) const;
const ProbFormula& formula (unsigned) const;
unsigned range (unsigned) const;
unsigned nrFormulas (LogVar) const;
int indexOf (const ProbFormula&) const;
int indexOfLogVar (LogVar) const;
int indexOfFormulaWithLogVar (LogVar) const;
int indexOfFormulaWithGroup (unsigned) const;
int indexOfGroup (unsigned) const;
vector<unsigned> getAllGroups (void) const;
void print (bool = false) const;
string getHeaderString (void) const;
string getLabel (void) const;
private:
void expandPotential (int fIdx, unsigned newRange,
const vector<unsigned>& sumIndexes);
static void alignAndExponentiate (Parfactor*, Parfactor*);
static void align (
Parfactor*, const LogVars&, Parfactor*, const LogVars&);
void insertDimension (unsigned);
ProbFormulas formulas_;
Ranges ranges_;
Params params_;
unsigned distId_;
ConstraintTree* constr_;
ConstraintTree* constr_;
};

View File

@ -3,9 +3,32 @@
#include "ParfactorList.h"
ParfactorList::ParfactorList (Parfactors& pfs)
ParfactorList::ParfactorList (const ParfactorList& pfList)
{
pfList_.insert (pfList_.end(), pfs.begin(), pfs.end());
ParfactorList::const_iterator it = pfList.begin();
while (it != pfList.end()) {
addShattered (new Parfactor (**it));
++ it;
}
}
ParfactorList::ParfactorList (const Parfactors& pfs)
{
add (pfs);
}
ParfactorList::~ParfactorList (void)
{
ParfactorList::const_iterator it = pfList_.begin();
while (it != pfList_.end()) {
delete *it;
++ it;
}
}
@ -14,17 +37,17 @@ void
ParfactorList::add (Parfactor* pf)
{
pf->setNewGroups();
pfList_.push_back (pf);
addToShatteredList (pf);
}
void
ParfactorList::add (Parfactors& pfs)
ParfactorList::add (const Parfactors& pfs)
{
for (unsigned i = 0; i < pfs.size(); i++) {
pfs[i]->setNewGroups();
pfList_.push_back (pfs[i]);
addToShatteredList (pfs[i]);
}
}
@ -33,7 +56,20 @@ ParfactorList::add (Parfactors& pfs)
void
ParfactorList::addShattered (Parfactor* pf)
{
assert (isAllShattered());
pfList_.push_back (pf);
assert (isAllShattered());
}
list<Parfactor*>::iterator
ParfactorList::insertShattered (
list<Parfactor*>::iterator it,
Parfactor* pf)
{
return pfList_.insert (it, pf);
assert (isAllShattered());
}
@ -47,7 +83,7 @@ ParfactorList::remove (list<Parfactor*>::iterator it)
list<Parfactor*>::iterator
ParfactorList::deleteAndRemove (list<Parfactor*>::iterator it)
ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
{
delete *it;
return pfList_.erase (it);
@ -55,58 +91,21 @@ ParfactorList::deleteAndRemove (list<Parfactor*>::iterator it)
void
ParfactorList::shatter (void)
bool
ParfactorList::isAllShattered (void) const
{
list<Parfactor*> tempList;
Parfactors newPfs;
newPfs.insert (newPfs.end(), pfList_.begin(), pfList_.end());
while (newPfs.empty() == false) {
tempList.insert (tempList.end(), newPfs.begin(), newPfs.end());
newPfs.clear();
list<Parfactor*>::iterator iter1 = tempList.begin();
while (tempList.size() > 1 && iter1 != -- tempList.end()) {
list<Parfactor*>::iterator iter2 = iter1;
++ iter2;
bool incIter1 = true;
while (iter2 != tempList.end()) {
assert (iter1 != iter2);
std::pair<Parfactors, Parfactors> res = shatter (
(*iter1)->formulas(), *iter1, (*iter2)->formulas(), *iter2);
bool incIter2 = true;
if (res.second.empty() == false) {
// cout << "second unshattered" << endl;
delete *iter2;
iter2 = tempList.erase (iter2);
incIter2 = false;
newPfs.insert (
newPfs.begin(), res.second.begin(), res.second.end());
}
if (res.first.empty() == false) {
// cout << "first unshattered" << endl;
delete *iter1;
iter1 = tempList.erase (iter1);
newPfs.insert (
newPfs.begin(), res.first.begin(), res.first.end());
incIter1 = false;
break;
}
if (incIter2) {
++ iter2;
}
}
if (incIter1) {
++ iter1;
if (pfList_.size() <= 1) {
return true;
}
vector<Parfactor*> pfs (pfList_.begin(), pfList_.end());
for (unsigned i = 0; i < pfs.size() - 1; i++) {
for (unsigned j = i + 1; j < pfs.size(); j++) {
if (isShattered (pfs[i], pfs[j]) == false) {
return false;
}
}
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
// cout << "||||||||||||| SHATTERING ITERATION ||||||||||||||" << endl;
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
// printParfactors (newPfs);
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
}
pfList_.clear();
pfList_.insert (pfList_.end(), tempList.begin(), tempList.end());
return true;
}
@ -117,25 +116,88 @@ ParfactorList::print (void) const
list<Parfactor*>::const_iterator it;
for (it = pfList_.begin(); it != pfList_.end(); ++it) {
(*it)->print();
cout << endl;
}
}
std::pair<Parfactors, Parfactors>
ParfactorList::shatter (
ProbFormulas& formulas1,
Parfactor* g1,
ProbFormulas& formulas2,
Parfactor* g2)
bool
ParfactorList::isShattered (
const Parfactor* g1,
const Parfactor* g2) const
{
assert (g1 != g2);
const ProbFormulas& fms1 = g1->arguments();
const ProbFormulas& fms2 = g2->arguments();
for (unsigned i = 0; i < fms1.size(); i++) {
for (unsigned j = 0; j < fms2.size(); j++) {
if (fms1[i].group() == fms2[j].group()) {
if (identical (
fms1[i], *(g1->constr()),
fms2[j], *(g2->constr())) == false) {
return false;
}
} else {
if (disjoint (
fms1[i], *(g1->constr()),
fms2[j], *(g2->constr())) == false) {
return false;
}
}
}
}
return true;
}
void
ParfactorList::addToShatteredList (Parfactor* g)
{
queue<Parfactor*> residuals;
residuals.push (g);
while (residuals.empty() == false) {
Parfactor* pf = residuals.front();
bool pfSplitted = false;
list<Parfactor*>::iterator pfIter;
pfIter = pfList_.begin();
while (pfIter != pfList_.end()) {
std::pair<Parfactors, Parfactors> shattRes;
shattRes = shatter (*pfIter, pf);
if (shattRes.first.empty() == false) {
pfIter = removeAndDelete (pfIter);
Util::addToQueue (residuals, shattRes.first);
} else {
++ pfIter;
}
if (shattRes.second.empty() == false) {
delete pf;
Util::addToQueue (residuals, shattRes.second);
pfSplitted = true;
break;
}
}
residuals.pop();
if (pfSplitted == false) {
addShattered (pf);
}
}
assert (isAllShattered());
}
std::pair<Parfactors, Parfactors>
ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
{
ProbFormulas& formulas1 = g1->arguments();
ProbFormulas& formulas2 = g2->arguments();
assert (g1 != 0 && g2 != 0 && g1 != g2);
for (unsigned i = 0; i < formulas1.size(); i++) {
for (unsigned j = 0; j < formulas2.size(); j++) {
if (formulas1[i].sameSkeletonAs (formulas2[j])) {
std::pair<Parfactors, Parfactors> res
= shatter (formulas1[i], g1, formulas2[j], g2);
std::pair<Parfactors, Parfactors> res;
res = shatter (i, g1, j, g2);
if (res.first.empty() == false ||
res.second.empty() == false) {
return res;
@ -150,21 +212,22 @@ ParfactorList::shatter (
std::pair<Parfactors, Parfactors>
ParfactorList::shatter (
ProbFormula& f1,
Parfactor* g1,
ProbFormula& f2,
Parfactor* g2)
unsigned fIdx1, Parfactor* g1,
unsigned fIdx2, Parfactor* g2)
{
ProbFormula& f1 = g1->argument (fIdx1);
ProbFormula& f2 = g2->argument (fIdx2);
// cout << endl;
// cout << "-------------------------------------------------" << endl;
// Util::printDashedLine();
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
// g1->print();
// cout << "-> WITH" << endl;
// g2->print();
// cout << "-> ON: " << f1.toString (g1->constr()) << endl;
// cout << "-> ON: " << f2.toString (g2->constr()) << endl;
// cout << "-------------------------------------------------" << endl;
// cout << "-> ON: " << f1 << "|" ;
// cout << g1->constr()->tupleSet (f1.logVars()) << endl;
// cout << "-> ON: " << f2 << "|" ;
// cout << g2->constr()->tupleSet (f2.logVars()) << endl;
// Util::printDashedLine();
if (f1.isAtom()) {
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
f1.setGroup (group);
@ -174,7 +237,7 @@ ParfactorList::shatter (
assert (g1->constr()->empty() == false);
assert (g2->constr()->empty() == false);
if (f1.group() == f2.group()) {
// assert (identical (f1, g1->constr(), f2, g2->constr()));
assert (identical (f1, *(g1->constr()), f2, *(g2->constr())));
return { };
}
@ -201,21 +264,24 @@ ParfactorList::shatter (
assert (commCt1->tupleSet (f1.arity()) ==
commCt2->tupleSet (f2.arity()));
// stringstream ss1; ss1 << "" << count << "_A.dot" ;
// stringstream ss2; ss2 << "" << count << "_B.dot" ;
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
// ct1->exportToGraphViz (ss1.str().c_str(), true);
// ct2->exportToGraphViz (ss2.str().c_str(), true);
// commCt1->exportToGraphViz (ss3.str().c_str(), true);
// exclCt1->exportToGraphViz (ss4.str().c_str(), true);
// commCt2->exportToGraphViz (ss5.str().c_str(), true);
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
// unsigned static count = 0; count ++;
// stringstream ss1; ss1 << "" << count << "_A.dot" ;
// stringstream ss2; ss2 << "" << count << "_B.dot" ;
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
// g1->constr()->exportToGraphViz (ss1.str().c_str(), true);
// g2->constr()->exportToGraphViz (ss2.str().c_str(), true);
// commCt1->exportToGraphViz (ss3.str().c_str(), true);
// exclCt1->exportToGraphViz (ss4.str().c_str(), true);
// commCt2->exportToGraphViz (ss5.str().c_str(), true);
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
if (exclCt1->empty() && exclCt2->empty()) {
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
unsigned group = (f1.group() < f2.group())
? f1.group()
: f2.group();
// identical
f1.setGroup (group);
f2.setGroup (group);
@ -235,8 +301,8 @@ ParfactorList::shatter (
} else {
group = ProbFormula::getNewGroup();
}
Parfactors res1 = shatter (g1, f1, commCt1, exclCt1, group);
Parfactors res2 = shatter (g2, f2, commCt2, exclCt2, group);
Parfactors res1 = shatter (g1, fIdx1, commCt1, exclCt1, group);
Parfactors res2 = shatter (g2, fIdx2, commCt2, exclCt2, group);
return make_pair (res1, res2);
}
@ -245,11 +311,19 @@ ParfactorList::shatter (
Parfactors
ParfactorList::shatter (
Parfactor* g,
const ProbFormula& f,
unsigned fIdx,
ConstraintTree* commCt,
ConstraintTree* exclCt,
unsigned commGroup)
{
ProbFormula& f = g->argument (fIdx);
if (exclCt->empty()) {
delete commCt;
delete exclCt;
f.setGroup (commGroup);
return { };
}
Parfactors result;
if (f.isCounting()) {
LogVar X_new1 = g->constr()->logVarSet().back() + 1;
@ -259,7 +333,7 @@ ParfactorList::shatter (
for (unsigned i = 0; i < cts.size(); i++) {
Parfactor* newPf = new Parfactor (g, cts[i]);
if (cts[i]->nrLogVars() == g->constr()->nrLogVars() + 1) {
newPf->expandPotential (f.countedLogVar(), X_new1, X_new2);
newPf->expand (f.countedLogVar(), X_new1, X_new2);
assert (g->constr()->getConditionalCount (f.countedLogVar()) ==
cts[i]->getConditionalCount (X_new1) +
cts[i]->getConditionalCount (X_new2));
@ -270,20 +344,16 @@ ParfactorList::shatter (
newPf->setNewGroups();
result.push_back (newPf);
}
delete commCt;
delete exclCt;
} else {
if (exclCt->empty()) {
delete commCt;
delete exclCt;
g->setFormulaGroup (f, commGroup);
} else {
Parfactor* newPf = new Parfactor (g, commCt);
newPf->setNewGroups();
newPf->setFormulaGroup (f, commGroup);
result.push_back (newPf);
newPf = new Parfactor (g, exclCt);
newPf->setNewGroups();
result.push_back (newPf);
}
Parfactor* newPf = new Parfactor (g, commCt);
newPf->setNewGroups();
newPf->argument (fIdx).setGroup (commGroup);
result.push_back (newPf);
newPf = new Parfactor (g, exclCt);
newPf->setNewGroups();
result.push_back (newPf);
}
return result;
}
@ -296,7 +366,7 @@ ParfactorList::unifyGroups (unsigned group1, unsigned group2)
unsigned newGroup = ProbFormula::getNewGroup();
for (ParfactorList::iterator it = pfList_.begin();
it != pfList_.end(); it++) {
ProbFormulas& formulas = (*it)->formulas();
ProbFormulas& formulas = (*it)->arguments();
for (unsigned i = 0; i < formulas.size(); i++) {
if (formulas[i].group() == group1 ||
formulas[i].group() == group2) {
@ -306,3 +376,52 @@ ParfactorList::unifyGroups (unsigned group1, unsigned group2)
}
}
bool
ParfactorList::proper (
const ProbFormula& f1, ConstraintTree c1,
const ProbFormula& f2, ConstraintTree c2) const
{
return disjoint (f1, c1, f2, c2)
|| identical (f1, c1, f2, c2);
}
bool
ParfactorList::identical (
const ProbFormula& f1, ConstraintTree c1,
const ProbFormula& f2, ConstraintTree c2) const
{
if (f1.sameSkeletonAs (f2) == false) {
return false;
}
if (f1.isAtom()) {
return true;
}
c1.moveToTop (f1.logVars());
c2.moveToTop (f2.logVars());
return ConstraintTree::identical (
&c1, &c2, f1.logVars().size());
}
bool
ParfactorList::disjoint (
const ProbFormula& f1, ConstraintTree c1,
const ProbFormula& f2, ConstraintTree c2) const
{
if (f1.sameSkeletonAs (f2) == false) {
return true;
}
if (f1.isAtom()) {
return true;
}
c1.moveToTop (f1.logVars());
c2.moveToTop (f2.logVars());
return ConstraintTree::overlap (
&c1, &c2, f1.arity()) == false;
}

View File

@ -2,6 +2,7 @@
#define HORUS_PARFACTORLIST_H
#include <list>
#include <queue>
#include "Parfactor.h"
#include "ProbFormula.h"
@ -14,56 +15,82 @@ class ParfactorList
{
public:
ParfactorList (void) { }
ParfactorList (Parfactors&);
list<Parfactor*>& getParfactors (void) { return pfList_; }
const list<Parfactor*>& getParfactors (void) const { return pfList_; }
void add (Parfactor* pf);
void add (Parfactors& pfs);
void addShattered (Parfactor* pf);
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
list<Parfactor*>::iterator deleteAndRemove (list<Parfactor*>::iterator);
ParfactorList (const ParfactorList&);
void clear (void) { pfList_.clear(); }
unsigned size (void) const { return pfList_.size(); }
ParfactorList (const Parfactors&);
void shatter (void);
~ParfactorList (void);
const list<Parfactor*>& parfactors (void) const { return pfList_; }
void clear (void) { pfList_.clear(); }
unsigned size (void) const { return pfList_.size(); }
typedef std::list<Parfactor*>::iterator iterator;
iterator begin (void) { return pfList_.begin(); }
iterator end (void) { return pfList_.end(); }
iterator end (void) { return pfList_.end(); }
typedef std::list<Parfactor*>::const_iterator const_iterator;
const_iterator begin (void) const { return pfList_.begin(); }
const_iterator end (void) const { return pfList_.end(); }
const_iterator end (void) const { return pfList_.end(); }
void add (Parfactor* pf);
void add (const Parfactors& pfs);
void addShattered (Parfactor* pf);
list<Parfactor*>::iterator insertShattered (
list<Parfactor*>::iterator, Parfactor*);
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
list<Parfactor*>::iterator removeAndDelete (list<Parfactor*>::iterator);
bool isAllShattered (void) const;
void print (void) const;
private:
bool isShattered (const Parfactor*, const Parfactor*) const;
static std::pair<Parfactors, Parfactors> shatter (
ProbFormulas&,
Parfactor*,
ProbFormulas&,
Parfactor*);
void addToShatteredList (Parfactor*);
std::pair<Parfactors, Parfactors> shatter (
Parfactor*, Parfactor*);
static std::pair<Parfactors, Parfactors> shatter (
ProbFormula&,
Parfactor*,
ProbFormula&,
Parfactor*);
std::pair<Parfactors, Parfactors> shatter (
unsigned, Parfactor*, unsigned, Parfactor*);
static Parfactors shatter (
Parfactor*,
const ProbFormula&,
ConstraintTree*,
ConstraintTree*,
unsigned);
Parfactors shatter (
Parfactor*,
unsigned,
ConstraintTree*,
ConstraintTree*,
unsigned);
void unifyGroups (unsigned group1, unsigned group2);
void unifyGroups (unsigned group1, unsigned group2);
list<Parfactor*> pfList_;
bool proper (
const ProbFormula&, ConstraintTree,
const ProbFormula&, ConstraintTree) const;
bool identical (
const ProbFormula&, ConstraintTree,
const ProbFormula&, ConstraintTree) const;
bool disjoint (
const ProbFormula&, ConstraintTree,
const ProbFormula&, ConstraintTree) const;
list<Parfactor*> pfList_;
};
#endif // HORUS_PARFACTORLIST_H

View File

@ -16,8 +16,7 @@ ProbFormula::sameSkeletonAs (const ProbFormula& f) const
bool
ProbFormula::contains (LogVar lv) const
{
return std::find (logVars_.begin(), logVars_.end(), lv) !=
logVars_.end();
return Util::contains (logVars_, lv);
}
@ -77,16 +76,15 @@ ProbFormula::rename (LogVar oldName, LogVar newName)
}
bool
ProbFormula::operator== (const ProbFormula& f) const
bool operator== (const ProbFormula& f1, const ProbFormula& f2)
{
return functor_ == f.functor_ && logVars_ == f.logVars_ ;
return f1.group_ == f2.group_;
//return functor_ == f.functor_ && logVars_ == f.logVars_ ;
}
ostream& operator<< (ostream &os, const ProbFormula& f)
std::ostream& operator<< (ostream &os, const ProbFormula& f)
{
os << f.functor_;
if (f.isAtom() == false) {
@ -113,3 +111,13 @@ ProbFormula::getNewGroup (void)
return freeGroup_;
}
ostream& operator<< (ostream &os, const ObservedFormula& of)
{
os << of.functor_ << "/" << of.arity_;
os << "|" << of.constr_.tupleSet();
os << " [evidence=" << of.evidence_ << "]";
return os;
}

View File

@ -8,14 +8,16 @@
#include "Horus.h"
class ProbFormula
{
public:
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
: functor_(f), logVars_(lvs), range_(range),
countedLogVar_() { }
countedLogVar_(), group_(Util::maxUnsigned()) { }
ProbFormula (Symbol f, unsigned r) : functor_(f), range_(r) { }
ProbFormula (Symbol f, unsigned r)
: functor_(f), range_(r), group_(Util::maxUnsigned()) { }
Symbol functor (void) const { return functor_; }
@ -29,9 +31,9 @@ class ProbFormula
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
unsigned group (void) const { return groupId_; }
unsigned group (void) const { return group_; }
void setGroup (unsigned g) { groupId_ = g; }
void setGroup (unsigned g) { group_ = g; }
bool sameSkeletonAs (const ProbFormula&) const;
@ -49,23 +51,58 @@ class ProbFormula
void rename (LogVar, LogVar);
bool operator== (const ProbFormula& f) const;
friend ostream& operator<< (ostream &out, const ProbFormula& f);
static unsigned getNewGroup (void);
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
friend bool operator== (const ProbFormula& f1, const ProbFormula& f2);
private:
Symbol functor_;
LogVars logVars_;
unsigned range_;
LogVar countedLogVar_;
unsigned groupId_;
static int freeGroup_;
Symbol functor_;
LogVars logVars_;
unsigned range_;
LogVar countedLogVar_;
unsigned group_;
static int freeGroup_;
};
typedef vector<ProbFormula> ProbFormulas;
class ObservedFormula
{
public:
ObservedFormula (Symbol f, unsigned a, unsigned ev)
: functor_(f), arity_(a), evidence_(ev), constr_(a) { }
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
{
constr_.addTuple (tuple);
}
Symbol functor (void) const { return functor_; }
unsigned arity (void) const { return arity_; }
unsigned evidence (void) const { return evidence_; }
ConstraintTree& constr (void) { return constr_; }
bool isAtom (void) const { return arity_ == 0; }
void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); }
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
private:
Symbol functor_;
unsigned arity_;
unsigned evidence_;
ConstraintTree constr_;
};
typedef vector<ObservedFormula> ObservedFormulas;
#endif // HORUS_PROBFORMULA_H

View File

@ -3,51 +3,35 @@
void
Solver::printAllPosterioris (void)
Solver::printAnswer (const VarIds& vids)
{
const VarNodes& vars = gm_->getVariableNodes();
for (unsigned i = 0; i < vars.size(); i++) {
printPosterioriOf (vars[i]->varId());
}
}
void
Solver::printPosterioriOf (VarId vid)
{
VarNode* var = gm_->getVariableNode (vid);
const Params& posterioriDist = getPosterioriOf (vid);
const States& states = var->states();
for (unsigned i = 0; i < states.size(); i++) {
cout << "P(" << var->label() << "=" << states[i] << ") = " ;
cout << setprecision (PRECISION) << posterioriDist[i];
cout << endl;
}
cout << endl;
}
void
Solver::printJointDistributionOf (const VarIds& vids)
{
VarNodes vars;
VarIds vidsWithoutEvidence;
Vars unobservedVars;
VarIds unobservedVids;
for (unsigned i = 0; i < vids.size(); i++) {
VarNode* var = gm_->getVariableNode (vids[i]);
if (var->hasEvidence() == false) {
vars.push_back (var);
vidsWithoutEvidence.push_back (vids[i]);
VarNode* vn = fg.getVarNode (vids[i]);
if (vn->hasEvidence() == false) {
unobservedVars.push_back (vn);
unobservedVids.push_back (vids[i]);
}
}
const Params& jointDist = getJointDistributionOf (vidsWithoutEvidence);
vector<string> jointStrings = Util::getJointStateStrings (vars);
for (unsigned i = 0; i < jointDist.size(); i++) {
cout << "P(" << jointStrings[i] << ") = " ;
cout << setprecision (PRECISION) << jointDist[i];
Params res = solveQuery (unobservedVids);
vector<string> stateLines = Util::getStateLines (unobservedVars);
for (unsigned i = 0; i < res.size(); i++) {
cout << "P(" << stateLines[i] << ") = " ;
cout << std::setprecision (Constants::PRECISION) << res[i];
cout << endl;
}
cout << endl;
}
void
Solver::printAllPosterioris (void)
{
const VarNodes& vars = fg.varNodes();
for (unsigned i = 0; i < vars.size(); i++) {
printAnswer ({vars[i]->varId()});
}
}

View File

@ -3,29 +3,27 @@
#include <iomanip>
#include "GraphicalModel.h"
#include "VarNode.h"
#include "Var.h"
#include "FactorGraph.h"
using namespace std;
class Solver
{
public:
Solver (const GraphicalModel* gm)
{
gm_ = gm;
}
virtual ~Solver() {} // to ensure that subclass destructor is called
virtual void runSolver (void) = 0;
virtual Params getPosterioriOf (VarId) = 0;
virtual Params getJointDistributionOf (const VarIds&) = 0;
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
virtual ~Solver() { } // ensure that subclass destructor is called
virtual Params solveQuery (VarIds queryVids) = 0;
void printAnswer (const VarIds& vids);
void printAllPosterioris (void);
void printPosterioriOf (VarId vid);
void printJointDistributionOf (const VarIds& vids);
private:
const GraphicalModel* gm_;
protected:
const FactorGraph& fg;
};
#endif // HORUS_SOLVER_H

View File

@ -0,0 +1,4 @@
TODO
- add a way to calculate combinations and factorials with large numbers
- refactor sumOut in parfactor -> is really ugly code
- Indexer: start receiving ranges as constant reference

View File

@ -1,21 +1,22 @@
#include <limits>
#include <sstream>
#include <fstream>
#include "Util.h"
#include "Indexer.h"
#include "GraphicalModel.h"
namespace Globals {
bool logDomain = false;
bool logDomain = false;
//InfAlgs infAlgorithm = InfAlgorithms::VE;
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
};
namespace InfAlgorithms {
//InfAlgs infAlgorithm = InfAlgorithms::VE;
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
//InfAlgs infAlgorithm = InfAlgorithms::CBP;
}
namespace BpOptions {
@ -23,13 +24,11 @@ Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
//Schedule schedule = BpOptions::Schedule::PARALLEL;
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
double accuracy = 0.0001;
unsigned maxIter = 1000;
double accuracy = 0.0001;
unsigned maxIter = 1000;
}
unordered_map<VarId,VariableInfo> GraphicalModel::varsInfo_;
unordered_map<unsigned,Distribution*> GraphicalModel::distsInfo_;
vector<NetInfo> Statistics::netInfo_;
vector<CompressInfo> Statistics::compressInfo_;
@ -58,76 +57,6 @@ fromLog (Params& v)
void
normalize (Params& v)
{
double sum;
if (Globals::logDomain) {
sum = addIdenty();
for (unsigned i = 0; i < v.size(); i++) {
logSum (sum, v[i]);
}
assert (sum != -numeric_limits<double>::infinity());
for (unsigned i = 0; i < v.size(); i++) {
v[i] -= sum;
}
} else {
sum = 0.0;
for (unsigned i = 0; i < v.size(); i++) {
sum += v[i];
}
assert (sum != 0.0);
for (unsigned i = 0; i < v.size(); i++) {
v[i] /= sum;
}
}
}
void
pow (Params& v, double expoent)
{
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
}
void
pow (Params& v, unsigned expoent)
{
if (expoent == 1) {
return;
}
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
}
double
pow (double p, unsigned expoent)
{
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
}
double
factorial (double num)
{
@ -153,6 +82,151 @@ nrCombinations (unsigned n, unsigned r)
unsigned
expectedSize (const Ranges& ranges)
{
unsigned prod = 1;
for (unsigned i = 0; i < ranges.size(); i++) {
prod *= ranges[i];
}
return prod;
}
unsigned
getNumberOfDigits (int number)
{
unsigned count = 1;
while (number >= 10) {
number /= 10;
count ++;
}
return count;
}
bool
isInteger (const string& s)
{
stringstream ss1 (s);
stringstream ss2;
int integer;
ss1 >> integer;
ss2 << integer;
return (ss1.str() == ss2.str());
}
string
parametersToString (const Params& v, unsigned precision)
{
stringstream ss;
ss.precision (precision);
ss << "[" ;
for (unsigned i = 0; i < v.size(); i++) {
if (i != 0) ss << ", " ;
ss << v[i];
}
ss << "]" ;
return ss.str();
}
vector<string>
getStateLines (const Vars& vars)
{
StatesIndexer idx (vars);
vector<string> jointStrings;
while (idx.valid()) {
stringstream ss;
for (unsigned i = 0; i < vars.size(); i++) {
if (i != 0) ss << ", " ;
ss << vars[i]->label() << "=" << vars[i]->states()[(idx[i])];
}
jointStrings.push_back (ss.str());
++ idx;
}
return jointStrings;
}
void
printHeader (string header, std::ostream& os)
{
printAsteriskLine (os);
os << header << endl;
printAsteriskLine (os);
}
void
printSubHeader (string header, std::ostream& os)
{
printDashedLine (os);
os << header << endl;
printDashedLine (os);
}
void
printAsteriskLine (std::ostream& os)
{
os << "********************************" ;
os << "********************************" ;
os << endl;
}
void
printDashedLine (std::ostream& os)
{
os << "--------------------------------" ;
os << "--------------------------------" ;
os << endl;
}
}
namespace LogAware {
void
normalize (Params& v)
{
double sum;
if (Globals::logDomain) {
sum = LogAware::addIdenty();
for (unsigned i = 0; i < v.size(); i++) {
sum = Util::logSum (sum, v[i]);
}
assert (sum != -numeric_limits<double>::infinity());
for (unsigned i = 0; i < v.size(); i++) {
v[i] -= sum;
}
} else {
sum = 0.0;
for (unsigned i = 0; i < v.size(); i++) {
sum += v[i];
}
assert (sum != 0.0);
for (unsigned i = 0; i < v.size(); i++) {
v[i] /= sum;
}
}
}
double
getL1Distance (const Params& v1, const Params& v2)
{
@ -196,67 +270,57 @@ getMaxNorm (const Params& v1, const Params& v2)
}
double
pow (double p, unsigned expoent)
{
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
}
unsigned
getNumberOfDigits (int number) {
unsigned count = 1;
while (number >= 10) {
number /= 10;
count ++;
double
pow (double p, double expoent)
{
// assumes that `expoent' is never in log domain
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
}
void
pow (Params& v, unsigned expoent)
{
if (expoent == 1) {
return;
}
return count;
}
bool
isInteger (const string& s)
{
stringstream ss1 (s);
stringstream ss2;
int integer;
ss1 >> integer;
ss2 << integer;
return (ss1.str() == ss2.str());
}
string
parametersToString (const Params& v, unsigned precision)
{
stringstream ss;
ss.precision (precision);
ss << "[" ;
for (unsigned i = 0; i < v.size(); i++) {
if (i != 0) ss << ", " ;
ss << v[i];
}
ss << "]" ;
return ss.str();
}
vector<string>
getJointStateStrings (const VarNodes& vars)
{
StatesIndexer idx (vars);
vector<string> jointStrings;
while (idx.valid()) {
stringstream ss;
for (unsigned i = 0; i < vars.size(); i++) {
if (i != 0) ss << ", " ;
ss << vars[i]->label() << "=" << vars[i]->states()[(idx[i])];
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
jointStrings.push_back (ss.str());
++ idx;
}
return jointStrings;
}
void
pow (Params& v, double expoent)
{
// assumes that `expoent' is never in log domain
if (Globals::logDomain) {
for (unsigned i = 0; i < v.size(); i++) {
v[i] *= expoent;
}
} else {
for (unsigned i = 0; i < v.size(); i++) {
v[i] = std::pow (v[i], expoent);
}
}
}
}
@ -286,8 +350,11 @@ Statistics::getPrimaryNetworksCounting (void)
void
Statistics::updateStatistics (unsigned size, bool loopy,
unsigned nIters, double time)
Statistics::updateStatistics (
unsigned size,
bool loopy,
unsigned nIters,
double time)
{
netInfo_.push_back (NetInfo (size, loopy, nIters, time));
}
@ -303,12 +370,12 @@ Statistics::printStatistics (void)
void
Statistics::writeStatisticsToFile (const char* fileName)
Statistics::writeStatistics (const char* fileName)
{
ofstream out (fileName);
if (!out.is_open()) {
cerr << "error: cannot open file to write at " ;
cerr << "Statistics::writeStatisticsToFile()" << endl;
cerr << "Statistics::writeStats()" << endl;
abort();
}
out << getStatisticString();
@ -318,13 +385,14 @@ Statistics::writeStatisticsToFile (const char* fileName)
void
Statistics::updateCompressingStatistics (unsigned nGroundVars,
unsigned nGroundFactors,
unsigned nClusterVars,
unsigned nClusterFactors,
unsigned nWithoutNeighs) {
compressInfo_.push_back (CompressInfo (nGroundVars, nGroundFactors,
nClusterVars, nClusterFactors, nWithoutNeighs));
Statistics::updateCompressingStatistics (
unsigned nrGroundVars,
unsigned nrGroundFactors,
unsigned nrClusterVars,
unsigned nrClusterFactors,
unsigned nrNeighborless) {
compressInfo_.push_back (CompressInfo (nrGroundVars, nrGroundFactors,
nrClusterVars, nrClusterFactors, nrNeighborless));
}
@ -334,26 +402,30 @@ Statistics::getStatisticString (void)
{
stringstream ss2, ss3, ss4, ss1;
ss1 << "running mode: " ;
switch (InfAlgorithms::infAlgorithm) {
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
case InfAlgorithms::BN_BP: ss1 << "bn_bp" << endl; break;
case InfAlgorithms::FG_BP: ss1 << "fg_bp" << endl; break;
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
switch (Globals::infAlgorithm) {
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
case InfAlgorithms::BP: ss1 << "bp" << endl; break;
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
}
ss1 << "message schedule: " ;
switch (BpOptions::schedule) {
case BpOptions::Schedule::SEQ_FIXED: ss1 << "sequential fixed" << endl; break;
case BpOptions::Schedule::SEQ_RANDOM: ss1 << "sequential random" << endl; break;
case BpOptions::Schedule::PARALLEL: ss1 << "parallel" << endl; break;
case BpOptions::Schedule::MAX_RESIDUAL: ss1 << "max residual" << endl; break;
case BpOptions::Schedule::SEQ_FIXED:
ss1 << "sequential fixed" << endl;
break;
case BpOptions::Schedule::SEQ_RANDOM:
ss1 << "sequential random" << endl;
break;
case BpOptions::Schedule::PARALLEL:
ss1 << "parallel" << endl;
break;
case BpOptions::Schedule::MAX_RESIDUAL:
ss1 << "max residual" << endl;
break;
}
ss1 << "max iterations: " << BpOptions::maxIter << endl;
ss1 << "accuracy " << BpOptions::accuracy << endl;
ss1 << endl << endl;
ss2 << "---------------------------------------------------" << endl;
ss2 << " Network information" << endl;
ss2 << "---------------------------------------------------" << endl;
Util::printSubHeader ("Network information", ss2);
ss2 << left;
ss2 << setw (15) << "Network Size" ;
ss2 << setw (9) << "Loopy" ;
@ -387,24 +459,22 @@ Statistics::getStatisticString (void)
unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0;
if (compressInfo_.size() > 0) {
ss3 << "---------------------------------------------------" << endl;
ss3 << " Compression information" << endl;
ss3 << "---------------------------------------------------" << endl;
Util::printSubHeader ("Compress information", ss3);
ss3 << left;
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
ss3 << "Vars Vars Factors Factors Vars" << endl;
for (unsigned i = 0; i < compressInfo_.size(); i++) {
ss3 << setw (9) << compressInfo_[i].nGroundVars;
ss3 << setw (10) << compressInfo_[i].nClusterVars;
ss3 << setw (10) << compressInfo_[i].nGroundFactors;
ss3 << setw (10) << compressInfo_[i].nClusterFactors;
ss3 << setw (10) << compressInfo_[i].nWithoutNeighs;
ss3 << setw (9) << compressInfo_[i].nrGroundVars;
ss3 << setw (10) << compressInfo_[i].nrClusterVars;
ss3 << setw (10) << compressInfo_[i].nrGroundFactors;
ss3 << setw (10) << compressInfo_[i].nrClusterFactors;
ss3 << setw (10) << compressInfo_[i].nrNeighborless;
ss3 << endl;
c1 += compressInfo_[i].nGroundVars - compressInfo_[i].nWithoutNeighs;
c2 += compressInfo_[i].nClusterVars;
c3 += compressInfo_[i].nGroundFactors - compressInfo_[i].nWithoutNeighs;
c4 += compressInfo_[i].nClusterFactors;
if (compressInfo_[i].nWithoutNeighs != 0) {
c1 += compressInfo_[i].nrGroundVars - compressInfo_[i].nrNeighborless;
c2 += compressInfo_[i].nrClusterVars;
c3 += compressInfo_[i].nrGroundFactors - compressInfo_[i].nrNeighborless;
c4 += compressInfo_[i].nrClusterFactors;
if (compressInfo_[i].nrNeighborless != 0) {
c2 --;
c4 --;
}

View File

@ -1,53 +1,141 @@
#ifndef HORUS_UTIL_H
#define HORUS_UTIL_H
#include <cmath>
#include <cassert>
#include <limits>
#include <algorithm>
#include <vector>
#include <set>
#include <queue>
#include <unordered_map>
#include <sstream>
#include <iostream>
#include "Horus.h"
using namespace std;
namespace Util {
void toLog (Params&);
void fromLog (Params&);
void normalize (Params&);
void logSum (double&, double);
void multiply (Params&, const Params&);
void multiply (Params&, const Params&, unsigned);
void add (Params&, const Params&);
void add (Params&, const Params&, unsigned);
void pow (Params&, double);
void pow (Params&, unsigned);
double pow (double, unsigned);
double factorial (double);
unsigned nrCombinations (unsigned, unsigned);
double getL1Distance (const Params&, const Params&);
double getMaxNorm (const Params&, const Params&);
unsigned getNumberOfDigits (int);
bool isInteger (const string&);
string parametersToString (const Params&, unsigned = PRECISION);
vector<string> getJointStateStrings (const VarNodes&);
double tl (double);
double fl (double);
double multIdenty();
double addIdenty();
double withEvidence();
double noEvidence();
double one();
double zero();
template <typename T> void addToVector (vector<T>&, const vector<T>&);
template <typename T> void addToSet (set<T>&, const vector<T>&);
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
template <typename T> bool contains (const vector<T>&, const T&);
template <typename T> bool contains (const set<T>&, const T&);
template <typename K, typename V> bool contains (
const unordered_map<K, V>&, const K&);
template <typename T> std::string toString (const T&);
void toLog (Params&);
void fromLog (Params&);
double logSum (double, double);
void multiply (Params&, const Params&);
void multiply (Params&, const Params&, unsigned);
void add (Params&, const Params&);
void add (Params&, const Params&, unsigned);
double factorial (double);
unsigned nrCombinations (unsigned, unsigned);
unsigned expectedSize (const Ranges&);
unsigned getNumberOfDigits (int);
bool isInteger (const string&);
string parametersToString (const Params&, unsigned = Constants::PRECISION);
vector<string> getStateLines (const Vars&);
void printHeader (string, std::ostream& os = std::cout);
void printSubHeader (string, std::ostream& os = std::cout);
void printAsteriskLine (std::ostream& os = std::cout);
void printDashedLine (std::ostream& os = std::cout);
unsigned maxUnsigned (void);
};
template <class T>
std::string toString (const T& t)
template <typename T> void
Util::addToVector (vector<T>& v, const vector<T>& elements)
{
v.insert (v.end(), elements.begin(), elements.end());
}
template <typename T> void
Util::addToSet (set<T>& s, const vector<T>& elements)
{
s.insert (elements.begin(), elements.end());
}
template <typename T> void
Util::addToQueue (queue<T>& q, const vector<T>& elements)
{
for (unsigned i = 0; i < elements.size(); i++) {
q.push (elements[i]);
}
}
template <typename T> bool
Util::contains (const vector<T>& v, const T& e)
{
return std::find (v.begin(), v.end(), e) != v.end();
}
template <typename T> bool
Util::contains (const set<T>& s, const T& e)
{
return s.find (e) != s.end();
}
template <typename K, typename V> bool
Util::contains (const unordered_map<K, V>& m, const K& k)
{
return m.find (k) != m.end();
}
template <typename T> std::string
Util::toString (const T& t)
{
std::stringstream ss;
ss << t;
return ss.str();
}
};
template <typename T>
@ -62,28 +150,31 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
}
namespace {
const double INF = -numeric_limits<double>::infinity();
};
inline void
Util::logSum (double& x, double y)
inline double
Util::logSum (double x, double y)
{
x = log (exp (x) + exp (y)); return;
return log (exp (x) + exp (y));
assert (isfinite (x) && isfinite (y));
// If one value is much smaller than the other, keep the larger value.
if (x < (y - log (1e200))) {
x = y;
return;
return y;
}
if (y < (x - log (1e200))) {
return;
return x;
}
double diff = x - y;
assert (isfinite (diff) && isfinite (x) && isfinite (y));
if (!isfinite (exp (diff))) { // difference is too large
x = x > y ? x : y;
} else { // otherwise return the sum.
x = y + log (static_cast<double>(1.0) + exp (diff));
if (!isfinite (exp (diff))) {
// difference is too large
return x > y ? x : y;
}
// otherwise return the sum.
return y + log (static_cast<double>(1.0) + exp (diff));
}
@ -140,52 +231,87 @@ Util::add (Params& v1, const Params& v2, unsigned repetitions)
inline double
Util::tl (double v)
inline unsigned
Util::maxUnsigned (void)
{
return Globals::logDomain ? log(v) : v;
return numeric_limits<unsigned>::max();
}
inline double
Util::fl (double v)
{
return Globals::logDomain ? exp(v) : v;
}
namespace LogAware {
inline double
Util::multIdenty() {
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
Util::addIdenty()
{
return Globals::logDomain ? INF : 0.0;
}
inline double
Util::withEvidence()
one()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
Util::noEvidence() {
return Globals::logDomain ? INF : 0.0;
}
inline double
Util::one()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
Util::zero() {
zero() {
return Globals::logDomain ? INF : 0.0 ;
}
inline double
addIdenty()
{
return Globals::logDomain ? INF : 0.0;
}
inline double
multIdenty()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
withEvidence()
{
return Globals::logDomain ? 0.0 : 1.0;
}
inline double
noEvidence() {
return Globals::logDomain ? INF : 0.0;
}
inline double
tl (double v)
{
return Globals::logDomain ? log (v) : v;
}
inline double
fl (double v)
{
return Globals::logDomain ? exp (v) : v;
}
void normalize (Params&);
double getL1Distance (const Params&, const Params&);
double getMaxNorm (const Params&, const Params&);
double pow (double, unsigned);
double pow (double, double);
void pow (Params&, unsigned);
void pow (Params&, double);
};
struct NetInfo
{
NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
@ -206,17 +332,17 @@ struct CompressInfo
{
CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e)
{
nGroundVars = a;
nGroundFactors = b;
nClusterVars = c;
nClusterFactors = d;
nWithoutNeighs = e;
nrGroundVars = a;
nrGroundFactors = b;
nrClusterVars = c;
nrClusterFactors = d;
nrNeighborless = e;
}
unsigned nGroundVars;
unsigned nGroundFactors;
unsigned nClusterVars;
unsigned nClusterFactors;
unsigned nWithoutNeighs;
unsigned nrGroundVars;
unsigned nrGroundFactors;
unsigned nrClusterVars;
unsigned nrClusterFactors;
unsigned nrNeighborless;
};
@ -224,11 +350,17 @@ class Statistics
{
public:
static unsigned getSolvedNetworksCounting (void);
static void incrementPrimaryNetworksCounting (void);
static unsigned getPrimaryNetworksCounting (void);
static void updateStatistics (unsigned, bool, unsigned, double);
static void printStatistics (void);
static void writeStatisticsToFile (const char*);
static void writeStatistics (const char*);
static void updateCompressingStatistics (
unsigned, unsigned, unsigned, unsigned, unsigned);

View File

@ -0,0 +1,102 @@
#include <algorithm>
#include <sstream>
#include "Var.h"
using namespace std;
unordered_map<VarId, VarInfo> Var::varsInfo_;
Var::Var (const Var* v)
{
varId_ = v->varId();
range_ = v->range();
evidence_ = v->getEvidence();
index_ = std::numeric_limits<unsigned>::max();
}
Var::Var (VarId varId, unsigned range, int evidence)
{
assert (range != 0);
assert (evidence < (int) range);
varId_ = varId;
range_ = range;
evidence_ = evidence;
index_ = std::numeric_limits<unsigned>::max();
}
bool
Var::isValidState (int stateIndex)
{
return stateIndex >= 0 && stateIndex < (int) range_;
}
bool
Var::isValidState (const string& stateName)
{
States states = Var::getVarInfo (varId_).states;
return Util::contains (states, stateName);
}
void
Var::setEvidence (int ev)
{
assert (ev < (int) range_);
evidence_ = ev;
}
void
Var::setEvidence (const string& ev)
{
States states = Var::getVarInfo (varId_).states;
for (unsigned i = 0; i < states.size(); i++) {
if (states[i] == ev) {
evidence_ = i;
return;
}
}
assert (false);
}
string
Var::label (void) const
{
if (Var::varsHaveInfo()) {
return Var::getVarInfo (varId_).label;
}
stringstream ss;
ss << "x" << varId_;
return ss.str();
}
States
Var::states (void) const
{
if (Var::varsHaveInfo()) {
return Var::getVarInfo (varId_).states;
}
States states;
for (unsigned i = 0; i < range_; i++) {
stringstream ss;
ss << i ;
states.push_back (ss.str());
}
return states;
}

View File

@ -0,0 +1,108 @@
#ifndef HORUS_Var_H
#define HORUS_Var_H
#include <cassert>
#include <iostream>
#include "Util.h"
#include "Horus.h"
using namespace std;
struct VarInfo
{
VarInfo (string l, const States& sts) : label(l), states(sts) { }
string label;
States states;
};
class Var
{
public:
Var (const Var*);
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
virtual ~Var (void) { };
unsigned varId (void) const { return varId_; }
unsigned range (void) const { return range_; }
int getEvidence (void) const { return evidence_; }
unsigned getIndex (void) const { return index_; }
void setIndex (unsigned idx) { index_ = idx; }
operator unsigned () const { return index_; }
bool hasEvidence (void) const
{
return evidence_ != Constants::NO_EVIDENCE;
}
bool operator== (const Var& var) const
{
assert (!(varId_ == var.varId() && range_ != var.range()));
return varId_ == var.varId();
}
bool operator!= (const Var& var) const
{
assert (!(varId_ == var.varId() && range_ != var.range()));
return varId_ != var.varId();
}
bool isValidState (int);
bool isValidState (const string&);
void setEvidence (int);
void setEvidence (const string&);
string label (void) const;
States states (void) const;
static void addVarInfo (
VarId vid, string label, const States& states)
{
assert (Util::contains (varsInfo_, vid) == false);
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
}
static VarInfo getVarInfo (VarId vid)
{
assert (Util::contains (varsInfo_, vid));
return varsInfo_.find (vid)->second;
}
static bool varsHaveInfo (void)
{
return varsInfo_.size() != 0;
}
static void clearVarsInfo (void)
{
varsInfo_.clear();
}
private:
VarId varId_;
unsigned range_;
int evidence_;
unsigned index_;
static unordered_map<VarId, VarInfo> varsInfo_;
};
#endif // BP_Var_H

View File

@ -6,61 +6,27 @@
#include "Util.h"
VarElimSolver::VarElimSolver (const BayesNet& bn) : Solver (&bn)
{
bayesNet_ = &bn;
factorGraph_ = new FactorGraph (bn);
}
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (&fg)
{
bayesNet_ = 0;
factorGraph_ = &fg;
}
VarElimSolver::~VarElimSolver (void)
{
if (bayesNet_) {
delete factorGraph_;
}
delete factorList_.back();
}
Params
VarElimSolver::getPosterioriOf (VarId vid)
{
assert (factorGraph_->getFgVarNode (vid));
FgVarNode* vn = factorGraph_->getFgVarNode (vid);
if (vn->hasEvidence()) {
Params params (vn->nrStates(), 0.0);
params[vn->getEvidence()] = 1.0;
return params;
}
return getJointDistributionOf (VarIds() = {vid});
}
Params
VarElimSolver::getJointDistributionOf (const VarIds& vids)
VarElimSolver::solveQuery (VarIds queryVids)
{
factorList_.clear();
varFactors_.clear();
elimOrder_.clear();
createFactorList();
introduceEvidence();
chooseEliminationOrder (vids);
processFactorList (vids);
Params params = factorList_.back()->getParameters();
absorveEvidence();
findEliminationOrder (queryVids);
processFactorList (queryVids);
Params params = factorList_.back()->params();
if (Globals::logDomain) {
Util::fromLog (params);
}
delete factorList_.back();
return params;
}
@ -69,11 +35,11 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
void
VarElimSolver::createFactorList (void)
{
const FgFacSet& factorNodes = factorGraph_->getFactorNodes();
factorList_.reserve (factorNodes.size() * 2);
for (unsigned i = 0; i < factorNodes.size(); i++) {
factorList_.push_back (new Factor (*factorNodes[i]->factor()));
const FgVarSet& neighs = factorNodes[i]->neighbors();
const FacNodes& facNodes = fg.facNodes();
factorList_.reserve (facNodes.size() * 2);
for (unsigned i = 0; i < facNodes.size(); i++) {
factorList_.push_back (new Factor (facNodes[i]->factor()));
const VarNodes& neighs = facNodes[i]->neighbors();
for (unsigned j = 0; j < neighs.size(); j++) {
unordered_map<VarId,vector<unsigned> >::iterator it
= varFactors_.find (neighs[j]->varId());
@ -89,16 +55,16 @@ VarElimSolver::createFactorList (void)
void
VarElimSolver::introduceEvidence (void)
VarElimSolver::absorveEvidence (void)
{
const FgVarSet& varNodes = factorGraph_->getVarNodes();
const VarNodes& varNodes = fg.varNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
if (varNodes[i]->hasEvidence()) {
const vector<unsigned>& idxs =
varFactors_.find (varNodes[i]->varId())->second;
for (unsigned j = 0; j < idxs.size(); j++) {
Factor* factor = factorList_[idxs[j]];
if (factor->nrVariables() == 1) {
if (factor->nrArguments() == 1) {
factorList_[idxs[j]] = 0;
} else {
factorList_[idxs[j]]->absorveEvidence (
@ -112,21 +78,9 @@ VarElimSolver::introduceEvidence (void)
void
VarElimSolver::chooseEliminationOrder (const VarIds& vids)
VarElimSolver::findEliminationOrder (const VarIds& vids)
{
if (bayesNet_) {
ElimGraph graph (*bayesNet_);
elimOrder_ = graph.getEliminatingOrder (vids);
} else {
const FgVarSet& varNodes = factorGraph_->getVarNodes();
for (unsigned i = 0; i < varNodes.size(); i++) {
VarId vid = varNodes[i]->varId();
if (std::find (vids.begin(), vids.end(), vid) == vids.end()
&& !varNodes[i]->hasEvidence()) {
elimOrder_.push_back (vid);
}
}
}
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
}
@ -149,12 +103,12 @@ VarElimSolver::processFactorList (const VarIds& vids)
VarIds unobservedVids;
for (unsigned i = 0; i < vids.size(); i++) {
if (factorGraph_->getFgVarNode (vids[i])->hasEvidence() == false) {
if (fg.getVarNode (vids[i])->hasEvidence() == false) {
unobservedVids.push_back (vids[i]);
}
}
finalFactor->reorderVariables (unobservedVids);
finalFactor->reorderArguments (unobservedVids);
finalFactor->normalize();
factorList_.push_back (finalFactor);
}
@ -165,13 +119,12 @@ void
VarElimSolver::eliminate (VarId elimVar)
{
Factor* result = 0;
FgVarNode* vn = factorGraph_->getFgVarNode (elimVar);
vector<unsigned>& idxs = varFactors_.find (elimVar)->second;
for (unsigned i = 0; i < idxs.size(); i++) {
unsigned idx = idxs[i];
if (factorList_[idx]) {
if (result == 0) {
result = new Factor(*factorList_[idx]);
result = new Factor (*factorList_[idx]);
} else {
result->multiply (*factorList_[idx]);
}
@ -179,10 +132,10 @@ VarElimSolver::eliminate (VarId elimVar)
factorList_[idx] = 0;
}
}
if (result != 0 && result->nrVariables() != 1) {
result->sumOut (vn->varId());
if (result != 0 && result->nrArguments() != 1) {
result->sumOut (elimVar);
factorList_.push_back (result);
const VarIds& resultVarIds = result->getVarIds();
const VarIds& resultVarIds = result->arguments();
for (unsigned i = 0; i < resultVarIds.size(); i++) {
vector<unsigned>& idxs =
varFactors_.find (resultVarIds[i])->second;
@ -199,7 +152,6 @@ VarElimSolver::printActiveFactors (void)
for (unsigned i = 0; i < factorList_.size(); i++) {
if (factorList_[i] != 0) {
factorList_[i]->print();
cout << endl;
}
}
}

View File

@ -5,7 +5,6 @@
#include "Solver.h"
#include "FactorGraph.h"
#include "BayesNet.h"
#include "Horus.h"
@ -15,23 +14,25 @@ using namespace std;
class VarElimSolver : public Solver
{
public:
VarElimSolver (const BayesNet&);
VarElimSolver (const FactorGraph&);
VarElimSolver (const FactorGraph& fg) : Solver (fg) { }
~VarElimSolver (void);
void runSolver (void) { }
Params getPosterioriOf (VarId);
Params getJointDistributionOf (const VarIds&);
Params solveQuery (VarIds);
private:
void createFactorList (void);
void introduceEvidence (void);
void chooseEliminationOrder (const VarIds&);
void absorveEvidence (void);
void findEliminationOrder (const VarIds&);
void processFactorList (const VarIds&);
void eliminate (VarId);
void printActiveFactors (void);
const BayesNet* bayesNet_;
const FactorGraph* factorGraph_;
vector<Factor*> factorList_;
VarIds elimOrder_;
unordered_map<VarId, vector<unsigned>> varFactors_;

View File

@ -1,100 +0,0 @@
#include <algorithm>
#include <sstream>
#include "VarNode.h"
#include "GraphicalModel.h"
using namespace std;
VarNode::VarNode (const VarNode* v)
{
varId_ = v->varId();
nrStates_ = v->nrStates();
evidence_ = v->getEvidence();
index_ = std::numeric_limits<unsigned>::max();
}
VarNode::VarNode (VarId varId, unsigned nrStates, int evidence)
{
assert (nrStates != 0);
assert (evidence < (int) nrStates);
varId_ = varId;
nrStates_ = nrStates;
evidence_ = evidence;
index_ = std::numeric_limits<unsigned>::max();
}
bool
VarNode::isValidState (int stateIndex)
{
return stateIndex >= 0 && stateIndex < (int) nrStates_;
}
bool
VarNode::isValidState (const string& stateName)
{
States states = GraphicalModel::getVariableInformation (varId_).states;
return find (states.begin(), states.end(), stateName) != states.end();
}
void
VarNode::setEvidence (int ev)
{
assert (ev < (int) nrStates_);
evidence_ = ev;
}
void
VarNode::setEvidence (const string& ev)
{
States states = GraphicalModel::getVariableInformation (varId_).states;
for (unsigned i = 0; i < states.size(); i++) {
if (states[i] == ev) {
evidence_ = i;
return;
}
}
assert (false);
}
string
VarNode::label (void) const
{
if (GraphicalModel::variablesHaveInformation()) {
return GraphicalModel::getVariableInformation (varId_).label;
}
stringstream ss;
ss << "x" << varId_;
return ss.str();
}
States
VarNode::states (void) const
{
if (GraphicalModel::variablesHaveInformation()) {
return GraphicalModel::getVariableInformation (varId_).states;
}
States states;
for (unsigned i = 0; i < nrStates_; i++) {
stringstream ss;
ss << i ;
states.push_back (ss.str());
}
return states;
}

View File

@ -1,54 +0,0 @@
#ifndef HORUS_VARNODE_H
#define HORUS_VARNODE_H
#include "Horus.h"
using namespace std;
class VarNode
{
public:
VarNode (const VarNode*);
VarNode (VarId, unsigned, int = NO_EVIDENCE);
virtual ~VarNode (void) {};
bool isValidState (int);
bool isValidState (const string&);
void setEvidence (int);
void setEvidence (const string&);
string label (void) const;
States states (void) const;
unsigned varId (void) const { return varId_; }
unsigned nrStates (void) const { return nrStates_; }
bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; }
int getEvidence (void) const { return evidence_; }
unsigned getIndex (void) const { return index_; }
void setIndex (unsigned idx) { index_ = idx; }
operator unsigned () const { return index_; }
bool operator== (const VarNode& var) const
{
cout << "equal operator called" << endl;
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
return varId_ == var.varId();
}
bool operator!= (const VarNode& var) const
{
cout << "diff operator called" << endl;
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
return varId_ != var.varId();
}
private:
VarId varId_;
unsigned nrStates_;
int evidence_;
unsigned index_;
};
#endif // BP_VARNODE_H

View File

@ -0,0 +1,35 @@
if [ $1 ] && [ $1 == "clear" ]; then
rm *~
rm -f school/*.log school/*~
rm -f city/*.log city/*~
rm -f workshop_attrs/*.log workshop_attrs/*~
fi
function run_solver
{
constraint=$1
solver_flag=true
if [ -n "$2" ]; then
if [ $SOLVER = hve ]; then
extra_flag=clpbn_horus:set_horus_flag\(elim_heuristic,$2\)
elif [ $SOLVER = bp ]; then
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
elif [ $SOLVER = cbp ]; then
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
else
echo "unknow flag $2"
fi
fi
/usr/bin/time -o $LOG_FILE -a -f "real:%E\tuser:%U\tsys:%S" \
$YAP << EOF >> $LOG_FILE 2>> ignore.$LOG_FILE
[$NETWORK].
[$constraint].
clpbn_horus:set_solver($SOLVER).
clpbn_horus:set_horus_flag(use_logarithms, true).
$solver_flag.
$QUERY.
open("$LOG_FILE", 'append', S), format(S, '$constraint: ~15+ ', []), close(S).
EOF
}

View File

@ -1,50 +0,0 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_bnbp
YAP=~/bin/town_bnbp
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=bnbp.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
run_solver town_5000 $1 town_5000 $3 $4 $5
run_solver town_10000 $1 town_10000 $3 $4 $5
run_solver town_50000 $1 town_50000 $3 $4 $5
run_solver town_100000 $1 town_100000 $3 $4 $5
run_solver town_500000 $1 town_500000 $3 $4 $5
run_solver town_1000000 $1 town_1000000 $3 $4 $5
}
run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed

View File

@ -0,0 +1,17 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="bp"
YAP=~/bin/$SHORTNAME-$SOLVER
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed

View File

@ -1,54 +1,17 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_cbp
YAP=~/bin/town_cbp
source city.sh
source ../benchs.sh
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=cbp.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
SOLVER="cbp"
YAP=~/bin/$SHORTNAME-$SOLVER
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
run_solver town_5000 $1 town_5000 $3 $4 $5
run_solver town_10000 $1 town_10000 $3 $4 $5
run_solver town_50000 $1 town_50000 $3 $4 $5
run_solver town_100000 $1 town_100000 $3 $4 $5
run_solver town_500000 $1 town_500000 $3 $4 $5
run_solver town_1000000 $1 town_1000000 $3 $4 $5
run_solver town_2500000 $1 town_2500000 $3 $4 $5
run_solver town_5000000 $1 town_5000000 $3 $4 $5
run_solver town_7500000 $1 town_7500000 $3 $4 $5
run_solver town_10000000 $1 town_10000000 $3 $4 $5
}
run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed

View File

@ -0,0 +1,25 @@
#!/bin/bash
NETWORK="'../../examples/city'"
SHORTNAME="city"
QUERY="is_joe_guilty(X)"
function run_all_graphs
{
cp ~/bin/yap $YAP
echo -n "**********************************" >> $LOG_FILE
echo "**********************************" >> $LOG_FILE
echo "results for solver $1" >> $LOG_FILE
echo -n "**********************************" >> $LOG_FILE
echo "**********************************" >> $LOG_FILE
run_solver city_5 $2
#run_solver city_1000 $2
#run_solver city_5000 $2
#run_solver city_10000 $2
#run_solver city_50000 $2
#run_solver city_100000 $2
#run_solver city_500000 $2
#run_solver city_1000000 $2
}

View File

@ -0,0 +1,37 @@
#!/home/tiago/bin/yap -L --
:- initialization(main).
main :-
unix(argv([H])),
generate_town(H).
generate_town(N) :-
atomic_concat(['city_', N, '.yap'], FileName),
open(FileName, 'write', S),
atom_number(N, N2),
generate_people(S, N2, 4),
write(S, '\n'),
generate_query(S, N2, 4),
write(S, '\n'),
close(S).
generate_people(S, N, Counting) :-
Counting > N, !.
generate_people(S, N, Counting) :-
format(S, 'people(p~w, nyc).~n', [Counting]),
Counting1 is Counting + 1,
generate_people(S, N, Counting1).
generate_query(S, N, Counting) :-
Counting > N, !.
generate_query(S, N, Counting) :- !,
format(S, 'ev(descn(p~w, t)).~n', [Counting]),
Counting1 is Counting + 1,
generate_query(S, N, Counting1).

View File

@ -1,50 +0,0 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_fgbp
YAP=~/bin/town_fgbp
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=fb_bp.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
#run_solver town_5000 $1 town_5000 $3 $4 $5
#run_solver town_10000 $1 town_10000 $3 $4 $5
#run_solver town_50000 $1 town_50000 $3 $4 $5
#run_solver town_100000 $1 town_100000 $3 $4 $5
#run_solver town_500000 $1 town_500000 $3 $4 $5
#run_solver town_1000000 $1 town_1000000 $3 $4 $5
}
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed

View File

@ -0,0 +1,17 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="fove"
YAP=~/bin/$SHORTNAME-$SOLVER
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILEE
run_all_graphs "fove "

View File

@ -1,50 +0,0 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_gibbs
YAP=~/bin/town_gibbs
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=gibbs.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
run_solver town_5000 $1 town_5000 $3 $4 $5
run_solver town_10000 $1 town_10000 $3 $4 $5
run_solver town_50000 $1 town_50000 $3 $4 $5
run_solver town_100000 $1 town_100000 $3 $4 $5
run_solver town_500000 $1 town_500000 $3 $4 $5
run_solver town_1000000 $1 town_1000000 $3 $4 $5
}
run_all_graphs gibbs "gibbs "

View File

@ -0,0 +1,17 @@
#!/bin/bash
source city.sh
source ../benchs.sh
SOLVER="hve"
YAP=~/bin/$SHORTNAME-$SOLVER
LOG_FILE=$SOLVER.log
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
rm -f $LOG_FILE
rm -f ignore.$LOG_FILE
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors

View File

@ -1,50 +0,0 @@
#!/bin/bash
cp ~/bin/yap ~/bin/town_jt
YAP=~/bin/town_jt
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
OUT_FILE_NAME=jt.log
rm -f $OUT_FILE_NAME
rm -f ignore.$OUT_FILE_NAME
function run_solver
{
if [ $2 = bp ]
then
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
else
extra_flag1=true
extra_flag2=true
fi
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
[$1].
clpbn:set_clpbn_flag(solver,$2),
clpbn_bp:set_horus_flag(use_logarithms, true),
$extra_flag1, $extra_flag2,
run_query(_R),
open("$OUT_FILE_NAME", 'append',S),
format(S, '$3: ~15+ ',[]),
close(S).
EOF
}
function run_all_graphs
{
echo "*******************************************************************" >> "$OUT_FILE_NAME"
echo "results for solver $2" >> $OUT_FILE_NAME
echo "*******************************************************************" >> "$OUT_FILE_NAME"
run_solver town_1000 $1 town_1000 $3 $4 $5
run_solver town_5000 $1 town_5000 $3 $4 $5
run_solver town_10000 $1 town_10000 $3 $4 $5
run_solver town_50000 $1 town_50000 $3 $4 $5
run_solver town_100000 $1 town_100000 $3 $4 $5
run_solver town_500000 $1 town_500000 $3 $4 $5
run_solver town_1000000 $1 town_1000000 $3 $4 $5
}
run_all_graphs jt "jt "

View File

@ -1,65 +0,0 @@
conservative_city(City, Cons) :-
cons_table(City, ConsDist),
{ Cons = conservative_city(City) with p([y,n], ConsDist) }.
gender(X, Gender) :-
gender_table(X, GenderDist),
{ Gender = gender(X) with p([m,f], GenderDist) }.
hair_color(X, Color) :-
lives(X, City),
conservative_city(City, Cons),
hair_color_table(X,ColorTable),
{ Color = hair_color(X) with
p([t,f], ColorTable,[Cons]) }.
car_color(X, Color) :-
hair_color(X, HColor),
car_color_table(X,CColorTable),
{ Color = car_color(X) with
p([t,f], CColorTable,[HColor]) }.
height(X, Height) :-
gender(X, Gender),
height_table(X,HeightTable),
{ Height = height(X) with
p([t,f], HeightTable,[Gender]) }.
shoe_size(X, Shoesize) :-
height(X, Height),
shoe_size_table(X,ShoesizeTable),
{ Shoesize = shoe_size(X) with
p([t,f], ShoesizeTable,[Height]) }.
guilty(X, Guilt) :-
guilty_table(X, GuiltDist),
{ Guilt = guilty(X) with p([y,n], GuiltDist) }.
descn(X, Descn) :-
car_color(X, Car),
hair_color(X, Hair),
height(X, Height),
guilty(X, Guilt),
descn_table(X, DescTable),
{ Descn = descn(X) with
p([t,f], DescTable,[Car,Hair,Height,Guilt]) }.
witness(City, Witness) :-
descn(joe, DescnJ),
descn(p2, Descn2),
wit_table(WitTable),
{ Witness = witness(City) with
p([t,f], WitTable,[DescnJ, Descn2]) }.
:- ensure_loaded(tables).

View File

@ -1,46 +0,0 @@
cons_table(amsterdam, [0.2, 0.8]) :- !.
cons_table(_, [0.8, 0.2]).
gender_table(_, [0.55, 0.44]).
hair_color_table(_,
/* conservative_city */
/* y n */
[ 0.05, 0.1,
0.95, 0.9 ]).
car_color_table(_,
/* t f */
[ 0.9, 0.2,
0.1, 0.8 ]).
height_table(_,
/* m f */
[ 0.6, 0.4,
0.4, 0.6 ]).
shoe_size_table(_,
/* t f */
[ 0.9, 0.1,
0.1, 0.9 ]).
guilty_table(_, [0.23, 0.77]).
descn_table(_,
/* color, hair, height, guilt */
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
[ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.33, 0.91, 0.95, 0.92,
0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.61, 0.28, 0.77, 0.09, 0.05, 0.08]).
wit_table([0.2, 0.45, 0.24, 0.34,
0.8, 0.55, 0.76, 0.66]).

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More