Merge branch 'master' of git://yap.git.sourceforge.net/gitroot/yap/yap-6.3
This commit is contained in:
commit
b66b261972
@ -490,7 +490,6 @@ Yap_HasOp(Atom a)
|
|||||||
OpEntry *
|
OpEntry *
|
||||||
Yap_OpPropForModule(Atom a, Term mod)
|
Yap_OpPropForModule(Atom a, Term mod)
|
||||||
{ /* look property list of atom a for kind */
|
{ /* look property list of atom a for kind */
|
||||||
CACHE_REGS
|
|
||||||
AtomEntry *ae = RepAtom(a);
|
AtomEntry *ae = RepAtom(a);
|
||||||
PropEntry *pp;
|
PropEntry *pp;
|
||||||
OpEntry *info = NULL;
|
OpEntry *info = NULL;
|
||||||
@ -767,6 +766,7 @@ ExpandPredHash(void)
|
|||||||
Prop
|
Prop
|
||||||
Yap_NewPredPropByFunctor(FunctorEntry *fe, Term cur_mod)
|
Yap_NewPredPropByFunctor(FunctorEntry *fe, Term cur_mod)
|
||||||
{
|
{
|
||||||
|
CACHE_REGS
|
||||||
PredEntry *p = (PredEntry *) Yap_AllocAtomSpace(sizeof(*p));
|
PredEntry *p = (PredEntry *) Yap_AllocAtomSpace(sizeof(*p));
|
||||||
|
|
||||||
if (p == NULL) {
|
if (p == NULL) {
|
||||||
@ -902,6 +902,7 @@ Yap_NewThreadPred(PredEntry *ap USES_REGS)
|
|||||||
Prop
|
Prop
|
||||||
Yap_NewPredPropByAtom(AtomEntry *ae, Term cur_mod)
|
Yap_NewPredPropByAtom(AtomEntry *ae, Term cur_mod)
|
||||||
{
|
{
|
||||||
|
CACHE_REGS
|
||||||
Prop p0;
|
Prop p0;
|
||||||
PredEntry *p = (PredEntry *) Yap_AllocAtomSpace(sizeof(*p));
|
PredEntry *p = (PredEntry *) Yap_AllocAtomSpace(sizeof(*p));
|
||||||
|
|
||||||
|
@ -2053,6 +2053,7 @@ a_try(op_numbers opcode, CELL lab, CELL opr, int nofalts, int hascut, yamop *cod
|
|||||||
yamop *newcp;
|
yamop *newcp;
|
||||||
/* emit a special instruction and then a label for backpatching */
|
/* emit a special instruction and then a label for backpatching */
|
||||||
if (pass_no) {
|
if (pass_no) {
|
||||||
|
CACHE_REGS
|
||||||
UInt size = (UInt)NEXTOP((yamop *)NULL,OtaLl);
|
UInt size = (UInt)NEXTOP((yamop *)NULL,OtaLl);
|
||||||
if ((newcp = (yamop *)Yap_AllocCodeSpace(size)) == NULL) {
|
if ((newcp = (yamop *)Yap_AllocCodeSpace(size)) == NULL) {
|
||||||
/* OOOPS, got in trouble, must do a longjmp and recover space */
|
/* OOOPS, got in trouble, must do a longjmp and recover space */
|
||||||
|
@ -561,6 +561,7 @@ X_API YAP_tag_t STD_PROTO(YAP_TagOfTerm,(Term));
|
|||||||
X_API size_t STD_PROTO(YAP_ExportTerm,(Term, char *, size_t));
|
X_API size_t STD_PROTO(YAP_ExportTerm,(Term, char *, size_t));
|
||||||
X_API size_t STD_PROTO(YAP_SizeOfExportedTerm,(char *));
|
X_API size_t STD_PROTO(YAP_SizeOfExportedTerm,(char *));
|
||||||
X_API Term STD_PROTO(YAP_ImportTerm,(char *));
|
X_API Term STD_PROTO(YAP_ImportTerm,(char *));
|
||||||
|
X_API int STD_PROTO(YAP_RequiresExtraStack,(size_t));
|
||||||
|
|
||||||
static UInt
|
static UInt
|
||||||
current_arity(void)
|
current_arity(void)
|
||||||
@ -2705,7 +2706,6 @@ YAP_InitConsult(int mode, char *filename)
|
|||||||
X_API IOSTREAM *
|
X_API IOSTREAM *
|
||||||
YAP_TermToStream(Term t)
|
YAP_TermToStream(Term t)
|
||||||
{
|
{
|
||||||
CACHE_REGS
|
|
||||||
IOSTREAM *s;
|
IOSTREAM *s;
|
||||||
BACKUP_MACHINE_REGS();
|
BACKUP_MACHINE_REGS();
|
||||||
|
|
||||||
@ -2937,7 +2937,13 @@ YAP_Init(YAP_init_args *yap_init)
|
|||||||
int restore_result;
|
int restore_result;
|
||||||
int do_bootstrap = (yap_init->YapPrologBootFile != NULL);
|
int do_bootstrap = (yap_init->YapPrologBootFile != NULL);
|
||||||
CELL Trail = 0, Stack = 0, Heap = 0, Atts = 0;
|
CELL Trail = 0, Stack = 0, Heap = 0, Atts = 0;
|
||||||
static char boot_file[256];
|
char boot_file[256];
|
||||||
|
static int initialised = FALSE;
|
||||||
|
|
||||||
|
/* ignore repeated calls to YAP_Init */
|
||||||
|
if (initialised)
|
||||||
|
return YAP_BOOT_DONE_BEFOREHAND;
|
||||||
|
initialised = TRUE;
|
||||||
|
|
||||||
Yap_InitPageSize(); /* init memory page size, required by later functions */
|
Yap_InitPageSize(); /* init memory page size, required by later functions */
|
||||||
#if defined(YAPOR_COPY) || defined(YAPOR_COW) || defined(YAPOR_SBA)
|
#if defined(YAPOR_COPY) || defined(YAPOR_COW) || defined(YAPOR_SBA)
|
||||||
@ -3612,9 +3618,22 @@ YAP_ListToFloats(Term t, double *dblp, size_t sz)
|
|||||||
if (!IsPairTerm(t))
|
if (!IsPairTerm(t))
|
||||||
return -1;
|
return -1;
|
||||||
hd = HeadOfTerm(t);
|
hd = HeadOfTerm(t);
|
||||||
if (!IsFloatTerm(hd))
|
if (IsFloatTerm(hd)) {
|
||||||
return -1;
|
dblp[i++] = FloatOfTerm(hd);
|
||||||
dblp[i++] = FloatOfTerm(hd);
|
} else {
|
||||||
|
extern double Yap_gmp_to_float(Term hd);
|
||||||
|
|
||||||
|
if (IsIntTerm(hd))
|
||||||
|
dblp[i++] = IntOfTerm(hd);
|
||||||
|
else if (IsLongIntTerm(hd))
|
||||||
|
dblp[i++] = LongIntOfTerm(hd);
|
||||||
|
#if USE_GMP
|
||||||
|
else if (IsBigIntTerm(hd))
|
||||||
|
dblp[i++] = Yap_gmp_to_float(hd);
|
||||||
|
#endif
|
||||||
|
else
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
if (i == sz)
|
if (i == sz)
|
||||||
return sz;
|
return sz;
|
||||||
t = TailOfTerm(t);
|
t = TailOfTerm(t);
|
||||||
@ -4108,3 +4127,24 @@ YAP_ImportTerm(char * buf) {
|
|||||||
return Yap_ImportTerm(buf);
|
return Yap_ImportTerm(buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
X_API int
|
||||||
|
YAP_RequiresExtraStack(size_t sz) {
|
||||||
|
CACHE_REGS
|
||||||
|
|
||||||
|
if (sz < 16*1024)
|
||||||
|
sz = 16*1024;
|
||||||
|
if (H <= ASP-sz) {
|
||||||
|
return FALSE;
|
||||||
|
}
|
||||||
|
BACKUP_H();
|
||||||
|
while (H > ASP-sz) {
|
||||||
|
CACHE_REGS
|
||||||
|
RECOVER_H();
|
||||||
|
if (!dogc( 0, NULL PASS_REGS )) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
BACKUP_H();
|
||||||
|
}
|
||||||
|
RECOVER_H();
|
||||||
|
return TRUE;
|
||||||
|
}
|
||||||
|
@ -5107,6 +5107,8 @@ p_continue_static_clause( USES_REGS1 )
|
|||||||
static void
|
static void
|
||||||
add_code_in_lu_index(LogUpdIndex *cl, PredEntry *pp)
|
add_code_in_lu_index(LogUpdIndex *cl, PredEntry *pp)
|
||||||
{
|
{
|
||||||
|
CACHE_REGS
|
||||||
|
|
||||||
char *code_end = (char *)cl + cl->ClSize;
|
char *code_end = (char *)cl + cl->ClSize;
|
||||||
Yap_inform_profiler_of_clause(cl, code_end, pp, GPROF_LU_INDEX);
|
Yap_inform_profiler_of_clause(cl, code_end, pp, GPROF_LU_INDEX);
|
||||||
cl = cl->ChildIndex;
|
cl = cl->ChildIndex;
|
||||||
@ -5119,6 +5121,7 @@ add_code_in_lu_index(LogUpdIndex *cl, PredEntry *pp)
|
|||||||
static void
|
static void
|
||||||
add_code_in_static_index(StaticIndex *cl, PredEntry *pp)
|
add_code_in_static_index(StaticIndex *cl, PredEntry *pp)
|
||||||
{
|
{
|
||||||
|
CACHE_REGS
|
||||||
char *code_end = (char *)cl + cl->ClSize;
|
char *code_end = (char *)cl + cl->ClSize;
|
||||||
Yap_inform_profiler_of_clause(cl, code_end, pp, GPROF_STATIC_INDEX);
|
Yap_inform_profiler_of_clause(cl, code_end, pp, GPROF_STATIC_INDEX);
|
||||||
cl = cl->ChildIndex;
|
cl = cl->ChildIndex;
|
||||||
@ -5131,6 +5134,7 @@ add_code_in_static_index(StaticIndex *cl, PredEntry *pp)
|
|||||||
|
|
||||||
static void
|
static void
|
||||||
add_code_in_pred(PredEntry *pp) {
|
add_code_in_pred(PredEntry *pp) {
|
||||||
|
CACHE_REGS
|
||||||
yamop *clcode;
|
yamop *clcode;
|
||||||
|
|
||||||
PELOCK(49,pp);
|
PELOCK(49,pp);
|
||||||
@ -5202,6 +5206,7 @@ add_code_in_pred(PredEntry *pp) {
|
|||||||
|
|
||||||
void
|
void
|
||||||
Yap_dump_code_area_for_profiler(void) {
|
Yap_dump_code_area_for_profiler(void) {
|
||||||
|
CACHE_REGS
|
||||||
ModEntry *me = CurrentModules;
|
ModEntry *me = CurrentModules;
|
||||||
|
|
||||||
while (me) {
|
while (me) {
|
||||||
|
@ -1887,6 +1887,7 @@ Yap_new_ludbe(Term t, PredEntry *pe, UInt nargs)
|
|||||||
static LogUpdClause *
|
static LogUpdClause *
|
||||||
record_lu(PredEntry *pe, Term t, int position)
|
record_lu(PredEntry *pe, Term t, int position)
|
||||||
{
|
{
|
||||||
|
CACHE_REGS
|
||||||
LogUpdClause *cl;
|
LogUpdClause *cl;
|
||||||
|
|
||||||
if ((cl = new_lu_db_entry(t, pe)) == NULL) {
|
if ((cl = new_lu_db_entry(t, pe)) == NULL) {
|
||||||
|
1
C/exec.c
1
C/exec.c
@ -1066,6 +1066,7 @@ do_goal(Term t, yamop *CodeAdr, int arity, CELL *pt, int top USES_REGS)
|
|||||||
S = CellPtr (RepPredProp (PredPropByFunc (Yap_MkFunctor(AtomCall, 1),0))); /* A1 mishaps */
|
S = CellPtr (RepPredProp (PredPropByFunc (Yap_MkFunctor(AtomCall, 1),0))); /* A1 mishaps */
|
||||||
|
|
||||||
out = exec_absmi(top PASS_REGS);
|
out = exec_absmi(top PASS_REGS);
|
||||||
|
Yap_flush();
|
||||||
// if (out) {
|
// if (out) {
|
||||||
// out = Yap_GetFromSlot(sl);
|
// out = Yap_GetFromSlot(sl);
|
||||||
// }
|
// }
|
||||||
|
27
C/gprof.c
27
C/gprof.c
@ -168,6 +168,7 @@ RBfree(rb_red_blk_node *ptr)
|
|||||||
|
|
||||||
static rb_red_blk_node *
|
static rb_red_blk_node *
|
||||||
RBTreeCreate(void) {
|
RBTreeCreate(void) {
|
||||||
|
CACHE_REGS
|
||||||
rb_red_blk_node* temp;
|
rb_red_blk_node* temp;
|
||||||
|
|
||||||
/* see the comment in the rb_red_blk_tree structure in red_black_tree.h */
|
/* see the comment in the rb_red_blk_tree structure in red_black_tree.h */
|
||||||
@ -210,6 +211,7 @@ RBTreeCreate(void) {
|
|||||||
|
|
||||||
static void
|
static void
|
||||||
LeftRotate(rb_red_blk_node* x) {
|
LeftRotate(rb_red_blk_node* x) {
|
||||||
|
CACHE_REGS
|
||||||
rb_red_blk_node* y;
|
rb_red_blk_node* y;
|
||||||
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
||||||
|
|
||||||
@ -266,6 +268,7 @@ LeftRotate(rb_red_blk_node* x) {
|
|||||||
|
|
||||||
static void
|
static void
|
||||||
RightRotate(rb_red_blk_node* y) {
|
RightRotate(rb_red_blk_node* y) {
|
||||||
|
CACHE_REGS
|
||||||
rb_red_blk_node* x;
|
rb_red_blk_node* x;
|
||||||
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
||||||
|
|
||||||
@ -318,6 +321,7 @@ RightRotate(rb_red_blk_node* y) {
|
|||||||
|
|
||||||
static void
|
static void
|
||||||
TreeInsertHelp(rb_red_blk_node* z) {
|
TreeInsertHelp(rb_red_blk_node* z) {
|
||||||
|
CACHE_REGS
|
||||||
/* This function should only be called by InsertRBTree (see above) */
|
/* This function should only be called by InsertRBTree (see above) */
|
||||||
rb_red_blk_node* x;
|
rb_red_blk_node* x;
|
||||||
rb_red_blk_node* y;
|
rb_red_blk_node* y;
|
||||||
@ -369,6 +373,7 @@ TreeInsertHelp(rb_red_blk_node* z) {
|
|||||||
|
|
||||||
static rb_red_blk_node *
|
static rb_red_blk_node *
|
||||||
RBTreeInsert(yamop *key, yamop *lim) {
|
RBTreeInsert(yamop *key, yamop *lim) {
|
||||||
|
CACHE_REGS
|
||||||
rb_red_blk_node * y;
|
rb_red_blk_node * y;
|
||||||
rb_red_blk_node * x;
|
rb_red_blk_node * x;
|
||||||
rb_red_blk_node * newNode;
|
rb_red_blk_node * newNode;
|
||||||
@ -440,6 +445,7 @@ RBTreeInsert(yamop *key, yamop *lim) {
|
|||||||
|
|
||||||
static rb_red_blk_node*
|
static rb_red_blk_node*
|
||||||
RBExactQuery(yamop* q) {
|
RBExactQuery(yamop* q) {
|
||||||
|
CACHE_REGS
|
||||||
rb_red_blk_node* x;
|
rb_red_blk_node* x;
|
||||||
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
||||||
|
|
||||||
@ -460,6 +466,7 @@ RBExactQuery(yamop* q) {
|
|||||||
|
|
||||||
static rb_red_blk_node*
|
static rb_red_blk_node*
|
||||||
RBLookup(yamop *entry) {
|
RBLookup(yamop *entry) {
|
||||||
|
CACHE_REGS
|
||||||
rb_red_blk_node *current;
|
rb_red_blk_node *current;
|
||||||
|
|
||||||
if (!LOCAL_ProfilerRoot)
|
if (!LOCAL_ProfilerRoot)
|
||||||
@ -495,6 +502,7 @@ RBLookup(yamop *entry) {
|
|||||||
/***********************************************************************/
|
/***********************************************************************/
|
||||||
|
|
||||||
static void RBDeleteFixUp(rb_red_blk_node* x) {
|
static void RBDeleteFixUp(rb_red_blk_node* x) {
|
||||||
|
CACHE_REGS
|
||||||
rb_red_blk_node* root=LOCAL_ProfilerRoot->left;
|
rb_red_blk_node* root=LOCAL_ProfilerRoot->left;
|
||||||
rb_red_blk_node *w;
|
rb_red_blk_node *w;
|
||||||
|
|
||||||
@ -574,6 +582,7 @@ static void RBDeleteFixUp(rb_red_blk_node* x) {
|
|||||||
|
|
||||||
static rb_red_blk_node*
|
static rb_red_blk_node*
|
||||||
TreeSuccessor(rb_red_blk_node* x) {
|
TreeSuccessor(rb_red_blk_node* x) {
|
||||||
|
CACHE_REGS
|
||||||
rb_red_blk_node* y;
|
rb_red_blk_node* y;
|
||||||
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
||||||
rb_red_blk_node* root=LOCAL_ProfilerRoot;
|
rb_red_blk_node* root=LOCAL_ProfilerRoot;
|
||||||
@ -612,6 +621,7 @@ TreeSuccessor(rb_red_blk_node* x) {
|
|||||||
|
|
||||||
static void
|
static void
|
||||||
RBDelete(rb_red_blk_node* z){
|
RBDelete(rb_red_blk_node* z){
|
||||||
|
CACHE_REGS
|
||||||
rb_red_blk_node* y;
|
rb_red_blk_node* y;
|
||||||
rb_red_blk_node* x;
|
rb_red_blk_node* x;
|
||||||
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
rb_red_blk_node* nil=LOCAL_ProfilerNil;
|
||||||
@ -664,7 +674,8 @@ RBDelete(rb_red_blk_node* z){
|
|||||||
|
|
||||||
char *set_profile_dir(char *);
|
char *set_profile_dir(char *);
|
||||||
char *set_profile_dir(char *name){
|
char *set_profile_dir(char *name){
|
||||||
int size=0;
|
CACHE_REGS
|
||||||
|
int size=0;
|
||||||
|
|
||||||
if (name!=NULL) {
|
if (name!=NULL) {
|
||||||
size=strlen(name)+1;
|
size=strlen(name)+1;
|
||||||
@ -687,8 +698,9 @@ return LOCAL_DIRNAME;
|
|||||||
|
|
||||||
char *profile_names(int);
|
char *profile_names(int);
|
||||||
char *profile_names(int k) {
|
char *profile_names(int k) {
|
||||||
static char *FNAME=NULL;
|
CACHE_REGS
|
||||||
int size=200;
|
static char *FNAME=NULL;
|
||||||
|
int size=200;
|
||||||
|
|
||||||
if (LOCAL_DIRNAME==NULL) set_profile_dir(NULL);
|
if (LOCAL_DIRNAME==NULL) set_profile_dir(NULL);
|
||||||
size=strlen(LOCAL_DIRNAME)+40;
|
size=strlen(LOCAL_DIRNAME)+40;
|
||||||
@ -709,6 +721,7 @@ int size=200;
|
|||||||
|
|
||||||
void del_profile_files(void);
|
void del_profile_files(void);
|
||||||
void del_profile_files() {
|
void del_profile_files() {
|
||||||
|
CACHE_REGS
|
||||||
if (LOCAL_DIRNAME!=NULL) {
|
if (LOCAL_DIRNAME!=NULL) {
|
||||||
remove(profile_names(PROFPREDS_FILE));
|
remove(profile_names(PROFPREDS_FILE));
|
||||||
remove(profile_names(PROFILING_FILE));
|
remove(profile_names(PROFILING_FILE));
|
||||||
@ -717,6 +730,7 @@ void del_profile_files() {
|
|||||||
|
|
||||||
void
|
void
|
||||||
Yap_inform_profiler_of_clause__(void *code_start, void *code_end, PredEntry *pe,gprof_info index_code) {
|
Yap_inform_profiler_of_clause__(void *code_start, void *code_end, PredEntry *pe,gprof_info index_code) {
|
||||||
|
CACHE_REGS
|
||||||
buf_ptr b;
|
buf_ptr b;
|
||||||
buf_extra e;
|
buf_extra e;
|
||||||
LOCAL_ProfOn = TRUE;
|
LOCAL_ProfOn = TRUE;
|
||||||
@ -742,6 +756,7 @@ static Int profend( USES_REGS1 );
|
|||||||
|
|
||||||
static void
|
static void
|
||||||
clean_tree(rb_red_blk_node* node) {
|
clean_tree(rb_red_blk_node* node) {
|
||||||
|
CACHE_REGS
|
||||||
if (node == LOCAL_ProfilerNil)
|
if (node == LOCAL_ProfilerNil)
|
||||||
return;
|
return;
|
||||||
clean_tree(node->left);
|
clean_tree(node->left);
|
||||||
@ -751,6 +766,7 @@ clean_tree(rb_red_blk_node* node) {
|
|||||||
|
|
||||||
static void
|
static void
|
||||||
reset_tree(void) {
|
reset_tree(void) {
|
||||||
|
CACHE_REGS
|
||||||
clean_tree(LOCAL_ProfilerRoot);
|
clean_tree(LOCAL_ProfilerRoot);
|
||||||
Yap_FreeCodeSpace((char *)LOCAL_ProfilerNil);
|
Yap_FreeCodeSpace((char *)LOCAL_ProfilerNil);
|
||||||
LOCAL_ProfilerNil = LOCAL_ProfilerRoot = NULL;
|
LOCAL_ProfilerNil = LOCAL_ProfilerRoot = NULL;
|
||||||
@ -760,6 +776,7 @@ reset_tree(void) {
|
|||||||
static int
|
static int
|
||||||
InitProfTree(void)
|
InitProfTree(void)
|
||||||
{
|
{
|
||||||
|
CACHE_REGS
|
||||||
if (LOCAL_ProfilerRoot)
|
if (LOCAL_ProfilerRoot)
|
||||||
reset_tree();
|
reset_tree();
|
||||||
while (!(LOCAL_ProfilerRoot = RBTreeCreate())) {
|
while (!(LOCAL_ProfilerRoot = RBTreeCreate())) {
|
||||||
@ -773,6 +790,7 @@ InitProfTree(void)
|
|||||||
|
|
||||||
static void RemoveCode(CODEADDR clau)
|
static void RemoveCode(CODEADDR clau)
|
||||||
{
|
{
|
||||||
|
CACHE_REGS
|
||||||
rb_red_blk_node* x, *node;
|
rb_red_blk_node* x, *node;
|
||||||
PredEntry *pp;
|
PredEntry *pp;
|
||||||
UInt count;
|
UInt count;
|
||||||
@ -958,6 +976,7 @@ prof_alrm(int signo, siginfo_t *si, void *scv)
|
|||||||
void
|
void
|
||||||
Yap_InformOfRemoval(void *clau)
|
Yap_InformOfRemoval(void *clau)
|
||||||
{
|
{
|
||||||
|
CACHE_REGS
|
||||||
LOCAL_ProfOn = TRUE;
|
LOCAL_ProfOn = TRUE;
|
||||||
if (LOCAL_FPreds != NULL) {
|
if (LOCAL_FPreds != NULL) {
|
||||||
/* just store info about what is going on */
|
/* just store info about what is going on */
|
||||||
@ -1048,6 +1067,7 @@ static Int profinit( USES_REGS1 )
|
|||||||
|
|
||||||
static Int start_profilers(int msec)
|
static Int start_profilers(int msec)
|
||||||
{
|
{
|
||||||
|
CACHE_REGS
|
||||||
struct itimerval t;
|
struct itimerval t;
|
||||||
struct sigaction sa;
|
struct sigaction sa;
|
||||||
|
|
||||||
@ -1157,6 +1177,7 @@ static Int profres0( USES_REGS1 ) {
|
|||||||
void
|
void
|
||||||
Yap_InitLowProf(void)
|
Yap_InitLowProf(void)
|
||||||
{
|
{
|
||||||
|
CACHE_REGS
|
||||||
#if LOW_PROF
|
#if LOW_PROF
|
||||||
LOCAL_ProfCalls = 0;
|
LOCAL_ProfCalls = 0;
|
||||||
LOCAL_ProfilerOn = FALSE;
|
LOCAL_ProfilerOn = FALSE;
|
||||||
|
5
C/grow.c
5
C/grow.c
@ -718,6 +718,11 @@ AdjustScannerStacks(TokEntry **tksp, VarEntry **vep USES_REGS)
|
|||||||
TokEntry *tktmp;
|
TokEntry *tktmp;
|
||||||
|
|
||||||
switch (tks->Tok) {
|
switch (tks->Tok) {
|
||||||
|
case Number_tok:
|
||||||
|
if (IsApplTerm(tks->TokInfo)) {
|
||||||
|
tks->TokInfo = AdjustAppl(tks->TokInfo PASS_REGS);
|
||||||
|
}
|
||||||
|
break;
|
||||||
case Var_tok:
|
case Var_tok:
|
||||||
case String_tok:
|
case String_tok:
|
||||||
if (IsOldTrail(tks->TokInfo))
|
if (IsOldTrail(tks->TokInfo))
|
||||||
|
@ -1888,6 +1888,7 @@ emit_single_switch_case(ClauseDef *min, struct intermediates *cint, int first, i
|
|||||||
static UInt
|
static UInt
|
||||||
suspend_indexing(ClauseDef *min, ClauseDef *max, PredEntry *ap, struct intermediates *cint)
|
suspend_indexing(ClauseDef *min, ClauseDef *max, PredEntry *ap, struct intermediates *cint)
|
||||||
{
|
{
|
||||||
|
CACHE_REGS
|
||||||
UInt tcls = ap->cs.p_code.NOfClauses;
|
UInt tcls = ap->cs.p_code.NOfClauses;
|
||||||
UInt cls = (max-min)+1;
|
UInt cls = (max-min)+1;
|
||||||
|
|
||||||
|
12
C/qlyr.c
12
C/qlyr.c
@ -993,6 +993,15 @@ p_read_module_preds( USES_REGS1 )
|
|||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
ReInitCatch(void)
|
||||||
|
{
|
||||||
|
Term t = Yap_MkNewApplTerm(PredHandleThrow->FunctorOfPred, PredHandleThrow->ArityOfPE);
|
||||||
|
YAP_RunGoalOnce(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
static Int
|
static Int
|
||||||
p_read_program( USES_REGS1 )
|
p_read_program( USES_REGS1 )
|
||||||
{
|
{
|
||||||
@ -1016,7 +1025,7 @@ p_read_program( USES_REGS1 )
|
|||||||
Sclose( stream );
|
Sclose( stream );
|
||||||
/* back to the top level we go */
|
/* back to the top level we go */
|
||||||
Yap_CloseSlots(PASS_REGS1);
|
Yap_CloseSlots(PASS_REGS1);
|
||||||
|
ReInitCatch();
|
||||||
Yap_RestartYap( 3 );
|
Yap_RestartYap( 3 );
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
@ -1030,6 +1039,7 @@ Yap_Restore(char *s, char *lib_dir)
|
|||||||
return -1;
|
return -1;
|
||||||
read_module(stream);
|
read_module(stream);
|
||||||
Sclose( stream );
|
Sclose( stream );
|
||||||
|
ReInitCatch();
|
||||||
return DO_ONLY_CODE;
|
return DO_ONLY_CODE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1619,6 +1619,7 @@ InteractSIGINT(int ch) {
|
|||||||
static int
|
static int
|
||||||
ProcessSIGINT(void)
|
ProcessSIGINT(void)
|
||||||
{
|
{
|
||||||
|
CACHE_REGS
|
||||||
int ch, out;
|
int ch, out;
|
||||||
|
|
||||||
LOCAL_PrologMode |= AsyncIntMode;
|
LOCAL_PrologMode |= AsyncIntMode;
|
||||||
|
12
C/tracer.c
12
C/tracer.c
@ -52,17 +52,7 @@ send_tracer_message(char *start, char *name, Int arity, char *mname, CELL *args)
|
|||||||
if (args) {
|
if (args) {
|
||||||
for (i= 0; i < arity; i++) {
|
for (i= 0; i < arity; i++) {
|
||||||
if (i > 0) fprintf(GLOBAL_stderr, ",");
|
if (i > 0) fprintf(GLOBAL_stderr, ",");
|
||||||
#if DEBUG
|
Yap_plwrite(args[i], NULL, 15, Handle_vars_f|AttVar_Portray_f, 1200);
|
||||||
#if COROUTINING
|
|
||||||
Yap_Portray_delays = TRUE;
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
Yap_plwrite(args[i], NULL, 15, Handle_vars_f, 1200);
|
|
||||||
#if DEBUG
|
|
||||||
#if COROUTINING
|
|
||||||
Yap_Portray_delays = FALSE;
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
if (arity) {
|
if (arity) {
|
||||||
fprintf(GLOBAL_stderr, ")");
|
fprintf(GLOBAL_stderr, ")");
|
||||||
|
@ -4255,7 +4255,7 @@ p_is_list_or_partial_list( USES_REGS1 )
|
|||||||
}
|
}
|
||||||
|
|
||||||
static Term
|
static Term
|
||||||
numbervar(Int id)
|
numbervar(Int id USES_REGS)
|
||||||
{
|
{
|
||||||
Term ts[1];
|
Term ts[1];
|
||||||
ts[0] = MkIntegerTerm(id);
|
ts[0] = MkIntegerTerm(id);
|
||||||
@ -4263,7 +4263,7 @@ numbervar(Int id)
|
|||||||
}
|
}
|
||||||
|
|
||||||
static Term
|
static Term
|
||||||
numbervar_singleton(void)
|
numbervar_singleton(USES_REGS1)
|
||||||
{
|
{
|
||||||
Term ts[1];
|
Term ts[1];
|
||||||
ts[0] = MkIntegerTerm(-1);
|
ts[0] = MkIntegerTerm(-1);
|
||||||
@ -4356,9 +4356,9 @@ static Int numbervars_in_complex_term(register CELL *pt0, register CELL *pt0_end
|
|||||||
derefa_body(d0, ptd0, vars_in_term_unk, vars_in_term_nvar);
|
derefa_body(d0, ptd0, vars_in_term_unk, vars_in_term_nvar);
|
||||||
/* do or pt2 are unbound */
|
/* do or pt2 are unbound */
|
||||||
if (singles)
|
if (singles)
|
||||||
*ptd0 = numbervar_singleton();
|
*ptd0 = numbervar_singleton( PASS_REGS1 );
|
||||||
else
|
else
|
||||||
*ptd0 = numbervar(numbv++);
|
*ptd0 = numbervar(numbv++ PASS_REGS);
|
||||||
/* leave an empty slot to fill in later */
|
/* leave an empty slot to fill in later */
|
||||||
if (H+1024 > ASP) {
|
if (H+1024 > ASP) {
|
||||||
goto global_overflow;
|
goto global_overflow;
|
||||||
@ -4450,10 +4450,10 @@ Yap_NumberVars( Term inp, Int numbv, int handle_singles ) /* numbervariables in
|
|||||||
CELL *ptd0 = VarOfTerm(t);
|
CELL *ptd0 = VarOfTerm(t);
|
||||||
TrailTerm(TR++) = (CELL)ptd0;
|
TrailTerm(TR++) = (CELL)ptd0;
|
||||||
if (handle_singles) {
|
if (handle_singles) {
|
||||||
*ptd0 = numbervar_singleton();
|
*ptd0 = numbervar_singleton( PASS_REGS1 );
|
||||||
return numbv;
|
return numbv;
|
||||||
} else {
|
} else {
|
||||||
*ptd0 = numbervar(numbv);
|
*ptd0 = numbervar(numbv PASS_REGS);
|
||||||
return numbv+1;
|
return numbv+1;
|
||||||
}
|
}
|
||||||
} else if (IsPrimitiveTerm(t)) {
|
} else if (IsPrimitiveTerm(t)) {
|
||||||
|
164
C/write.c
164
C/write.c
@ -66,7 +66,7 @@ typedef struct rewind_term {
|
|||||||
|
|
||||||
typedef struct write_globs {
|
typedef struct write_globs {
|
||||||
void *stream;
|
void *stream;
|
||||||
int Quote_illegal, Ignore_ops, Handle_vars, Use_portray;
|
int Quote_illegal, Ignore_ops, Handle_vars, Use_portray, Portray_delays;
|
||||||
int Keep_terms;
|
int Keep_terms;
|
||||||
int Write_Loops;
|
int Write_Loops;
|
||||||
int Write_strings;
|
int Write_strings;
|
||||||
@ -90,28 +90,50 @@ STATIC_PROTO(void writeTerm, (Term, int, int, int, struct write_globs *, struct
|
|||||||
|
|
||||||
#define wrputc(X,WF) Sputcode(X,WF) /* writes a character */
|
#define wrputc(X,WF) Sputcode(X,WF) /* writes a character */
|
||||||
|
|
||||||
|
/*
|
||||||
|
protect bracket from merging with previoous character.
|
||||||
|
avoid stuff like not (2,3) -> not(2,3) or
|
||||||
|
*/
|
||||||
static void
|
static void
|
||||||
protect_open_number(struct write_globs *wglb, int minus_required)
|
wropen_bracket(struct write_globs *wglb, int protect)
|
||||||
{
|
{
|
||||||
wrf stream = wglb->stream;
|
wrf stream = wglb->stream;
|
||||||
|
|
||||||
if (lastw == symbol && last_minus && !minus_required) {
|
if (lastw != separator && protect)
|
||||||
if (!wglb->Ignore_ops) {
|
wrputc(' ', stream);
|
||||||
/* protect against collating - with number, and getting - 1 ^2 as (-(1))^2 */
|
wrputc('(', stream);
|
||||||
wrputc(' ', wglb->stream);
|
lastw = separator;
|
||||||
}
|
|
||||||
wrputc('(', wglb->stream);
|
|
||||||
} else if (lastw == alphanum) {
|
|
||||||
wrputc(' ', stream);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
protect_close_number(struct write_globs *wglb, int minus_required)
|
wrclose_bracket(struct write_globs *wglb, int protect)
|
||||||
{
|
{
|
||||||
if (lastw == symbol && last_minus && !minus_required) {
|
wrf stream = wglb->stream;
|
||||||
wrputc(')', wglb->stream);
|
|
||||||
lastw = separator;
|
wrputc(')', stream);
|
||||||
|
lastw = separator;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int
|
||||||
|
protect_open_number(struct write_globs *wglb, int lm, int minus_required)
|
||||||
|
{
|
||||||
|
wrf stream = wglb->stream;
|
||||||
|
|
||||||
|
if (lastw == symbol && lm && !minus_required) {
|
||||||
|
wropen_bracket(wglb, TRUE);
|
||||||
|
return TRUE;
|
||||||
|
} else if (lastw == alphanum ||
|
||||||
|
(lastw == symbol && minus_required)) {
|
||||||
|
wrputc(' ', stream);
|
||||||
|
}
|
||||||
|
return FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
protect_close_number(struct write_globs *wglb, int used_bracket)
|
||||||
|
{
|
||||||
|
if (used_bracket) {
|
||||||
|
wrclose_bracket(wglb, TRUE);
|
||||||
} else {
|
} else {
|
||||||
lastw = alphanum;
|
lastw = alphanum;
|
||||||
}
|
}
|
||||||
@ -125,8 +147,9 @@ wrputn(Int n, struct write_globs *wglb) /* writes an integer */
|
|||||||
wrf stream = wglb->stream;
|
wrf stream = wglb->stream;
|
||||||
char s[256], *s1=s; /* that should be enough for most integers */
|
char s[256], *s1=s; /* that should be enough for most integers */
|
||||||
int has_minus = (n < 0);
|
int has_minus = (n < 0);
|
||||||
|
int ob;
|
||||||
|
|
||||||
protect_open_number(wglb, has_minus);
|
ob = protect_open_number(wglb, last_minus, has_minus);
|
||||||
#if HAVE_SNPRINTF
|
#if HAVE_SNPRINTF
|
||||||
snprintf(s, 256, Int_FORMAT, n);
|
snprintf(s, 256, Int_FORMAT, n);
|
||||||
#else
|
#else
|
||||||
@ -134,7 +157,7 @@ wrputn(Int n, struct write_globs *wglb) /* writes an integer */
|
|||||||
#endif
|
#endif
|
||||||
while (*s1)
|
while (*s1)
|
||||||
wrputc(*s1++, stream);
|
wrputc(*s1++, stream);
|
||||||
protect_close_number(wglb, has_minus);
|
protect_close_number(wglb, ob);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define wrputs(s, stream) Sfputs(s, stream)
|
#define wrputs(s, stream) Sfputs(s, stream)
|
||||||
@ -190,9 +213,10 @@ static void
|
|||||||
write_mpint(MP_INT *big, struct write_globs *wglb) {
|
write_mpint(MP_INT *big, struct write_globs *wglb) {
|
||||||
char *s;
|
char *s;
|
||||||
int has_minus = mpz_sgn(big);
|
int has_minus = mpz_sgn(big);
|
||||||
|
int ob;
|
||||||
|
|
||||||
s = ensure_space(3+mpz_sizeinbase(big, 10));
|
s = ensure_space(3+mpz_sizeinbase(big, 10));
|
||||||
protect_open_number(wglb, has_minus);
|
ob = protect_open_number(wglb, last_minus, has_minus);
|
||||||
if (!s) {
|
if (!s) {
|
||||||
s = mpz_get_str(NULL, 10, big);
|
s = mpz_get_str(NULL, 10, big);
|
||||||
if (!s)
|
if (!s)
|
||||||
@ -203,7 +227,7 @@ write_mpint(MP_INT *big, struct write_globs *wglb) {
|
|||||||
mpz_get_str(s, 10, big);
|
mpz_get_str(s, 10, big);
|
||||||
wrputs(s,wglb->stream);
|
wrputs(s,wglb->stream);
|
||||||
}
|
}
|
||||||
protect_close_number(wglb, has_minus);
|
protect_close_number(wglb, ob);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -271,6 +295,8 @@ wrputf(Float f, struct write_globs *wglb) /* writes a float */
|
|||||||
char s[256];
|
char s[256];
|
||||||
wrf stream = wglb->stream;
|
wrf stream = wglb->stream;
|
||||||
int sgn;
|
int sgn;
|
||||||
|
int ob;
|
||||||
|
|
||||||
|
|
||||||
#if HAVE_ISNAN || defined(__WIN32)
|
#if HAVE_ISNAN || defined(__WIN32)
|
||||||
if (isnan(f)) {
|
if (isnan(f)) {
|
||||||
@ -291,7 +317,7 @@ wrputf(Float f, struct write_globs *wglb) /* writes a float */
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
protect_open_number(wglb, sgn);
|
ob = protect_open_number(wglb, last_minus, sgn);
|
||||||
#if THREADS
|
#if THREADS
|
||||||
/* old style writing */
|
/* old style writing */
|
||||||
int found_dot = FALSE, found_exp = FALSE;
|
int found_dot = FALSE, found_exp = FALSE;
|
||||||
@ -343,7 +369,7 @@ wrputf(Float f, struct write_globs *wglb) /* writes a float */
|
|||||||
if (!buf) return;
|
if (!buf) return;
|
||||||
wrputs(buf, stream);
|
wrputs(buf, stream);
|
||||||
#endif
|
#endif
|
||||||
protect_close_number(wglb, sgn);
|
protect_close_number(wglb, ob);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* writes a data base reference */
|
/* writes a data base reference */
|
||||||
@ -423,7 +449,7 @@ AtomIsSymbols(unsigned char *s) /* Is this atom just formed by symbols ? */
|
|||||||
return(separator);
|
return(separator);
|
||||||
while ((ch = *s++) != '\0') {
|
while ((ch = *s++) != '\0') {
|
||||||
if (Yap_chtype[ch] != SY)
|
if (Yap_chtype[ch] != SY)
|
||||||
return(alphanum);
|
return alphanum;
|
||||||
}
|
}
|
||||||
return symbol;
|
return symbol;
|
||||||
}
|
}
|
||||||
@ -669,15 +695,14 @@ write_var(CELL *t, struct write_globs *wglb, struct rewind_term *rwt)
|
|||||||
/* make sure we don't get no creepy spaces where they shouldn't be */
|
/* make sure we don't get no creepy spaces where they shouldn't be */
|
||||||
lastw = separator;
|
lastw = separator;
|
||||||
if (IsAttVar(t)) {
|
if (IsAttVar(t)) {
|
||||||
#if defined(COROUTINING) && defined(DEBUG)
|
|
||||||
Int vcount = (t-H0);
|
Int vcount = (t-H0);
|
||||||
if (Yap_Portray_delays) {
|
if (wglb->Portray_delays) {
|
||||||
exts ext = ExtFromCell(t);
|
exts ext = ExtFromCell(t);
|
||||||
struct rewind_term nrwt;
|
struct rewind_term nrwt;
|
||||||
nrwt.parent = rwt;
|
nrwt.parent = rwt;
|
||||||
nrwt.u.s.ptr = 0;
|
nrwt.u.s.ptr = 0;
|
||||||
|
|
||||||
Yap_Portray_delays = FALSE;
|
wglb->Portray_delays = FALSE;
|
||||||
if (ext == attvars_ext) {
|
if (ext == attvars_ext) {
|
||||||
attvar_record *attv = RepAttVar(t);
|
attvar_record *attv = RepAttVar(t);
|
||||||
CELL *l = &attv->Value; /* dirty low-level hack, check atts.h */
|
CELL *l = &attv->Value; /* dirty low-level hack, check atts.h */
|
||||||
@ -691,14 +716,13 @@ write_var(CELL *t, struct write_globs *wglb, struct rewind_term *rwt)
|
|||||||
l += 2;
|
l += 2;
|
||||||
writeTerm(from_pointer(l, &nrwt, wglb), 999, 1, FALSE, wglb, &nrwt);
|
writeTerm(from_pointer(l, &nrwt, wglb), 999, 1, FALSE, wglb, &nrwt);
|
||||||
restore_from_write(&nrwt, wglb);
|
restore_from_write(&nrwt, wglb);
|
||||||
wrputc(')', wglb->stream);
|
wrclose_bracket(wglb, TRUE);
|
||||||
}
|
}
|
||||||
Yap_Portray_delays = TRUE;
|
wglb->Portray_delays = TRUE;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
wrputc('D', wglb->stream);
|
wrputc('D', wglb->stream);
|
||||||
wrputn(vcount,wglb);
|
wrputn(vcount,wglb);
|
||||||
#endif
|
|
||||||
} else {
|
} else {
|
||||||
wrputn(((Int) (t- H0)),wglb);
|
wrputn(((Int) (t- H0)),wglb);
|
||||||
}
|
}
|
||||||
@ -790,6 +814,7 @@ write_list(Term t, int direction, int depth, struct write_globs *wglb, struct re
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static void
|
static void
|
||||||
writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, struct rewind_term *rwt)
|
writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, struct rewind_term *rwt)
|
||||||
/* term to write */
|
/* term to write */
|
||||||
@ -823,8 +848,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
|||||||
wrputs(",",wglb->stream);
|
wrputs(",",wglb->stream);
|
||||||
writeTerm(from_pointer(RepPair(t)+1, &nrwt, wglb), 999, depth + 1, FALSE, wglb, &nrwt);
|
writeTerm(from_pointer(RepPair(t)+1, &nrwt, wglb), 999, depth + 1, FALSE, wglb, &nrwt);
|
||||||
restore_from_write(&nrwt, wglb);
|
restore_from_write(&nrwt, wglb);
|
||||||
wrputc(')', wglb->stream);
|
wrclose_bracket(wglb, TRUE);
|
||||||
lastw = separator;
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (wglb->Use_portray) {
|
if (wglb->Use_portray) {
|
||||||
@ -886,7 +910,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
|||||||
int argno = 1;
|
int argno = 1;
|
||||||
CELL *p = ArgsOfSFTerm(t);
|
CELL *p = ArgsOfSFTerm(t);
|
||||||
putAtom(atom, wglb->Quote_illegal, wglb);
|
putAtom(atom, wglb->Quote_illegal, wglb);
|
||||||
wrputc('(', wglb->stream);
|
wropen_bracket(wglb, FALSE);
|
||||||
lastw = separator;
|
lastw = separator;
|
||||||
while (*p) {
|
while (*p) {
|
||||||
Int sl = 0;
|
Int sl = 0;
|
||||||
@ -904,8 +928,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
|||||||
wrputc(',', wglb->stream);
|
wrputc(',', wglb->stream);
|
||||||
argno++;
|
argno++;
|
||||||
}
|
}
|
||||||
wrputc(')', wglb->stream);
|
wrclose_bracket(wglb, TRUE);
|
||||||
lastw = separator;
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -934,28 +957,22 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
|||||||
!IsVarTerm(tright) && IsAtomTerm(tright) &&
|
!IsVarTerm(tright) && IsAtomTerm(tright) &&
|
||||||
Yap_IsOp(AtomOfTerm(tright));
|
Yap_IsOp(AtomOfTerm(tright));
|
||||||
if (op > p) {
|
if (op > p) {
|
||||||
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */
|
wropen_bracket(wglb, TRUE);
|
||||||
if (lastw != separator && !rinfixarg)
|
|
||||||
wrputc(' ', wglb->stream);
|
|
||||||
wrputc('(', wglb->stream);
|
|
||||||
lastw = separator;
|
|
||||||
}
|
}
|
||||||
putAtom(atom, wglb->Quote_illegal, wglb);
|
putAtom(atom, wglb->Quote_illegal, wglb);
|
||||||
if (bracket_right) {
|
if (bracket_right) {
|
||||||
wrputc('(', wglb->stream);
|
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */
|
||||||
lastw = separator;
|
wropen_bracket(wglb, TRUE);
|
||||||
} else if (atom == AtomMinus) {
|
} else if (atom == AtomMinus) {
|
||||||
last_minus = TRUE;
|
last_minus = TRUE;
|
||||||
}
|
}
|
||||||
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), rp, depth + 1, FALSE, wglb, &nrwt);
|
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), rp, depth + 1, TRUE, wglb, &nrwt);
|
||||||
restore_from_write(&nrwt, wglb);
|
restore_from_write(&nrwt, wglb);
|
||||||
if (bracket_right) {
|
if (bracket_right) {
|
||||||
wrputc(')', wglb->stream);
|
wrclose_bracket(wglb, TRUE);
|
||||||
lastw = separator;
|
|
||||||
}
|
}
|
||||||
if (op > p) {
|
if (op > p) {
|
||||||
wrputc(')', wglb->stream);
|
wrclose_bracket(wglb, TRUE);
|
||||||
lastw = separator;
|
|
||||||
}
|
}
|
||||||
} else if (!wglb->Ignore_ops &&
|
} else if (!wglb->Ignore_ops &&
|
||||||
Arity == 1 &&
|
Arity == 1 &&
|
||||||
@ -963,29 +980,24 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
|||||||
Term tleft = ArgOfTerm(1, t);
|
Term tleft = ArgOfTerm(1, t);
|
||||||
|
|
||||||
int bracket_left =
|
int bracket_left =
|
||||||
!IsVarTerm(tleft) && IsAtomTerm(tleft) &&
|
!IsVarTerm(tleft) &&
|
||||||
|
IsAtomTerm(tleft) &&
|
||||||
Yap_IsOp(AtomOfTerm(tleft));
|
Yap_IsOp(AtomOfTerm(tleft));
|
||||||
if (op > p) {
|
if (op > p) {
|
||||||
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */
|
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */
|
||||||
if (lastw != separator && !rinfixarg)
|
wropen_bracket(wglb, TRUE);
|
||||||
wrputc(' ', wglb->stream);
|
|
||||||
wrputc('(', wglb->stream);
|
|
||||||
lastw = separator;
|
|
||||||
}
|
}
|
||||||
if (bracket_left) {
|
if (bracket_left) {
|
||||||
wrputc('(', wglb->stream);
|
wropen_bracket(wglb, TRUE);
|
||||||
lastw = separator;
|
|
||||||
}
|
}
|
||||||
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), lp, depth + 1, rinfixarg, wglb, &nrwt);
|
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), lp, depth + 1, rinfixarg, wglb, &nrwt);
|
||||||
restore_from_write(&nrwt, wglb);
|
restore_from_write(&nrwt, wglb);
|
||||||
if (bracket_left) {
|
if (bracket_left) {
|
||||||
wrputc(')', wglb->stream);
|
wrclose_bracket(wglb, TRUE);
|
||||||
lastw = separator;
|
|
||||||
}
|
}
|
||||||
putAtom(atom, wglb->Quote_illegal, wglb);
|
putAtom(atom, wglb->Quote_illegal, wglb);
|
||||||
if (op > p) {
|
if (op > p) {
|
||||||
wrputc(')', wglb->stream);
|
wrclose_bracket(wglb, TRUE);
|
||||||
lastw = separator;
|
|
||||||
}
|
}
|
||||||
} else if (!wglb->Ignore_ops &&
|
} else if (!wglb->Ignore_ops &&
|
||||||
Arity == 2 && Yap_IsInfixOp(atom, &op, &lp,
|
Arity == 2 && Yap_IsInfixOp(atom, &op, &lp,
|
||||||
@ -1001,41 +1013,36 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
|||||||
|
|
||||||
if (op > p) {
|
if (op > p) {
|
||||||
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */
|
/* avoid stuff such as \+ (a,b) being written as \+(a,b) */
|
||||||
if (lastw != separator && !rinfixarg)
|
wropen_bracket(wglb, TRUE);
|
||||||
wrputc(' ', wglb->stream);
|
|
||||||
wrputc('(', wglb->stream);
|
|
||||||
lastw = separator;
|
lastw = separator;
|
||||||
}
|
}
|
||||||
if (bracket_left) {
|
if (bracket_left) {
|
||||||
wrputc('(', wglb->stream);
|
wropen_bracket(wglb, TRUE);
|
||||||
lastw = separator;
|
|
||||||
}
|
}
|
||||||
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), lp, depth + 1, rinfixarg, wglb, &nrwt);
|
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), lp, depth + 1, rinfixarg, wglb, &nrwt);
|
||||||
t = AbsAppl(restore_from_write(&nrwt, wglb)-1);
|
t = AbsAppl(restore_from_write(&nrwt, wglb)-1);
|
||||||
if (bracket_left) {
|
if (bracket_left) {
|
||||||
wrputc(')', wglb->stream);
|
wrclose_bracket(wglb, TRUE);
|
||||||
lastw = separator;
|
|
||||||
}
|
}
|
||||||
/* avoid quoting commas */
|
/* avoid quoting commas and bars */
|
||||||
if (strcmp(RepAtom(atom)->StrOfAE,","))
|
if (!strcmp(RepAtom(atom)->StrOfAE,",")) {
|
||||||
putAtom(atom, wglb->Quote_illegal, wglb);
|
|
||||||
else {
|
|
||||||
wrputc(',', wglb->stream);
|
wrputc(',', wglb->stream);
|
||||||
lastw = separator;
|
lastw = separator;
|
||||||
}
|
} else if (!strcmp(RepAtom(atom)->StrOfAE,"|")) {
|
||||||
if (bracket_right) {
|
wrputc('|', wglb->stream);
|
||||||
wrputc('(', wglb->stream);
|
|
||||||
lastw = separator;
|
lastw = separator;
|
||||||
|
} else
|
||||||
|
putAtom(atom, wglb->Quote_illegal, wglb);
|
||||||
|
if (bracket_right) {
|
||||||
|
wropen_bracket(wglb, TRUE);
|
||||||
}
|
}
|
||||||
writeTerm(from_pointer(RepAppl(t)+2, &nrwt, wglb), rp, depth + 1, TRUE, wglb, &nrwt);
|
writeTerm(from_pointer(RepAppl(t)+2, &nrwt, wglb), rp, depth + 1, TRUE, wglb, &nrwt);
|
||||||
restore_from_write(&nrwt, wglb);
|
restore_from_write(&nrwt, wglb);
|
||||||
if (bracket_right) {
|
if (bracket_right) {
|
||||||
wrputc(')', wglb->stream);
|
wrclose_bracket(wglb, TRUE);
|
||||||
lastw = separator;
|
|
||||||
}
|
}
|
||||||
if (op > p) {
|
if (op > p) {
|
||||||
wrputc(')', wglb->stream);
|
wrclose_bracket(wglb, TRUE);
|
||||||
lastw = separator;
|
|
||||||
}
|
}
|
||||||
} else if (wglb->Handle_vars && functor == LOCAL_FunctorVar) {
|
} else if (wglb->Handle_vars && functor == LOCAL_FunctorVar) {
|
||||||
Term ti = ArgOfTerm(1, t);
|
Term ti = ArgOfTerm(1, t);
|
||||||
@ -1068,8 +1075,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
|||||||
lastw = separator;
|
lastw = separator;
|
||||||
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), 999, depth + 1, FALSE, wglb, &nrwt);
|
writeTerm(from_pointer(RepAppl(t)+1, &nrwt, wglb), 999, depth + 1, FALSE, wglb, &nrwt);
|
||||||
restore_from_write(&nrwt, wglb);
|
restore_from_write(&nrwt, wglb);
|
||||||
wrputc(')', wglb->stream);
|
wrclose_bracket(wglb, TRUE);
|
||||||
lastw = separator;
|
|
||||||
}
|
}
|
||||||
} else if (!wglb->Ignore_ops && functor == FunctorBraces) {
|
} else if (!wglb->Ignore_ops && functor == FunctorBraces) {
|
||||||
wrputc('{', wglb->stream);
|
wrputc('{', wglb->stream);
|
||||||
@ -1098,7 +1104,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
|||||||
} else {
|
} else {
|
||||||
putAtom(atom, wglb->Quote_illegal, wglb);
|
putAtom(atom, wglb->Quote_illegal, wglb);
|
||||||
lastw = separator;
|
lastw = separator;
|
||||||
wrputc('(', wglb->stream);
|
wropen_bracket(wglb, FALSE);
|
||||||
for (op = 1; op <= Arity; ++op) {
|
for (op = 1; op <= Arity; ++op) {
|
||||||
if (op == wglb->MaxArgs) {
|
if (op == wglb->MaxArgs) {
|
||||||
wrputc('.', wglb->stream);
|
wrputc('.', wglb->stream);
|
||||||
@ -1113,8 +1119,7 @@ writeTerm(Term t, int p, int depth, int rinfixarg, struct write_globs *wglb, str
|
|||||||
lastw = separator;
|
lastw = separator;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
wrputc(')', wglb->stream);
|
wrclose_bracket(wglb, TRUE);
|
||||||
lastw = separator;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1138,6 +1143,7 @@ Yap_plwrite(Term t, void *mywrite, int max_depth, int flags, int priority)
|
|||||||
wglb.Quote_illegal = flags & Quote_illegal_f;
|
wglb.Quote_illegal = flags & Quote_illegal_f;
|
||||||
wglb.Handle_vars = flags & Handle_vars_f;
|
wglb.Handle_vars = flags & Handle_vars_f;
|
||||||
wglb.Use_portray = flags & Use_portray_f;
|
wglb.Use_portray = flags & Use_portray_f;
|
||||||
|
wglb.Portray_delays = flags & AttVar_Portray_f;
|
||||||
wglb.MaxDepth = max_depth;
|
wglb.MaxDepth = max_depth;
|
||||||
wglb.MaxArgs = max_depth;
|
wglb.MaxArgs = max_depth;
|
||||||
/* notice: we must have ASP well set when using portray, otherwise
|
/* notice: we must have ASP well set when using portray, otherwise
|
||||||
|
@ -498,6 +498,7 @@ void STD_PROTO(Yap_init_optyap_preds,(void));
|
|||||||
|
|
||||||
/* pl-file.c */
|
/* pl-file.c */
|
||||||
struct PL_local_data *Yap_InitThreadIO(int wid);
|
struct PL_local_data *Yap_InitThreadIO(int wid);
|
||||||
|
void Yap_flush(void);
|
||||||
|
|
||||||
static inline
|
static inline
|
||||||
yamop *
|
yamop *
|
||||||
|
@ -159,6 +159,9 @@ typedef enum
|
|||||||
#ifdef HAVE_LOCALE_H
|
#ifdef HAVE_LOCALE_H
|
||||||
#include <locale.h>
|
#include <locale.h>
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef HAVE_LIMITS_H /* get MAXPATHLEN */
|
||||||
|
#include <limits.h>
|
||||||
|
#endif
|
||||||
#include <setjmp.h>
|
#include <setjmp.h>
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#if HAVE_SYS_PARAM_H
|
#if HAVE_SYS_PARAM_H
|
||||||
|
@ -323,12 +323,6 @@ int STD_PROTO(Yap_growtrail_in_parser, (tr_fr_ptr *, TokEntry **, VarEntry **)
|
|||||||
extern int errno;
|
extern int errno;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef DEBUG
|
|
||||||
#if COROUTINING
|
|
||||||
extern int Yap_Portray_delays;
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
EXTERN inline UInt STD_PROTO(HashFunction, (unsigned char *));
|
EXTERN inline UInt STD_PROTO(HashFunction, (unsigned char *));
|
||||||
EXTERN inline UInt STD_PROTO(WideHashFunction, (wchar_t *));
|
EXTERN inline UInt STD_PROTO(WideHashFunction, (wchar_t *));
|
||||||
|
|
||||||
|
@ -710,6 +710,7 @@ all: startup.yss
|
|||||||
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE))
|
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE))
|
||||||
@ENABLE_CPLINT@ (cd packages/cplint/slipcase; $(MAKE))
|
@ENABLE_CPLINT@ (cd packages/cplint/slipcase; $(MAKE))
|
||||||
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE))
|
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE))
|
||||||
|
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE))
|
||||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE))
|
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE))
|
||||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE))
|
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE))
|
||||||
@ENABLE_JPL@ @INSTALL_DLLS@ (cd packages/jpl; $(MAKE))
|
@ENABLE_JPL@ @INSTALL_DLLS@ (cd packages/jpl; $(MAKE))
|
||||||
@ -788,6 +789,7 @@ install_unix: startup.yss libYap.a
|
|||||||
@ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE) install)
|
@ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE) install)
|
||||||
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) install)
|
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) install)
|
||||||
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) install)
|
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) install)
|
||||||
|
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE) install)
|
||||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) install)
|
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) install)
|
||||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) install)
|
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) install)
|
||||||
|
|
||||||
@ -840,6 +842,7 @@ install_win32: startup.yss @ENABLE_WINCONSOLE@ pl-yap@EXEC_SUFFIX@
|
|||||||
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) install)
|
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) install)
|
||||||
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) install)
|
@ENABLE_PRISM@ (cd packages/prism/src/c; $(MAKE) install)
|
||||||
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) install)
|
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) install)
|
||||||
|
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE) install)
|
||||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) install)
|
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) install)
|
||||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) install)
|
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) install)
|
||||||
|
|
||||||
@ -904,6 +907,7 @@ clean: clean_docs
|
|||||||
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) clean)
|
@ENABLE_PRISM@ (cd packages/prism/src/prolog; $(MAKE) clean)
|
||||||
@ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE) clean)
|
@ENABLE_CPLINT@ (cd packages/cplint/approx/simplecuddLPADs; $(MAKE) clean)
|
||||||
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) clean)
|
@ENABLE_CPLINT@ (cd packages/cplint; $(MAKE) clean)
|
||||||
|
@ENABLE_BDDLIB@ (cd packages/bdd; $(MAKE) clean)
|
||||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) clean)
|
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd; $(MAKE) clean)
|
||||||
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) clean)
|
@ENABLE_CUDD@ (cd packages/ProbLog/simplecudd_lfi; $(MAKE) clean)
|
||||||
@ENABLE_JPL@ @INSTALL_DLLS@ (cd packages/jpl; $(MAKE) clean)
|
@ENABLE_JPL@ @INSTALL_DLLS@ (cd packages/jpl; $(MAKE) clean)
|
||||||
|
@ -257,6 +257,10 @@
|
|||||||
#undef HAVE_WAITPID
|
#undef HAVE_WAITPID
|
||||||
#undef HAVE_MPZ_XOR
|
#undef HAVE_MPZ_XOR
|
||||||
|
|
||||||
|
#if HAVE_GETHOSTNAME==1
|
||||||
|
#define HAS_GETHOSTNAME 1
|
||||||
|
#endif
|
||||||
|
|
||||||
#undef HAVE_SIGINFO
|
#undef HAVE_SIGINFO
|
||||||
#undef HAVE_SIGSEGV
|
#undef HAVE_SIGSEGV
|
||||||
#undef HAVE_SIGPROF
|
#undef HAVE_SIGPROF
|
||||||
|
22
configure
vendored
22
configure
vendored
@ -625,6 +625,7 @@ ENABLE_REAL
|
|||||||
ENABLE_MINISAT
|
ENABLE_MINISAT
|
||||||
CUDD_CPPFLAGS
|
CUDD_CPPFLAGS
|
||||||
CUDD_LDFLAGS
|
CUDD_LDFLAGS
|
||||||
|
ENABLE_BDDLIB
|
||||||
ENABLE_CUDD
|
ENABLE_CUDD
|
||||||
EXTRA_INCLUDES_FOR_WIN32
|
EXTRA_INCLUDES_FOR_WIN32
|
||||||
ENABLE_WINCONSOLE
|
ENABLE_WINCONSOLE
|
||||||
@ -788,6 +789,7 @@ enable_depth_limit
|
|||||||
enable_wam_profile
|
enable_wam_profile
|
||||||
enable_low_level_tracer
|
enable_low_level_tracer
|
||||||
enable_threads
|
enable_threads
|
||||||
|
enable_bddlib
|
||||||
enable_pthread_locking
|
enable_pthread_locking
|
||||||
enable_max_performance
|
enable_max_performance
|
||||||
enable_max_memory
|
enable_max_memory
|
||||||
@ -1462,6 +1464,7 @@ Optional Features:
|
|||||||
--enable-wam-profile support low level profiling of abstract machine
|
--enable-wam-profile support low level profiling of abstract machine
|
||||||
--enable-low-level-tracer support support for procedure-call tracing
|
--enable-low-level-tracer support support for procedure-call tracing
|
||||||
--enable-threads support system threads
|
--enable-threads support system threads
|
||||||
|
--enable-bddlib dynamic bdd library
|
||||||
--enable-pthread-locking use pthread locking primitives for internal locking (requires threads)
|
--enable-pthread-locking use pthread locking primitives for internal locking (requires threads)
|
||||||
--enable-max-performance try using the best flags for specific architecture
|
--enable-max-performance try using the best flags for specific architecture
|
||||||
--enable-max-memory try using the best flags for using the memory to the most
|
--enable-max-memory try using the best flags for using the memory to the most
|
||||||
@ -4486,6 +4489,13 @@ else
|
|||||||
threads=no
|
threads=no
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Check whether --enable-bddlib was given.
|
||||||
|
if test "${enable_bddlib+set}" = set; then :
|
||||||
|
enableval=$enable_bddlib; dynamic_bdd="$enableval"
|
||||||
|
else
|
||||||
|
dynamic_bdd=no
|
||||||
|
fi
|
||||||
|
|
||||||
# Check whether --enable-pthread-locking was given.
|
# Check whether --enable-pthread-locking was given.
|
||||||
if test "${enable_pthread_locking+set}" = set; then :
|
if test "${enable_pthread_locking+set}" = set; then :
|
||||||
enableval=$enable_pthread_locking; pthreadlocking="$enableval"
|
enableval=$enable_pthread_locking; pthreadlocking="$enableval"
|
||||||
@ -5000,7 +5010,14 @@ fi
|
|||||||
if test "$yap_cv_cudd" = no
|
if test "$yap_cv_cudd" = no
|
||||||
then
|
then
|
||||||
ENABLE_CUDD="@# "
|
ENABLE_CUDD="@# "
|
||||||
|
ENABLE_BDDLIB="@# "
|
||||||
else
|
else
|
||||||
|
if test "$dynamic_bdd" = yes
|
||||||
|
then
|
||||||
|
ENABLE_BDDLIB=""
|
||||||
|
else
|
||||||
|
ENABLE_BDDLIB="@# "
|
||||||
|
fi
|
||||||
ENABLE_CUDD=""
|
ENABLE_CUDD=""
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -9220,6 +9237,7 @@ fi
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for gcc threaded code" >&5
|
{ $as_echo "$as_me:${as_lineno-$LINENO}: checking for gcc threaded code" >&5
|
||||||
@ -10454,6 +10472,7 @@ mkdir -p LGPL/clp
|
|||||||
mkdir -p LGPL/swi_console
|
mkdir -p LGPL/swi_console
|
||||||
mkdir -p GPL
|
mkdir -p GPL
|
||||||
mkdir -p packages/
|
mkdir -p packages/
|
||||||
|
mkdir -p packages/bdd
|
||||||
mkdir -p packages/clib
|
mkdir -p packages/clib
|
||||||
mkdir -p packages/clib/sha1
|
mkdir -p packages/clib/sha1
|
||||||
mkdir -p packages/clib/maildrop
|
mkdir -p packages/clib/maildrop
|
||||||
@ -10616,6 +10635,8 @@ ac_config_files="$ac_config_files packages/zlib/Makefile"
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if test "$ENABLE_CUDD" = ""; then
|
if test "$ENABLE_CUDD" = ""; then
|
||||||
|
ac_config_files="$ac_config_files packages/bdd/Makefile"
|
||||||
|
|
||||||
ac_config_files="$ac_config_files packages/ProbLog/simplecudd/Makefile"
|
ac_config_files="$ac_config_files packages/ProbLog/simplecudd/Makefile"
|
||||||
|
|
||||||
ac_config_files="$ac_config_files packages/ProbLog/simplecudd_lfi/Makefile"
|
ac_config_files="$ac_config_files packages/ProbLog/simplecudd_lfi/Makefile"
|
||||||
@ -11398,6 +11419,7 @@ do
|
|||||||
"packages/semweb/Makefile") CONFIG_FILES="$CONFIG_FILES packages/semweb/Makefile" ;;
|
"packages/semweb/Makefile") CONFIG_FILES="$CONFIG_FILES packages/semweb/Makefile" ;;
|
||||||
"packages/sgml/Makefile") CONFIG_FILES="$CONFIG_FILES packages/sgml/Makefile" ;;
|
"packages/sgml/Makefile") CONFIG_FILES="$CONFIG_FILES packages/sgml/Makefile" ;;
|
||||||
"packages/zlib/Makefile") CONFIG_FILES="$CONFIG_FILES packages/zlib/Makefile" ;;
|
"packages/zlib/Makefile") CONFIG_FILES="$CONFIG_FILES packages/zlib/Makefile" ;;
|
||||||
|
"packages/bdd/Makefile") CONFIG_FILES="$CONFIG_FILES packages/bdd/Makefile" ;;
|
||||||
"packages/ProbLog/simplecudd/Makefile") CONFIG_FILES="$CONFIG_FILES packages/ProbLog/simplecudd/Makefile" ;;
|
"packages/ProbLog/simplecudd/Makefile") CONFIG_FILES="$CONFIG_FILES packages/ProbLog/simplecudd/Makefile" ;;
|
||||||
"packages/ProbLog/simplecudd_lfi/Makefile") CONFIG_FILES="$CONFIG_FILES packages/ProbLog/simplecudd_lfi/Makefile" ;;
|
"packages/ProbLog/simplecudd_lfi/Makefile") CONFIG_FILES="$CONFIG_FILES packages/ProbLog/simplecudd_lfi/Makefile" ;;
|
||||||
"packages/swi-minisat2/Makefile") CONFIG_FILES="$CONFIG_FILES packages/swi-minisat2/Makefile" ;;
|
"packages/swi-minisat2/Makefile") CONFIG_FILES="$CONFIG_FILES packages/swi-minisat2/Makefile" ;;
|
||||||
|
13
configure.in
13
configure.in
@ -158,6 +158,9 @@ AC_ARG_ENABLE(low-level-tracer,
|
|||||||
AC_ARG_ENABLE(threads,
|
AC_ARG_ENABLE(threads,
|
||||||
[ --enable-threads support system threads ],
|
[ --enable-threads support system threads ],
|
||||||
threads="$enableval", threads=no)
|
threads="$enableval", threads=no)
|
||||||
|
AC_ARG_ENABLE(bddlib,
|
||||||
|
[ --enable-bddlib dynamic bdd library ],
|
||||||
|
dynamic_bdd="$enableval", dynamic_bdd=no)
|
||||||
AC_ARG_ENABLE(pthread-locking,
|
AC_ARG_ENABLE(pthread-locking,
|
||||||
[ --enable-pthread-locking use pthread locking primitives for internal locking (requires threads) ],
|
[ --enable-pthread-locking use pthread locking primitives for internal locking (requires threads) ],
|
||||||
pthreadlocking="$enableval", pthreadlocking=no)
|
pthreadlocking="$enableval", pthreadlocking=no)
|
||||||
@ -510,7 +513,14 @@ fi
|
|||||||
if test "$yap_cv_cudd" = no
|
if test "$yap_cv_cudd" = no
|
||||||
then
|
then
|
||||||
ENABLE_CUDD="@# "
|
ENABLE_CUDD="@# "
|
||||||
|
ENABLE_BDDLIB="@# "
|
||||||
else
|
else
|
||||||
|
if test "$dynamic_bdd" = yes
|
||||||
|
then
|
||||||
|
ENABLE_BDDLIB=""
|
||||||
|
else
|
||||||
|
ENABLE_BDDLIB="@# "
|
||||||
|
fi
|
||||||
ENABLE_CUDD=""
|
ENABLE_CUDD=""
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -1789,6 +1799,7 @@ AC_SUBST(ENABLE_WINCONSOLE)
|
|||||||
AC_SUBST(EXTRA_INCLUDES_FOR_WIN32)
|
AC_SUBST(EXTRA_INCLUDES_FOR_WIN32)
|
||||||
|
|
||||||
AC_SUBST(ENABLE_CUDD)
|
AC_SUBST(ENABLE_CUDD)
|
||||||
|
AC_SUBST(ENABLE_BDDLIB)
|
||||||
AC_SUBST(CUDD_LDFLAGS)
|
AC_SUBST(CUDD_LDFLAGS)
|
||||||
AC_SUBST(CUDD_CPPFLAGS)
|
AC_SUBST(CUDD_CPPFLAGS)
|
||||||
AC_SUBST(ENABLE_MINISAT)
|
AC_SUBST(ENABLE_MINISAT)
|
||||||
@ -2269,6 +2280,7 @@ mkdir -p LGPL/clp
|
|||||||
mkdir -p LGPL/swi_console
|
mkdir -p LGPL/swi_console
|
||||||
mkdir -p GPL
|
mkdir -p GPL
|
||||||
mkdir -p packages/
|
mkdir -p packages/
|
||||||
|
mkdir -p packages/bdd
|
||||||
mkdir -p packages/clib
|
mkdir -p packages/clib
|
||||||
mkdir -p packages/clib/sha1
|
mkdir -p packages/clib/sha1
|
||||||
mkdir -p packages/clib/maildrop
|
mkdir -p packages/clib/maildrop
|
||||||
@ -2392,6 +2404,7 @@ AC_CONFIG_FILES([packages/zlib/Makefile])
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if test "$ENABLE_CUDD" = ""; then
|
if test "$ENABLE_CUDD" = ""; then
|
||||||
|
AC_CONFIG_FILES([packages/bdd/Makefile])
|
||||||
AC_CONFIG_FILES([packages/ProbLog/simplecudd/Makefile])
|
AC_CONFIG_FILES([packages/ProbLog/simplecudd/Makefile])
|
||||||
AC_CONFIG_FILES([packages/ProbLog/simplecudd_lfi/Makefile])
|
AC_CONFIG_FILES([packages/ProbLog/simplecudd_lfi/Makefile])
|
||||||
fi
|
fi
|
||||||
|
@ -75,6 +75,9 @@
|
|||||||
#if HAVE_STRING_H
|
#if HAVE_STRING_H
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#endif
|
#endif
|
||||||
|
#if HAVE_IEEEFP_H
|
||||||
|
#include <ieeefp.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
static void PROTO(do_top_goal,(YAP_Term));
|
static void PROTO(do_top_goal,(YAP_Term));
|
||||||
static void PROTO(exec_top_level,(int, YAP_init_args *));
|
static void PROTO(exec_top_level,(int, YAP_init_args *));
|
||||||
|
22
docs/yap.tex
22
docs/yap.tex
@ -12635,6 +12635,13 @@ The path @var{Path} is a path starting at vertex @var{Vertex} in graph
|
|||||||
The path @var{Path} is a path starting at vertex @var{Vertex} in graph
|
The path @var{Path} is a path starting at vertex @var{Vertex} in graph
|
||||||
@var{Graph}.
|
@var{Graph}.
|
||||||
|
|
||||||
|
@item dgraph_leaves(+@var{Graph}, ?@var{Vertices})
|
||||||
|
@findex dgraph_leaves/2
|
||||||
|
@snindex dgraph_leaves/2
|
||||||
|
@cnindex dgraph_leaves/2
|
||||||
|
The vertices @var{Vertices} have no outgoing edge in graph
|
||||||
|
@var{Graph}.
|
||||||
|
|
||||||
@end table
|
@end table
|
||||||
|
|
||||||
@node UnDGraphs, Lambda , DGraphs, Library
|
@node UnDGraphs, Lambda , DGraphs, Library
|
||||||
@ -16464,6 +16471,21 @@ then allow one to construct functors, and to obtain their name and arity.
|
|||||||
Note that the functor is essentially a pair formed by an atom, and
|
Note that the functor is essentially a pair formed by an atom, and
|
||||||
arity.
|
arity.
|
||||||
|
|
||||||
|
Constructing terms in the stack may lead to overflow. The routine
|
||||||
|
@example
|
||||||
|
int YAP_RequiresExtraStack(size_t @var{min})
|
||||||
|
@end example
|
||||||
|
verifies whether you have at least @var{min} cells free in the stack,
|
||||||
|
and it returns true if it has to ensure enough memory by calling the
|
||||||
|
garbage collector and or stack shifter. The routine returns false if no
|
||||||
|
memory is needed, and a negative number if it cannot provide enough
|
||||||
|
memory.
|
||||||
|
|
||||||
|
You can set @var{min} to zero if you do not know how much room you need
|
||||||
|
but you do know you do not need a big chunk at a single go. Usually, the routine
|
||||||
|
would usually be called together with a long-jump to restart the
|
||||||
|
code. Slots can also be used if there is small state.
|
||||||
|
|
||||||
@node Unifying Terms, Manipulating Strings, Manipulating Terms, C-Interface
|
@node Unifying Terms, Manipulating Strings, Manipulating Terms, C-Interface
|
||||||
@section Unification
|
@section Unification
|
||||||
|
|
||||||
|
@ -599,6 +599,8 @@ extern X_API size_t PROTO(YAP_SizeOfExportedTerm,(char *));
|
|||||||
|
|
||||||
extern X_API YAP_Term PROTO(YAP_ImportTerm,(char *));
|
extern X_API YAP_Term PROTO(YAP_ImportTerm,(char *));
|
||||||
|
|
||||||
|
extern X_API int PROTO(YAP_RequiresExtraStack,(size_t));
|
||||||
|
|
||||||
#define YAP_InitCPred(N,A,F) YAP_UserCPredicate(N,F,A)
|
#define YAP_InitCPred(N,A,F) YAP_UserCPredicate(N,F,A)
|
||||||
|
|
||||||
__END_DECLS
|
__END_DECLS
|
||||||
|
@ -115,6 +115,7 @@ typedef enum {
|
|||||||
#define YAP_BOOT_FROM_SAVED_CODE 1
|
#define YAP_BOOT_FROM_SAVED_CODE 1
|
||||||
#define YAP_BOOT_FROM_SAVED_STACKS 2
|
#define YAP_BOOT_FROM_SAVED_STACKS 2
|
||||||
#define YAP_FULL_BOOT_FROM_PROLOG 4
|
#define YAP_FULL_BOOT_FROM_PROLOG 4
|
||||||
|
#define YAP_BOOT_DONE_BEFOREHAND 8
|
||||||
#define YAP_BOOT_ERROR -1
|
#define YAP_BOOT_ERROR -1
|
||||||
|
|
||||||
#define YAP_WRITE_QUOTED 1
|
#define YAP_WRITE_QUOTED 1
|
||||||
|
@ -32,6 +32,7 @@
|
|||||||
dgraph_min_paths/3,
|
dgraph_min_paths/3,
|
||||||
dgraph_isomorphic/4,
|
dgraph_isomorphic/4,
|
||||||
dgraph_path/3,
|
dgraph_path/3,
|
||||||
|
dgraph_leaves/2,
|
||||||
dgraph_reachable/3
|
dgraph_reachable/3
|
||||||
]).
|
]).
|
||||||
|
|
||||||
@ -414,3 +415,13 @@ reachable([V|Vertices], Done0, DoneF, G, [V|EdgesF], Edges0) :-
|
|||||||
rb_insert(Done0, V, [], Done1),
|
rb_insert(Done0, V, [], Done1),
|
||||||
reachable(Kids, Done1, DoneI, G, EdgesF, EdgesI),
|
reachable(Kids, Done1, DoneI, G, EdgesF, EdgesI),
|
||||||
reachable(Vertices, DoneI, DoneF, G, EdgesI, Edges0).
|
reachable(Vertices, DoneI, DoneF, G, EdgesI, Edges0).
|
||||||
|
|
||||||
|
dgraph_leaves(Graph, Vertices) :-
|
||||||
|
rb_visit(Graph, Pairs),
|
||||||
|
vertices_without_children(Pairs, Vertices).
|
||||||
|
|
||||||
|
vertices_without_children([], []).
|
||||||
|
vertices_without_children((V-[]).Pairs, V.Vertices) :-
|
||||||
|
vertices_without_children(Pairs, Vertices).
|
||||||
|
vertices_without_children(_V-[_|_].Pairs, Vertices) :-
|
||||||
|
vertices_without_children(Pairs, Vertices).
|
||||||
|
@ -25,7 +25,10 @@
|
|||||||
|
|
||||||
:- ensure_loaded(library(lists)).
|
:- ensure_loaded(library(lists)).
|
||||||
|
|
||||||
:- load_foreign_files([matlab], ['eng','mx','ut'], init_matlab).
|
tell_warning :-
|
||||||
|
print_message(warning,functionality(matlab)).
|
||||||
|
|
||||||
|
:- ( catch(load_foreign_files([matlab], ['eng','mx','ut'], init_matlab),_,fail) -> true ; tell_warning).
|
||||||
|
|
||||||
matlab_eval_sequence(S) :-
|
matlab_eval_sequence(S) :-
|
||||||
atomic_concat(S,S1),
|
atomic_concat(S,S1),
|
||||||
|
@ -4671,6 +4671,11 @@ EndPredDefs
|
|||||||
|
|
||||||
#if __YAP_PROLOG__
|
#if __YAP_PROLOG__
|
||||||
|
|
||||||
|
void Yap_flush(void)
|
||||||
|
{
|
||||||
|
flush_output(0);
|
||||||
|
}
|
||||||
|
|
||||||
void *
|
void *
|
||||||
Yap_GetStreamHandle(Atom at)
|
Yap_GetStreamHandle(Atom at)
|
||||||
{ GET_LD
|
{ GET_LD
|
||||||
|
@ -16,6 +16,11 @@ YAPLIBDIR=@libdir@/Yap
|
|||||||
#
|
#
|
||||||
SHAREDIR=$(ROOTDIR)/share/Yap
|
SHAREDIR=$(ROOTDIR)/share/Yap
|
||||||
#
|
#
|
||||||
|
# where YAP should store documentation
|
||||||
|
#
|
||||||
|
DOCDIR=$(ROOTDIR)/share/doc/Yap
|
||||||
|
EXDIR=$(DOCDIR)/examples/CLPBN
|
||||||
|
#
|
||||||
#
|
#
|
||||||
# You shouldn't need to change what follows.
|
# You shouldn't need to change what follows.
|
||||||
#
|
#
|
||||||
@ -35,6 +40,7 @@ CLPBN_EXDIR = $(srcdir)/examples
|
|||||||
|
|
||||||
CLPBN_PROGRAMS= \
|
CLPBN_PROGRAMS= \
|
||||||
$(CLPBN_SRCDIR)/aggregates.yap \
|
$(CLPBN_SRCDIR)/aggregates.yap \
|
||||||
|
$(CLPBN_SRCDIR)/bdd.yap \
|
||||||
$(CLPBN_SRCDIR)/bnt.yap \
|
$(CLPBN_SRCDIR)/bnt.yap \
|
||||||
$(CLPBN_SRCDIR)/bp.yap \
|
$(CLPBN_SRCDIR)/bp.yap \
|
||||||
$(CLPBN_SRCDIR)/connected.yap \
|
$(CLPBN_SRCDIR)/connected.yap \
|
||||||
@ -48,6 +54,7 @@ CLPBN_PROGRAMS= \
|
|||||||
$(CLPBN_SRCDIR)/graphviz.yap \
|
$(CLPBN_SRCDIR)/graphviz.yap \
|
||||||
$(CLPBN_SRCDIR)/ground_factors.yap \
|
$(CLPBN_SRCDIR)/ground_factors.yap \
|
||||||
$(CLPBN_SRCDIR)/hmm.yap \
|
$(CLPBN_SRCDIR)/hmm.yap \
|
||||||
|
$(CLPBN_SRCDIR)/horus.yap \
|
||||||
$(CLPBN_SRCDIR)/jt.yap \
|
$(CLPBN_SRCDIR)/jt.yap \
|
||||||
$(CLPBN_SRCDIR)/matrix_cpt_utils.yap \
|
$(CLPBN_SRCDIR)/matrix_cpt_utils.yap \
|
||||||
$(CLPBN_SRCDIR)/pgrammar.yap \
|
$(CLPBN_SRCDIR)/pgrammar.yap \
|
||||||
@ -72,6 +79,8 @@ CLPBN_SCHOOL_EXAMPLES= \
|
|||||||
$(CLPBN_EXDIR)/School/parschema.yap \
|
$(CLPBN_EXDIR)/School/parschema.yap \
|
||||||
$(CLPBN_EXDIR)/School/school_128.yap \
|
$(CLPBN_EXDIR)/School/school_128.yap \
|
||||||
$(CLPBN_EXDIR)/School/school_32.yap \
|
$(CLPBN_EXDIR)/School/school_32.yap \
|
||||||
|
$(CLPBN_EXDIR)/School/sch32.yap \
|
||||||
|
$(CLPBN_EXDIR)/School/school32_data.yap \
|
||||||
$(CLPBN_EXDIR)/School/school_64.yap \
|
$(CLPBN_EXDIR)/School/school_64.yap \
|
||||||
$(CLPBN_EXDIR)/School/tables.yap
|
$(CLPBN_EXDIR)/School/tables.yap
|
||||||
|
|
||||||
@ -92,12 +101,13 @@ CLPBN_EXAMPLES= \
|
|||||||
install: $(CLBN_TOP) $(CLBN_PROGRAMS) $(CLPBN_PROGRAMS)
|
install: $(CLBN_TOP) $(CLBN_PROGRAMS) $(CLPBN_PROGRAMS)
|
||||||
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn
|
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn
|
||||||
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn/learning
|
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn/learning
|
||||||
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn/examples/School
|
mkdir -p $(DESTDIR)$(EXDIR)
|
||||||
mkdir -p $(DESTDIR)$(SHAREDIR)/clpbn/examples/HMMer
|
mkdir -p $(DESTDIR)$(EXDIR)/School
|
||||||
|
mkdir -p $(DESTDIR)$(EXDIR)/HMMer
|
||||||
for h in $(CLPBN_TOP); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR); done
|
for h in $(CLPBN_TOP); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR); done
|
||||||
for h in $(CLPBN_PROGRAMS); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn; done
|
for h in $(CLPBN_PROGRAMS); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn; done
|
||||||
for h in $(CLPBN_LEARNING_PROGRAMS); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/learning; done
|
for h in $(CLPBN_LEARNING_PROGRAMS); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/learning; done
|
||||||
for h in $(CLPBN_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/examples; done
|
for h in $(CLPBN_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(EXDIR); done
|
||||||
for h in $(CLPBN_SCHOOL_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/examples/School; done
|
for h in $(CLPBN_SCHOOL_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(EXDIR)/School; done
|
||||||
for h in $(CLPBN_HMMER_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(SHAREDIR)/clpbn/examples/HMMer; done
|
for h in $(CLPBN_HMMER_EXAMPLES); do $(INSTALL_DATA) $$h $(DESTDIR)$(EXDIR)/HMMer; done
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
clpbn_init_graph/1,
|
clpbn_init_graph/1,
|
||||||
probability/2,
|
probability/2,
|
||||||
conditional_probability/3,
|
conditional_probability/3,
|
||||||
|
use_parfactors/1,
|
||||||
op( 500, xfy, with)]).
|
op( 500, xfy, with)]).
|
||||||
|
|
||||||
:- use_module(library(atts)).
|
:- use_module(library(atts)).
|
||||||
@ -43,6 +44,7 @@
|
|||||||
check_if_bp_done/1,
|
check_if_bp_done/1,
|
||||||
init_bp_solver/4,
|
init_bp_solver/4,
|
||||||
run_bp_solver/3,
|
run_bp_solver/3,
|
||||||
|
call_bp_ground/6,
|
||||||
finalize_bp_solver/1
|
finalize_bp_solver/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
@ -61,11 +63,17 @@
|
|||||||
run_jt_solver/3
|
run_jt_solver/3
|
||||||
]).
|
]).
|
||||||
|
|
||||||
:- use_module('clpbn/bnt',
|
:- use_module('clpbn/bdd',
|
||||||
[do_bnt/3,
|
[bdd/3,
|
||||||
check_if_bnt_done/1
|
init_bdd_solver/4,
|
||||||
|
run_bdd_solver/3
|
||||||
]).
|
]).
|
||||||
|
|
||||||
|
%% :- use_module('clpbn/bnt',
|
||||||
|
%% [do_bnt/3,
|
||||||
|
%% check_if_bnt_done/1
|
||||||
|
%% ]).
|
||||||
|
|
||||||
:- use_module('clpbn/gibbs',
|
:- use_module('clpbn/gibbs',
|
||||||
[gibbs/3,
|
[gibbs/3,
|
||||||
check_if_gibbs_done/1,
|
check_if_gibbs_done/1,
|
||||||
@ -111,7 +119,7 @@
|
|||||||
[clpbn2gviz/4]).
|
[clpbn2gviz/4]).
|
||||||
|
|
||||||
:- use_module(clpbn/ground_factors,
|
:- use_module(clpbn/ground_factors,
|
||||||
[generate_bn/2]).
|
[generate_network/5]).
|
||||||
|
|
||||||
|
|
||||||
:- dynamic solver/1,output/1,use/1,suppress_attribute_display/1, parameter_softening/1, em_solver/1, use_parfactors/1.
|
:- dynamic solver/1,output/1,use/1,suppress_attribute_display/1, parameter_softening/1, em_solver/1, use_parfactors/1.
|
||||||
@ -223,9 +231,17 @@ clpbn_marginalise(V, Dist) :-
|
|||||||
% called by top-level
|
% called by top-level
|
||||||
% or by call_residue/2
|
% or by call_residue/2
|
||||||
%
|
%
|
||||||
project_attributes(GVars, AVars0) :-
|
project_attributes(GVars, _AVars0) :-
|
||||||
|
use_parfactors(on),
|
||||||
|
clpbn_flag(solver, Solver), Solver \= fove, !,
|
||||||
|
generate_network(GVars, GKeys, Keys, Factors, Evidence),
|
||||||
|
(ground(GVars) ->
|
||||||
|
true
|
||||||
|
;
|
||||||
|
call_ground_solver(Solver, GVars, GKeys, Keys, Factors, Evidence, _Avars0)
|
||||||
|
).
|
||||||
|
project_attributes(GVars, AVars) :-
|
||||||
suppress_attribute_display(false),
|
suppress_attribute_display(false),
|
||||||
generate_vars(GVars, AVars0, AVars),
|
|
||||||
AVars = [_|_],
|
AVars = [_|_],
|
||||||
solver(Solver),
|
solver(Solver),
|
||||||
( GVars = [_|_] ; Solver = graphs), !,
|
( GVars = [_|_] ; Solver = graphs), !,
|
||||||
@ -243,11 +259,6 @@ project_attributes(GVars, AVars0) :-
|
|||||||
).
|
).
|
||||||
project_attributes(_, _).
|
project_attributes(_, _).
|
||||||
|
|
||||||
generate_vars(GVars, _, NewAVars) :-
|
|
||||||
use_parfactors(on), !,
|
|
||||||
generate_bn(GVars, NewAVars).
|
|
||||||
generate_vars(_GVars, AVars, AVars).
|
|
||||||
|
|
||||||
clpbn_vars(AVars, DiffVars, AllVars) :-
|
clpbn_vars(AVars, DiffVars, AllVars) :-
|
||||||
sort_vars_by_key(AVars,SortedAVars,DiffVars),
|
sort_vars_by_key(AVars,SortedAVars,DiffVars),
|
||||||
incorporate_evidence(SortedAVars, AllVars).
|
incorporate_evidence(SortedAVars, AllVars).
|
||||||
@ -289,6 +300,8 @@ write_out(ve, GVars, AVars, DiffVars) :-
|
|||||||
ve(GVars, AVars, DiffVars).
|
ve(GVars, AVars, DiffVars).
|
||||||
write_out(jt, GVars, AVars, DiffVars) :-
|
write_out(jt, GVars, AVars, DiffVars) :-
|
||||||
jt(GVars, AVars, DiffVars).
|
jt(GVars, AVars, DiffVars).
|
||||||
|
write_out(bdd, GVars, AVars, DiffVars) :-
|
||||||
|
bdd(GVars, AVars, DiffVars).
|
||||||
write_out(bp, GVars, AVars, DiffVars) :-
|
write_out(bp, GVars, AVars, DiffVars) :-
|
||||||
bp(GVars, AVars, DiffVars).
|
bp(GVars, AVars, DiffVars).
|
||||||
write_out(gibbs, GVars, AVars, DiffVars) :-
|
write_out(gibbs, GVars, AVars, DiffVars) :-
|
||||||
@ -298,6 +311,11 @@ write_out(bnt, GVars, AVars, DiffVars) :-
|
|||||||
write_out(fove, GVars, AVars, DiffVars) :-
|
write_out(fove, GVars, AVars, DiffVars) :-
|
||||||
fove(GVars, AVars, DiffVars).
|
fove(GVars, AVars, DiffVars).
|
||||||
|
|
||||||
|
% call a solver with keys, not actual variables
|
||||||
|
call_ground_solver(bp, GVars, GoalKeys, Keys, Factors, Evidence, Answ) :-
|
||||||
|
call_bp_ground(GVars, GoalKeys, Keys, Factors, Evidence, Answ).
|
||||||
|
|
||||||
|
|
||||||
get_bnode(Var, Goal) :-
|
get_bnode(Var, Goal) :-
|
||||||
get_atts(Var, [key(Key),dist(Dist,Parents)]),
|
get_atts(Var, [key(Key),dist(Dist,Parents)]),
|
||||||
get_dist(Dist,_,Domain,CPT),
|
get_dist(Dist,_,Domain,CPT),
|
||||||
@ -382,6 +400,9 @@ bind_clpbn(_, Var, _, _, _, _, []) :-
|
|||||||
bind_clpbn(_, Var, _, _, _, _, []) :-
|
bind_clpbn(_, Var, _, _, _, _, []) :-
|
||||||
use(jt),
|
use(jt),
|
||||||
check_if_ve_done(Var), !.
|
check_if_ve_done(Var), !.
|
||||||
|
bind_clpbn(_, Var, _, _, _, _, []) :-
|
||||||
|
use(bdd),
|
||||||
|
check_if_bdd_done(Var), !.
|
||||||
bind_clpbn(T, Var, Key0, _, _, _, []) :-
|
bind_clpbn(T, Var, Key0, _, _, _, []) :-
|
||||||
get_atts(Var, [key(Key)]), !,
|
get_atts(Var, [key(Key)]), !,
|
||||||
(
|
(
|
||||||
@ -397,11 +418,12 @@ fresh_attvar(Var, NVar) :-
|
|||||||
|
|
||||||
% I will now allow two CLPBN variables to be bound together.
|
% I will now allow two CLPBN variables to be bound together.
|
||||||
%bind_clpbns(Key, Dist, Parents, Key, Dist, Parents).
|
%bind_clpbns(Key, Dist, Parents, Key, Dist, Parents).
|
||||||
bind_clpbns(Key, Dist, _Parents, Key1, Dist1, _Parents1) :-
|
bind_clpbns(Key, Dist, Parents, Key1, Dist1, Parents1) :-
|
||||||
Key == Key1, !,
|
Key == Key1, !,
|
||||||
get_dist(Dist,_Type,_Domain,_Table),
|
get_dist(Dist,_Type,_Domain,_Table),
|
||||||
get_dist(Dist1,_Type1,_Domain1,_Table1),
|
get_dist(Dist1,_Type1,_Domain1,_Table1),
|
||||||
Dist = Dist1.
|
Dist = Dist1,
|
||||||
|
Parents = Parents1.
|
||||||
bind_clpbns(Key, _, _, _, Key1, _, _, _) :-
|
bind_clpbns(Key, _, _, _, Key1, _, _, _) :-
|
||||||
Key\=Key1, !, fail.
|
Key\=Key1, !, fail.
|
||||||
bind_clpbns(_, _, _, _, _, _, _, _) :-
|
bind_clpbns(_, _, _, _, _, _, _, _) :-
|
||||||
@ -452,6 +474,8 @@ clpbn_init_solver(bp, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
|||||||
init_bp_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
init_bp_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
||||||
clpbn_init_solver(jt, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
clpbn_init_solver(jt, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
||||||
init_jt_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
init_jt_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
||||||
|
clpbn_init_solver(bdd, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
||||||
|
init_bdd_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
||||||
clpbn_init_solver(pcg, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
clpbn_init_solver(pcg, LVs, Vs0, VarsWithUnboundKeys, State) :-
|
||||||
init_pcg_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
init_pcg_solver(LVs, Vs0, VarsWithUnboundKeys, State).
|
||||||
|
|
||||||
@ -478,6 +502,9 @@ clpbn_run_solver(bp, LVs, LPs, State) :-
|
|||||||
clpbn_run_solver(jt, LVs, LPs, State) :-
|
clpbn_run_solver(jt, LVs, LPs, State) :-
|
||||||
run_jt_solver(LVs, LPs, State).
|
run_jt_solver(LVs, LPs, State).
|
||||||
|
|
||||||
|
clpbn_run_solver(bdd, LVs, LPs, State) :-
|
||||||
|
run_bdd_solver(LVs, LPs, State).
|
||||||
|
|
||||||
clpbn_run_solver(pcg, LVs, LPs, State) :-
|
clpbn_run_solver(pcg, LVs, LPs, State) :-
|
||||||
run_pcg_solver(LVs, LPs, State).
|
run_pcg_solver(LVs, LPs, State).
|
||||||
|
|
||||||
@ -538,4 +565,5 @@ match_probability([p(V0=C)=Prob|_], C, V, Prob) :-
|
|||||||
match_probability([_|Probs], C, V, Prob) :-
|
match_probability([_|Probs], C, V, Prob) :-
|
||||||
match_probability(Probs, C, V, Prob).
|
match_probability(Probs, C, V, Prob).
|
||||||
|
|
||||||
|
:- use_parfactors(on) -> true ; assert(use_parfactors(off)).
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@
|
|||||||
|
|
||||||
:- use_module(library('clpbn/dists'),
|
:- use_module(library('clpbn/dists'),
|
||||||
[
|
[
|
||||||
dist/4,
|
add_dist/6,
|
||||||
get_dist_domain_size/2]).
|
get_dist_domain_size/2]).
|
||||||
|
|
||||||
:- use_module(library('clpbn/matrix_cpt_utils'),
|
:- use_module(library('clpbn/matrix_cpt_utils'),
|
||||||
@ -44,8 +44,9 @@ check_for_agg_vars([_|Vs0], Vs1) :-
|
|||||||
% transform aggregate distribution into tree
|
% transform aggregate distribution into tree
|
||||||
simplify_dist(avg(Domain), V, Key, Parents, Vs0, VsF) :- !,
|
simplify_dist(avg(Domain), V, Key, Parents, Vs0, VsF) :- !,
|
||||||
cpt_average([V|Parents], Key, Domain, NewDist, Vs0, VsF),
|
cpt_average([V|Parents], Key, Domain, NewDist, Vs0, VsF),
|
||||||
dist(NewDist, Id, Key, ParentsF),
|
NewDist = p(Dom, Tab, Ps),
|
||||||
clpbn:put_atts(V, [dist(Id,ParentsF)]).
|
add_dist(Dom, tab, Tab, Ps, Key, Id),
|
||||||
|
clpbn:put_atts(V, [dist(Id,Ps)]).
|
||||||
simplify_dist(_, _, _, _, Vs0, Vs0).
|
simplify_dist(_, _, _, _, Vs0, Vs0).
|
||||||
|
|
||||||
cpt_average(AllVars, Key, Els0, Tab, Vs, NewVs) :-
|
cpt_average(AllVars, Key, Els0, Tab, Vs, NewVs) :-
|
||||||
|
802
packages/CLPBN/clpbn/bdd.yap
Normal file
802
packages/CLPBN/clpbn/bdd.yap
Normal file
@ -0,0 +1,802 @@
|
|||||||
|
|
||||||
|
/************************************************
|
||||||
|
|
||||||
|
BDDs in CLP(BN)
|
||||||
|
|
||||||
|
A variable is represented by the N possible cases it can take
|
||||||
|
|
||||||
|
V = v(Va, Vb, Vc)
|
||||||
|
|
||||||
|
The generic formula is
|
||||||
|
|
||||||
|
V <- X, Y
|
||||||
|
|
||||||
|
Va <- P*X1*Y1 + Q*X2*Y2 + ...
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
**************************************************/
|
||||||
|
|
||||||
|
:- module(clpbn_bdd,
|
||||||
|
[bdd/3,
|
||||||
|
set_solver_parameter/2,
|
||||||
|
init_bdd_solver/4,
|
||||||
|
run_bdd_solver/3,
|
||||||
|
finalize_bdd_solver/1,
|
||||||
|
check_if_bdd_done/1
|
||||||
|
]).
|
||||||
|
|
||||||
|
|
||||||
|
:- use_module(library('clpbn/dists'),
|
||||||
|
[dist/4,
|
||||||
|
get_dist_domain/2,
|
||||||
|
get_dist_domain_size/2,
|
||||||
|
get_dist_all_sizes/2,
|
||||||
|
get_dist_params/2
|
||||||
|
]).
|
||||||
|
|
||||||
|
|
||||||
|
:- use_module(library('clpbn/display'),
|
||||||
|
[clpbn_bind_vals/3]).
|
||||||
|
|
||||||
|
:- use_module(library('clpbn/aggregates'),
|
||||||
|
[check_for_agg_vars/2]).
|
||||||
|
|
||||||
|
|
||||||
|
:- use_module(library(atts)).
|
||||||
|
|
||||||
|
:- use_module(library(hacks)).
|
||||||
|
|
||||||
|
:- use_module(library(lists)).
|
||||||
|
|
||||||
|
:- use_module(library(dgraphs)).
|
||||||
|
|
||||||
|
:- use_module(library(bdd)).
|
||||||
|
|
||||||
|
:- use_module(library(rbtrees)).
|
||||||
|
|
||||||
|
:- use_module(library(bhash)).
|
||||||
|
|
||||||
|
:- use_module(library(matrix)).
|
||||||
|
|
||||||
|
:- dynamic network_counting/1.
|
||||||
|
|
||||||
|
:- attribute order/1.
|
||||||
|
|
||||||
|
check_if_bdd_done(_Var).
|
||||||
|
|
||||||
|
bdd([[]],_,_) :- !.
|
||||||
|
bdd([QueryVars], AllVars, AllDiffs) :-
|
||||||
|
init_bdd_solver(_, AllVars, _, BayesNet),
|
||||||
|
run_bdd_solver([QueryVars], LPs, BayesNet),
|
||||||
|
finalize_bdd_solver(BayesNet),
|
||||||
|
clpbn_bind_vals([QueryVars], [LPs], AllDiffs).
|
||||||
|
|
||||||
|
init_bdd_solver(_, AllVars0, _, bdd(Term, Leaves, Tops)) :-
|
||||||
|
% check_for_agg_vars(AllVars0, AllVars1),
|
||||||
|
sort_vars(AllVars0, AllVars, Leaves),
|
||||||
|
order_vars(AllVars, 0),
|
||||||
|
rb_new(Vars0),
|
||||||
|
rb_new(Pars0),
|
||||||
|
init_tops(Leaves,Tops),
|
||||||
|
get_vars_info(AllVars, Vars0, _Vars, Pars0, _Pars, Leaves, Tops, Term, []).
|
||||||
|
|
||||||
|
order_vars([], _).
|
||||||
|
order_vars([V|AllVars], I0) :-
|
||||||
|
put_atts(V, [order(I0)]),
|
||||||
|
I is I0+1,
|
||||||
|
order_vars(AllVars, I).
|
||||||
|
|
||||||
|
|
||||||
|
init_tops([],[]).
|
||||||
|
init_tops(_.Leaves,_.Tops) :-
|
||||||
|
init_tops(Leaves,Tops).
|
||||||
|
|
||||||
|
sort_vars(AllVars0, AllVars, Leaves) :-
|
||||||
|
dgraph_new(Graph0),
|
||||||
|
build_graph(AllVars0, Graph0, Graph),
|
||||||
|
dgraph_leaves(Graph, Leaves),
|
||||||
|
dgraph_top_sort(Graph, AllVars).
|
||||||
|
|
||||||
|
build_graph([], Graph, Graph).
|
||||||
|
build_graph(V.AllVars0, Graph0, Graph) :-
|
||||||
|
clpbn:get_atts(V, [dist(_DistId, Parents)]), !,
|
||||||
|
dgraph_add_vertex(Graph0, V, Graph1),
|
||||||
|
add_parents(Parents, V, Graph1, GraphI),
|
||||||
|
build_graph(AllVars0, GraphI, Graph).
|
||||||
|
build_graph(_V.AllVars0, Graph0, Graph) :-
|
||||||
|
build_graph(AllVars0, Graph0, Graph).
|
||||||
|
|
||||||
|
add_parents([], _V, Graph, Graph).
|
||||||
|
add_parents(V0.Parents, V, Graph0, GraphF) :-
|
||||||
|
dgraph_add_edge(Graph0, V0, V, GraphI),
|
||||||
|
add_parents(Parents, V, GraphI, GraphF).
|
||||||
|
|
||||||
|
get_vars_info([], Vs, Vs, Ps, Ps, _, _) --> [].
|
||||||
|
get_vars_info([V|MoreVs], Vs, VsF, Ps, PsF, Lvs, Outs) -->
|
||||||
|
{ clpbn:get_atts(V, [dist(DistId, Parents)]) }, !,
|
||||||
|
%{writeln(v:DistId:Parents)},
|
||||||
|
[DIST],
|
||||||
|
{ get_var_info(V, DistId, Parents, Vs, Vs2, Ps, Ps1, Lvs, Outs, DIST) },
|
||||||
|
get_vars_info(MoreVs, Vs2, VsF, Ps1, PsF, Lvs, Outs).
|
||||||
|
get_vars_info([_|MoreVs], Vs0, VsF, Ps0, PsF, VarsInfo, Lvs, Outs) :-
|
||||||
|
get_vars_info(MoreVs, Vs0, VsF, Ps0, PsF, VarsInfo, Lvs, Outs).
|
||||||
|
|
||||||
|
%
|
||||||
|
% let's have some fun with avg
|
||||||
|
%
|
||||||
|
get_var_info(V, avg(Domain), Parents, Vs, Vs2, Ps, Ps, Lvs, Outs, DIST) :- !,
|
||||||
|
length(Domain, DSize),
|
||||||
|
% run_though_avg(V, DSize, Domain, Parents, Vs, Vs2, Lvs, Outs, DIST).
|
||||||
|
top_down_with_tabling(V, DSize, Domain, Parents, Vs, Vs2, Lvs, Outs, DIST).
|
||||||
|
% bup_avg(V, DSize, Domain, Parents, Vs, Vs2, Lvs, Outs, DIST).
|
||||||
|
% standard random variable
|
||||||
|
get_var_info(V, DistId, Parents0, Vs, Vs2, Ps, Ps1, Lvs, Outs, DIST) :-
|
||||||
|
% clpbn:get_atts(V, [key(K)]), writeln(V:K:DistId:Parents),
|
||||||
|
reorder_vars(Parents0, Parents, Map),
|
||||||
|
check_p(DistId, Map, Parms, _ParmVars, Ps, Ps1),
|
||||||
|
unbound_parms(Parms, ParmVars),
|
||||||
|
check_v(V, DistId, DIST, Vs, Vs1),
|
||||||
|
DIST = info(V, Tree, Ev, Values, Formula, ParmVars, Parms),
|
||||||
|
% get a list of form [[P00,P01], [P10,P11], [P20,P21]]
|
||||||
|
get_parents(Parents, PVars, Vs1, Vs2),
|
||||||
|
cross_product(Values, Ev, PVars, ParmVars, Formula0),
|
||||||
|
% (numbervars(Formula0,0,_),writeln(formula0:Ev:Formula0), fail ; true),
|
||||||
|
get_evidence(V, Tree, Ev, Formula0, Formula, Lvs, Outs).
|
||||||
|
%, (numbervars(Formula,0,_),writeln(formula:Formula), fail ; true)
|
||||||
|
|
||||||
|
%
|
||||||
|
% reorder all variables and make sure we get a
|
||||||
|
% map of how the transfer was done.
|
||||||
|
%
|
||||||
|
% position zero is output
|
||||||
|
%
|
||||||
|
reorder_vars(Vs, OVs, Map) :-
|
||||||
|
add_pos(Vs, 1, PVs),
|
||||||
|
keysort(PVs, SVs),
|
||||||
|
remove_key(SVs, OVs, Map).
|
||||||
|
|
||||||
|
add_pos([], _, []).
|
||||||
|
add_pos([V|Vs], I0, [K-(I0,V)|PVs]) :-
|
||||||
|
get_atts(V,[order(K)]),
|
||||||
|
I is I0+1,
|
||||||
|
add_pos(Vs, I, PVs).
|
||||||
|
|
||||||
|
remove_key([], [], []).
|
||||||
|
remove_key([_-(I,V)|SVs], [V|OVs], [I|Map]) :-
|
||||||
|
remove_key(SVs, OVs, Map).
|
||||||
|
|
||||||
|
%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||||
|
%
|
||||||
|
% use top-down to generate average
|
||||||
|
%
|
||||||
|
run_though_avg(V, 3, Domain, Parents0, Vs, Vs2, Lvs, Outs, DIST) :-
|
||||||
|
reorder_vars(Parents0, Parents, _Map),
|
||||||
|
check_v(V, avg(Domain,Parents0), DIST, Vs, Vs1),
|
||||||
|
DIST = info(V, Tree, Ev, [V0,V1,V2], Formula, [], []),
|
||||||
|
get_parents(Parents, PVars, Vs1, Vs2),
|
||||||
|
length(Parents, N),
|
||||||
|
generate_3tree(F00, PVars, 0, 0, 0, N, N0, N1, N2, R, (N1+2*N2 =< N/2), (N1+2*(N2+R) =< N/2)),
|
||||||
|
simplify_exp(F00, F0),
|
||||||
|
% generate_3tree(F1, PVars, 0, 0, 0, N, N0, N1, N2, R, ((N1+2*(N2+R) > N/2, N1+2*N2 < (3*N)/2))),
|
||||||
|
generate_3tree(F20, PVars, 0, 0, 0, N, N0, N1, N2, R, (N1+2*(N2+R) >= (3*N)/2), N1+2*N2 >= (3*N)/2),
|
||||||
|
% simplify_exp(F20, F2),
|
||||||
|
F20=F2,
|
||||||
|
Formula0 = [V0=F0*Ev0,V2=F2*Ev2,V1=not(F0+F2)*Ev1],
|
||||||
|
Ev = [Ev0,Ev1,Ev2],
|
||||||
|
get_evidence(V, Tree, Ev, Formula0, Formula, Lvs, Outs).
|
||||||
|
|
||||||
|
generate_3tree(OUT, _, I00, I10, I20, IR0, N0, N1, N2, R, _Exp, ExpF) :-
|
||||||
|
IR is IR0-1,
|
||||||
|
satisf(I00, I10, I20, IR, N0, N1, N2, R, ExpF),
|
||||||
|
!,
|
||||||
|
OUT = 1.
|
||||||
|
generate_3tree(OUT, [[P0,P1,P2]], I00, I10, I20, IR0, N0, N1, N2, R, Exp, _ExpF) :-
|
||||||
|
IR is IR0-1,
|
||||||
|
( satisf(I00+1, I10, I20, IR, N0, N1, N2, R, Exp) ->
|
||||||
|
L0 = [P0|L1]
|
||||||
|
;
|
||||||
|
L0 = L1
|
||||||
|
),
|
||||||
|
( satisf(I00, I10+1, I20, IR, N0, N1, N2, R, Exp) ->
|
||||||
|
L1 = [P1|L2]
|
||||||
|
;
|
||||||
|
L1 = L2
|
||||||
|
),
|
||||||
|
( satisf(I00, I10, I20+1, IR, N0, N1, N2, R, Exp) ->
|
||||||
|
L2 = [P2]
|
||||||
|
;
|
||||||
|
L2 = []
|
||||||
|
),
|
||||||
|
to_disj(L0, OUT).
|
||||||
|
generate_3tree(OUT, [[P0,P1,P2]|Ps], I00, I10, I20, IR0, N0, N1, N2, R, Exp, ExpF) :-
|
||||||
|
IR is IR0-1,
|
||||||
|
( satisf(I00+1, I10, I20, IR, N0, N1, N2, R, Exp) ->
|
||||||
|
I0 is I00+1, generate_3tree(O0, Ps, I0, I10, I20, IR, N0, N1, N2, R, Exp, ExpF)
|
||||||
|
->
|
||||||
|
L0 = [P0*O0|L1]
|
||||||
|
;
|
||||||
|
L0 = L1
|
||||||
|
),
|
||||||
|
( satisf(I00, I10+1, I20, IR0, N0, N1, N2, R, Exp) ->
|
||||||
|
I1 is I10+1, generate_3tree(O1, Ps, I00, I1, I20, IR, N0, N1, N2, R, Exp, ExpF)
|
||||||
|
->
|
||||||
|
L1 = [P1*O1|L2]
|
||||||
|
;
|
||||||
|
L1 = L2
|
||||||
|
),
|
||||||
|
( satisf(I00, I10, I20+1, IR0, N0, N1, N2, R, Exp) ->
|
||||||
|
I2 is I20+1, generate_3tree(O2, Ps, I00, I10, I2, IR, N0, N1, N2, R, Exp, ExpF)
|
||||||
|
->
|
||||||
|
L2 = [P2*O2]
|
||||||
|
;
|
||||||
|
L2 = []
|
||||||
|
),
|
||||||
|
to_disj(L0, OUT).
|
||||||
|
|
||||||
|
|
||||||
|
satisf(I0, I1, I2, IR, N0, N1, N2, R, Exp) :-
|
||||||
|
\+ \+ ( I0 = N0, I1=N1, I2=N2, IR=R, call(Exp) ).
|
||||||
|
|
||||||
|
not_satisf(I0, I1, I2, IR, N0, N1, N2, R, Exp) :-
|
||||||
|
\+ ( I0 = N0, I1=N1, I2=N2, IR=R, call(Exp) ).
|
||||||
|
|
||||||
|
%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||||
|
%
|
||||||
|
% use top-down to generate average
|
||||||
|
%
|
||||||
|
top_down_with_tabling(V, Size, Domain, Parents0, Vs, Vs2, Lvs, Outs, DIST) :-
|
||||||
|
reorder_vars(Parents0, Parents, _Map),
|
||||||
|
check_v(V, avg(Domain,Parents), DIST, Vs, Vs1),
|
||||||
|
DIST = info(V, Tree, Ev, OVs, Formula, [], []),
|
||||||
|
get_parents(Parents, PVars, Vs1, Vs2),
|
||||||
|
length(Parents, N),
|
||||||
|
Max is (Size-1)*N, % This should be true
|
||||||
|
avg_borders(0, Size, Max, Borders),
|
||||||
|
b_hash_new(H0),
|
||||||
|
avg_trees(0, Max, PVars, Size, F1, 0, Borders, OVs, Ev, H0, H),
|
||||||
|
generate_avg_code(H, Formula, F),
|
||||||
|
% Formula0 = [V0=F0*Ev0,V2=F2*Ev2,V1=not(F0+F2)*Ev1],
|
||||||
|
% Ev = [Ev0,Ev1,Ev2],
|
||||||
|
get_evidence(V, Tree, Ev, F1, F, Lvs, Outs).
|
||||||
|
|
||||||
|
avg_trees(Size, _, _, Size, F0, _, F0, [], [], H, H) :- !.
|
||||||
|
avg_trees(I0, Max, PVars, Size, [V=O*E|F0], Im, [IM|Borders], [V|OVs], [E|Ev], H0, H) :-
|
||||||
|
I is I0+1,
|
||||||
|
avg_tree(PVars, 0, Max, Im, IM, Size, O, H0, HI),
|
||||||
|
Im1 is IM+1,
|
||||||
|
avg_trees(I, Max, PVars, Size, F0, Im1, Borders, OVs, Ev, HI, H).
|
||||||
|
|
||||||
|
avg_tree( _PVars, P, _, Im, IM, _Size, O, H0, H0) :-
|
||||||
|
b_hash_lookup(k(P,Im,IM), O=_Exp, H0), !.
|
||||||
|
avg_tree([], _P, _Max, _Im, _IM, _Size, 1, H, H).
|
||||||
|
avg_tree([Vals|PVars], P, Max, Im, IM, Size, O, H0, HF) :-
|
||||||
|
b_hash_insert(H0, k(P,Im,IM), O=Simp, HI),
|
||||||
|
MaxI is Max-(Size-1),
|
||||||
|
avg_exp(Vals, PVars, 0, P, MaxI, Size, Im, IM, HI, HF, Exp),
|
||||||
|
simplify_exp(Exp, Simp).
|
||||||
|
|
||||||
|
avg_exp([], _, _, _P, _Max, _Size, _Im, _IM, H, H, 0).
|
||||||
|
avg_exp([Val|Vals], PVars, I0, P0, Max, Size, Im, IM, HI, HF, O) :-
|
||||||
|
(Vals = [] -> O=O1 ; O = Val*O1+not(Val)*O2 ),
|
||||||
|
Im1 is max(0, Im-I0),
|
||||||
|
IM1 is IM-I0,
|
||||||
|
( IM1 < 0 -> O1 = 0, H2 = HI; /* we have exceed maximum */
|
||||||
|
Im1 > Max -> O1 = 0, H2 = HI; /* we cannot make to minimum */
|
||||||
|
Im1 = 0, IM1 > Max -> O1 = 1, H2 = HI; /* we cannot exceed maximum */
|
||||||
|
P is P0+1,
|
||||||
|
avg_tree(PVars, P, Max, Im1, IM1, Size, O1, HI, H2)
|
||||||
|
),
|
||||||
|
I is I0+1,
|
||||||
|
avg_exp(Vals, PVars, I, P0, Max, Size, Im, IM, H2, HF, O2).
|
||||||
|
|
||||||
|
generate_avg_code(H, Formula, Formula0) :-
|
||||||
|
b_hash_to_list(H,L),
|
||||||
|
sort(L, S),
|
||||||
|
strip_and_add(S, Formula0, Formula).
|
||||||
|
|
||||||
|
strip_and_add([], F, F).
|
||||||
|
strip_and_add([_-Exp|S], F0, F) :-
|
||||||
|
strip_and_add(S, [Exp|F0], F).
|
||||||
|
|
||||||
|
%%%%%%%%%%%%%%%%%%%%%%%%%
|
||||||
|
%
|
||||||
|
% use bottom-up dynamic programming to generate average
|
||||||
|
%
|
||||||
|
bup_avg(V, Size, Domain, Parents0, Vs, Vs2, Lvs, Outs, DIST) :-
|
||||||
|
reorder_vars(Parents0, Parents, _),
|
||||||
|
check_v(V, avg(Domain,Parents), DIST, Vs, Vs1),
|
||||||
|
DIST = info(V, Tree, Ev, OVs, Formula, [], []),
|
||||||
|
get_parents(Parents, PVars, Vs1, Vs2),
|
||||||
|
length(Parents, N),
|
||||||
|
Max is (Size-1)*N, % This should be true
|
||||||
|
ArraySize is Max+1,
|
||||||
|
functor(Protected, protected, ArraySize),
|
||||||
|
avg_domains(0, Size, 0, Max, LDomains),
|
||||||
|
Domains =.. [d|LDomains],
|
||||||
|
Reach is (Size-1),
|
||||||
|
generate_sums(PVars, Size, Max, Reach, Protected, Domains, ArraySize, Sums, F0),
|
||||||
|
% bin_sums(PVars, Sums, F00),
|
||||||
|
% reverse(F00,F0),
|
||||||
|
% easier to do recursion on lists
|
||||||
|
Sums =.. [_|LSums],
|
||||||
|
generate_avg(0, Size, 0, Max, LSums, OVs, Ev, F1, []),
|
||||||
|
reverse(F0, RF0),
|
||||||
|
get_evidence(V, Tree, Ev, F1, F2, Lvs, Outs),
|
||||||
|
append(RF0, F2, Formula).
|
||||||
|
|
||||||
|
%
|
||||||
|
% use binary approach, like what is standard
|
||||||
|
%
|
||||||
|
bin_sums(Vs, Sums, F) :-
|
||||||
|
vs_to_sums(Vs, Sums0),
|
||||||
|
bin_sums(Sums0, Sums, F, []).
|
||||||
|
|
||||||
|
vs_to_sums([], []).
|
||||||
|
vs_to_sums([V|Vs], [Sum|Sums0]) :-
|
||||||
|
Sum =.. [sum|V],
|
||||||
|
vs_to_sums(Vs, Sums0).
|
||||||
|
|
||||||
|
bin_sums([Sum], Sum) --> !.
|
||||||
|
bin_sums(LSums, Sum) -->
|
||||||
|
{ halve(LSums, Sums1, Sums2) },
|
||||||
|
bin_sums(Sums1, Sum1),
|
||||||
|
bin_sums(Sums2, Sum2),
|
||||||
|
sum(Sum1, Sum2, Sum).
|
||||||
|
|
||||||
|
halve(LSums, Sums1, Sums2) :-
|
||||||
|
length(LSums, L),
|
||||||
|
Take is L div 2,
|
||||||
|
head(Take, LSums, Sums1, Sums2).
|
||||||
|
|
||||||
|
head(0, L, [], L) :- !.
|
||||||
|
head(Take, [H|L], [H|Sums1], Sum2) :-
|
||||||
|
Take1 is Take-1,
|
||||||
|
head(Take1, L, Sums1, Sum2).
|
||||||
|
|
||||||
|
sum(Sum1, Sum2, Sum) -->
|
||||||
|
{ functor(Sum1, _, M1),
|
||||||
|
functor(Sum2, _, M2),
|
||||||
|
Max is M1+M2-2,
|
||||||
|
Max1 is Max+1,
|
||||||
|
Max0 is M2-1,
|
||||||
|
functor(Sum, sum, Max1),
|
||||||
|
Sum1 =.. [_|PVals] },
|
||||||
|
expand_sums(PVals, 0, Max0, Max1, M2, Sum2, Sum).
|
||||||
|
|
||||||
|
%
|
||||||
|
% bottom up step by step
|
||||||
|
%
|
||||||
|
%
|
||||||
|
generate_sums([PVals], Size, Max, _, _Protected, _Domains, _, Sum, []) :- !,
|
||||||
|
Max is Size-1,
|
||||||
|
Sum =.. [sum|PVals].
|
||||||
|
generate_sums([PVals|Parents], Size, Max, Reach, Protected, Domains, ASize, NewSums, F) :-
|
||||||
|
NewReach is Reach+(Size-1),
|
||||||
|
generate_sums(Parents, Size, Max0, NewReach, Protected, Domains, ASize, Sums, F0),
|
||||||
|
Max is Max0+(Size-1),
|
||||||
|
Max1 is Max+1,
|
||||||
|
functor(NewSums, sum, Max1),
|
||||||
|
protect_avg(0, Max0, Protected, Domains, ASize, Reach),
|
||||||
|
expand_sums(PVals, 0, Max0, Max1, Size, Sums, Protected, NewSums, F, F0).
|
||||||
|
|
||||||
|
protect_avg(Max0,Max0,_Protected, _Domains, _ASize, _Reach) :- !.
|
||||||
|
protect_avg(I0, Max0, Protected, Domains, ASize, Reach) :-
|
||||||
|
I is I0+1,
|
||||||
|
Top is I+Reach,
|
||||||
|
( Top > ASize ;
|
||||||
|
arg(I, Domains, CD),
|
||||||
|
arg(Top, Domains, CD)
|
||||||
|
), !,
|
||||||
|
arg(I, Protected, yes),
|
||||||
|
protect_avg(I, Max0, Protected, Domains, ASize, Reach).
|
||||||
|
protect_avg(I0, Max0, Protected, Domains, ASize, Reach) :-
|
||||||
|
I is I0+1,
|
||||||
|
protect_avg(I, Max0, Protected, Domains, ASize, Reach).
|
||||||
|
|
||||||
|
|
||||||
|
%
|
||||||
|
% outer loop: generate array of sums at level j= Sum[j0...jMax]
|
||||||
|
%
|
||||||
|
expand_sums(_Parents, Max, _, Max, _Size, _Sums, _P, _NewSums, F0, F0) :- !.
|
||||||
|
expand_sums(Parents, I0, Max0, Max, Size, Sums, Prot, NewSums, [O=SUM|F], F0) :-
|
||||||
|
I is I0+1,
|
||||||
|
arg(I, Prot, P),
|
||||||
|
var(P), !,
|
||||||
|
arg(I, NewSums, O),
|
||||||
|
sum_all(Parents, 0, I0, Max0, Sums, List),
|
||||||
|
to_disj(List, SUM),
|
||||||
|
expand_sums(Parents, I, Max0, Max, Size, Sums, Prot, NewSums, F, F0).
|
||||||
|
expand_sums(Parents, I0, Max0, Max, Size, Sums, Prot, NewSums, F, F0) :-
|
||||||
|
I is I0+1,
|
||||||
|
arg(I, Sums, O),
|
||||||
|
arg(I, NewSums, O),
|
||||||
|
expand_sums(Parents, I, Max0, Max, Size, Sums, Prot, NewSums, F, F0).
|
||||||
|
|
||||||
|
%
|
||||||
|
%inner loop: find all parents that contribute to A_ji,
|
||||||
|
% that is generate Pk*Sum_(j-1)l and k+l st k+l = i
|
||||||
|
%
|
||||||
|
sum_all([], _, _, _, _, []).
|
||||||
|
sum_all([V|Vs], Pos, I, Max0, Sums, [O|List]) :-
|
||||||
|
J is I-Pos,
|
||||||
|
J >= 0,
|
||||||
|
J =< Max0, !,
|
||||||
|
J1 is J+1,
|
||||||
|
arg(J1, Sums, S0),
|
||||||
|
( J < I -> O = V*S0 ; O = S0*V ),
|
||||||
|
Pos1 is Pos+1,
|
||||||
|
sum_all(Vs, Pos1, I, Max0, Sums, List).
|
||||||
|
sum_all([_V|Vs], Pos, I, Max0, Sums, List) :-
|
||||||
|
Pos1 is Pos+1,
|
||||||
|
sum_all(Vs, Pos1, I, Max0, Sums, List).
|
||||||
|
|
||||||
|
gen_arg(J, Sums, Max, S0) :-
|
||||||
|
gen_arg(0, Max, J, Sums, S0).
|
||||||
|
|
||||||
|
gen_arg(Max, Max, J, Sums, S0) :- !,
|
||||||
|
I is Max+1,
|
||||||
|
arg(I, Sums, A),
|
||||||
|
( Max = J -> S0 = A ; S0 = not(A)).
|
||||||
|
gen_arg(I0, Max, J, Sums, S) :-
|
||||||
|
I is I0+1,
|
||||||
|
arg(I, Sums, A),
|
||||||
|
( I0 = J -> S = A*S0 ; S = not(A)*S0),
|
||||||
|
gen_arg(I, Max, J, Sums, S0).
|
||||||
|
|
||||||
|
|
||||||
|
avg_borders(Size, Size, _Max, []) :- !.
|
||||||
|
avg_borders(I0, Size, Max, [J|Vals]) :-
|
||||||
|
I is I0+1,
|
||||||
|
Border is (I*Max)/Size,
|
||||||
|
J is integer(round(Border)),
|
||||||
|
avg_borders(I, Size, Max, Vals).
|
||||||
|
|
||||||
|
avg_domains(Size, Size, _J, _Max, []).
|
||||||
|
avg_domains(I0, Size, J0, Max, Vals) :-
|
||||||
|
I is I0+1,
|
||||||
|
Border is (I*Max)/Size,
|
||||||
|
fetch_domain_for_avg(J0, Border, J, I0, Vals, ValsI),
|
||||||
|
avg_domains(I, Size, J, Max, ValsI).
|
||||||
|
|
||||||
|
fetch_domain_for_avg(J, Border, J, _, Vals, Vals) :-
|
||||||
|
J > Border, !.
|
||||||
|
fetch_domain_for_avg(J0, Border, J, I0, [I0|LVals], RLVals) :-
|
||||||
|
J1 is J0+1,
|
||||||
|
fetch_domain_for_avg(J1, Border, J, I0, LVals, RLVals).
|
||||||
|
|
||||||
|
generate_avg(Size, Size, _J, _Max, [], [], [], F, F).
|
||||||
|
generate_avg(I0, Size, J0, Max, LSums, [O|OVs], [Ev|Evs], [O=Ev*Disj|F], F0) :-
|
||||||
|
I is I0+1,
|
||||||
|
Border is (I*Max)/Size,
|
||||||
|
fetch_for_avg(J0, Border, J, LSums, MySums, RSums),
|
||||||
|
to_disj(MySums, Disj),
|
||||||
|
generate_avg(I, Size, J, Max, RSums, OVs, Evs, F, F0).
|
||||||
|
|
||||||
|
fetch_for_avg(J, Border, J, RSums, [], RSums) :-
|
||||||
|
J > Border, !.
|
||||||
|
fetch_for_avg(J0, Border, J, [S|LSums], [S|MySums], RSums) :-
|
||||||
|
J1 is J0+1,
|
||||||
|
fetch_for_avg(J1, Border, J, LSums, MySums, RSums).
|
||||||
|
|
||||||
|
|
||||||
|
to_disj([], 0).
|
||||||
|
to_disj([V], V).
|
||||||
|
to_disj([V,V1|Vs], Out) :-
|
||||||
|
to_disj2([V1|Vs], V, Out).
|
||||||
|
|
||||||
|
to_disj2([V], V0, V0+V).
|
||||||
|
to_disj2([V,V1|Vs], V0, Out) :-
|
||||||
|
to_disj2([V1|Vs], V0+V, Out).
|
||||||
|
|
||||||
|
|
||||||
|
%
|
||||||
|
% look for parameters in the rb-tree, or add a new.
|
||||||
|
% distid is the key
|
||||||
|
%
|
||||||
|
check_p(DistId, Map, Parms, ParmVars, Ps, Ps) :-
|
||||||
|
rb_lookup(DistId-Map, theta(Parms, ParmVars), Ps), !.
|
||||||
|
check_p(DistId, Map, Parms, ParmVars, Ps, PsF) :-
|
||||||
|
get_dist_params(DistId, Parms0),
|
||||||
|
get_dist_all_sizes(DistId, Sizes),
|
||||||
|
swap_parms(Parms0, Sizes, [0|Map], Parms1),
|
||||||
|
length(Parms1, L0),
|
||||||
|
get_dist_domain_size(DistId, Size),
|
||||||
|
L1 is L0 div Size,
|
||||||
|
L is L0-L1,
|
||||||
|
initial_maxes(L1, Multipliers),
|
||||||
|
copy(L, Multipliers, NextMults, NextMults, Parms1, Parms, ParmVars),
|
||||||
|
%writeln(t:Size:Parms0:Parms:ParmVars),
|
||||||
|
rb_insert(Ps, DistId-Map, theta(Parms, ParmVars), PsF).
|
||||||
|
|
||||||
|
swap_parms(Parms0, Sizes, Map, Parms1) :-
|
||||||
|
matrix_new(floats, Sizes, Parms0, T0),
|
||||||
|
matrix_shuffle(T0,Map,TF),
|
||||||
|
matrix_to_list(TF, Parms1).
|
||||||
|
|
||||||
|
%
|
||||||
|
% we are using switches by two
|
||||||
|
%
|
||||||
|
initial_maxes(0, []) :- !.
|
||||||
|
initial_maxes(Size, [1.0|Multipliers]) :- !,
|
||||||
|
Size1 is Size-1,
|
||||||
|
initial_maxes(Size1, Multipliers).
|
||||||
|
|
||||||
|
copy(0, [], [], _, _Parms0, [], []) :- !.
|
||||||
|
copy(N, [], [], Ms, Parms0, Parms, ParmVars) :-!,
|
||||||
|
copy(N, Ms, NewMs, NewMs, Parms0, Parms, ParmVars).
|
||||||
|
copy(N, D.Ds, ND.NDs, New, El.Parms0, NEl.Parms, V.ParmVars) :-
|
||||||
|
N1 is N-1,
|
||||||
|
(El == 0.0 ->
|
||||||
|
NEl = 0,
|
||||||
|
V = NEl,
|
||||||
|
ND = D
|
||||||
|
;El == 1.0 ->
|
||||||
|
NEl = 1,
|
||||||
|
V = NEl,
|
||||||
|
ND = 0.0
|
||||||
|
;El == 0 ->
|
||||||
|
NEl = 0,
|
||||||
|
V = NEl,
|
||||||
|
ND = D
|
||||||
|
;El =:= 1 ->
|
||||||
|
NEl = 1,
|
||||||
|
V = NEl,
|
||||||
|
ND = 0.0,
|
||||||
|
V = NEl
|
||||||
|
;
|
||||||
|
NEl is El/D,
|
||||||
|
ND is D-El,
|
||||||
|
V = NEl
|
||||||
|
),
|
||||||
|
copy(N1, Ds, NDs, New, Parms0, Parms, ParmVars).
|
||||||
|
|
||||||
|
unbound_parms([], []).
|
||||||
|
unbound_parms(_.Parms, _.ParmVars) :-
|
||||||
|
unbound_parms(Parms, ParmVars).
|
||||||
|
|
||||||
|
check_v(V, _, INFO, Vs, Vs) :-
|
||||||
|
rb_lookup(V, INFO, Vs), !.
|
||||||
|
check_v(V, DistId, INFO, Vs0, Vs) :-
|
||||||
|
get_dist_domain_size(DistId, Size),
|
||||||
|
length(Values, Size),
|
||||||
|
length(Ev, Size),
|
||||||
|
INFO = info(V, _Tree, Ev, Values, _Formula, _, _),
|
||||||
|
rb_insert(Vs0, V, INFO, Vs).
|
||||||
|
|
||||||
|
get_parents([], [], Vs, Vs).
|
||||||
|
get_parents(V.Parents, Values.PVars, Vs0, Vs) :-
|
||||||
|
clpbn:get_atts(V, [dist(DistId, _)]),
|
||||||
|
check_v(V, DistId, INFO, Vs0, Vs1),
|
||||||
|
INFO = info(V, _Parent, _Ev, Values, _, _, _),
|
||||||
|
get_parents(Parents, PVars, Vs1, Vs).
|
||||||
|
|
||||||
|
%
|
||||||
|
% construct the formula, this is the key...
|
||||||
|
%
|
||||||
|
cross_product(Values, Ev, PVars, ParmVars, Formulas) :-
|
||||||
|
arrangements(PVars, Arranges),
|
||||||
|
apply_parents_first(Values, Ev, ParmCombos, ParmCombos, Arranges, Formulas, ParmVars).
|
||||||
|
|
||||||
|
%
|
||||||
|
% if we have the parent variables with two values, we get
|
||||||
|
% [[XP,YP],[XP,YN],[XN,YP],[XN,YN]]
|
||||||
|
%
|
||||||
|
arrangements([], [[]]).
|
||||||
|
arrangements([L1|Ls],O) :-
|
||||||
|
arrangements(Ls, LN),
|
||||||
|
expand(L1, LN, O, []).
|
||||||
|
|
||||||
|
expand([], _LN) --> [].
|
||||||
|
expand([H|L1], LN) -->
|
||||||
|
concatenate_all(H, LN),
|
||||||
|
expand(L1, LN).
|
||||||
|
|
||||||
|
concatenate_all(_H, []) --> [].
|
||||||
|
concatenate_all(H, L.LN) -->
|
||||||
|
[[H|L]],
|
||||||
|
concatenate_all(H, LN).
|
||||||
|
|
||||||
|
%
|
||||||
|
% core of algorithm
|
||||||
|
%
|
||||||
|
% Values -> Output Vars for BDD
|
||||||
|
% Es -> Evidence variables
|
||||||
|
% Previous -> top of difference list with parameters used so far
|
||||||
|
% P0 -> end of difference list with parameters used so far
|
||||||
|
% Pvars -> Parents
|
||||||
|
% Eqs -> Output Equations
|
||||||
|
% Pars -> Output Theta Parameters
|
||||||
|
%
|
||||||
|
apply_parents_first([Value], [E], Previous, [], PVars, [Value=Disj*E], Parameters) :- !,
|
||||||
|
apply_last_parent(PVars, Previous, Disj),
|
||||||
|
flatten(Previous, Parameters).
|
||||||
|
apply_parents_first([Value|Values], [E|Ev], Previous, P0, PVars, (Value=Disj*E).Formulas, Parameters) :-
|
||||||
|
P0 = [TheseParents|End],
|
||||||
|
apply_first_parent(PVars, Disj, TheseParents),
|
||||||
|
apply_parents_second(Values, Ev, Previous, End, PVars, Formulas, Parameters).
|
||||||
|
|
||||||
|
apply_parents_second([Value], [E], Previous, [], PVars, [Value=Disj*E], Parameters) :- !,
|
||||||
|
apply_last_parent(PVars, Previous, Disj),
|
||||||
|
flatten(Previous, Parameters).
|
||||||
|
apply_parents_second([Value|Values], [E|Ev], Previous, P0, PVars, (Value=Disj*E).Formulas, Parameters) :-
|
||||||
|
apply_middle_parent(PVars, Previous, Disj, TheseParents),
|
||||||
|
% this must be done after applying middle parents because of the var
|
||||||
|
% test.
|
||||||
|
P0 = [TheseParents|End],
|
||||||
|
apply_parents_second(Values, Ev, Previous, End, PVars, Formulas, Parameters).
|
||||||
|
|
||||||
|
apply_first_parent([Parents], Conj, [Theta]) :- !,
|
||||||
|
parents_to_conj(Parents,Theta,Conj).
|
||||||
|
apply_first_parent(Parents.PVars, Conj+Disj, Theta.TheseParents) :-
|
||||||
|
parents_to_conj(Parents,Theta,Conj),
|
||||||
|
apply_first_parent(PVars, Disj, TheseParents).
|
||||||
|
|
||||||
|
apply_middle_parent([Parents], Other, Conj, [ThetaPar]) :- !,
|
||||||
|
skim_for_theta(Other, Theta, _, ThetaPar),
|
||||||
|
parents_to_conj(Parents,Theta,Conj).
|
||||||
|
apply_middle_parent(Parents.PVars, Other, Conj+Disj, ThetaPar.TheseParents) :-
|
||||||
|
skim_for_theta(Other, Theta, Remaining, ThetaPar),
|
||||||
|
parents_to_conj(Parents,(Theta),Conj),
|
||||||
|
apply_middle_parent(PVars, Remaining, Disj, TheseParents).
|
||||||
|
|
||||||
|
apply_last_parent([Parents], Other, Conj) :- !,
|
||||||
|
parents_to_conj(Parents,(Theta),Conj),
|
||||||
|
skim_for_theta(Other, Theta, _, _).
|
||||||
|
apply_last_parent(Parents.PVars, Other, Conj+Disj) :-
|
||||||
|
parents_to_conj(Parents,(Theta),Conj),
|
||||||
|
skim_for_theta(Other, Theta, Remaining, _),
|
||||||
|
apply_last_parent(PVars, Remaining, Disj).
|
||||||
|
|
||||||
|
%
|
||||||
|
%
|
||||||
|
% simplify stuff, removing process that is cancelled by 0s
|
||||||
|
%
|
||||||
|
parents_to_conj([], Theta, Theta) :- !.
|
||||||
|
parents_to_conj(Ps, Theta, Theta*Conj) :-
|
||||||
|
parents_to_conj2(Ps, Conj).
|
||||||
|
|
||||||
|
parents_to_conj2([P],P) :- !.
|
||||||
|
parents_to_conj2(P.Ps,P*Conj) :-
|
||||||
|
parents_to_conj2(Ps,Conj).
|
||||||
|
|
||||||
|
%
|
||||||
|
% first case we haven't reached the end of the list so we need
|
||||||
|
% to create a new parameter variable
|
||||||
|
%
|
||||||
|
skim_for_theta([[P|Other]|V], not(P)*New, [Other|_], New) :- var(V), !.
|
||||||
|
%
|
||||||
|
% last theta, it is just negation of the other ones
|
||||||
|
%
|
||||||
|
skim_for_theta([[P|Other]], not(P), [Other], _) :- !.
|
||||||
|
%
|
||||||
|
% recursive case, build-up
|
||||||
|
%
|
||||||
|
skim_for_theta([[P|Other]|More], not(P)*Ps, [Other|Left], New ) :-
|
||||||
|
skim_for_theta(More, Ps, Left, New ).
|
||||||
|
|
||||||
|
get_evidence(V, Tree, Ev, F0, F, Leaves, Finals) :-
|
||||||
|
clpbn:get_atts(V, [evidence(Pos)]), !,
|
||||||
|
zero_pos(0, Pos, Ev),
|
||||||
|
insert_output(Leaves, V, Finals, Tree, Outs, SendOut),
|
||||||
|
get_outs(F0, F, SendOut, Outs).
|
||||||
|
% hidden deterministic node, can be removed.
|
||||||
|
get_evidence(V, _Tree, Ev, F0, [], _Leaves, _Finals) :-
|
||||||
|
clpbn:get_atts(V, [key(K)]),
|
||||||
|
functor(K, Name, 2),
|
||||||
|
( Name = 'AVG' ; Name = 'MAX' ; Name = 'MIN' ),
|
||||||
|
!,
|
||||||
|
one_list(Ev),
|
||||||
|
eval_outs(F0).
|
||||||
|
%% no evidence !!!
|
||||||
|
get_evidence(V, Tree, _Values, F0, F1, Leaves, Finals) :-
|
||||||
|
insert_output(Leaves, V, Finals, Tree, Outs, SendOut),
|
||||||
|
get_outs(F0, F1, SendOut, Outs).
|
||||||
|
|
||||||
|
zero_pos(_, _Pos, []).
|
||||||
|
zero_pos(Pos, Pos, 1.Values) :- !,
|
||||||
|
I is Pos+1,
|
||||||
|
zero_pos(I, Pos, Values).
|
||||||
|
zero_pos(I0, Pos, 0.Values) :-
|
||||||
|
I is I0+1,
|
||||||
|
zero_pos(I, Pos, Values).
|
||||||
|
|
||||||
|
one_list([]).
|
||||||
|
one_list(1.Ev) :-
|
||||||
|
one_list(Ev).
|
||||||
|
|
||||||
|
%
|
||||||
|
% insert a node with the disj of all alternatives, this is only done if node ends up to be in the output
|
||||||
|
%
|
||||||
|
insert_output([], _V, [], _Out, _Outs, []).
|
||||||
|
insert_output(V._Leaves, V0, [Top|_], Top, Outs, [Top = Outs]) :- V == V0, !.
|
||||||
|
insert_output(_.Leaves, V, _.Finals, Top, Outs, SendOut) :-
|
||||||
|
insert_output(Leaves, V, Finals, Top, Outs, SendOut).
|
||||||
|
|
||||||
|
|
||||||
|
get_outs([V=F], [V=NF|End], End, V) :- !,
|
||||||
|
% writeln(f0:F),
|
||||||
|
simplify_exp(F,NF).
|
||||||
|
get_outs((V=F).Outs, (V=NF).NOuts, End, (F0 + V)) :-
|
||||||
|
% writeln(f0:F),
|
||||||
|
simplify_exp(F,NF),
|
||||||
|
get_outs(Outs, NOuts, End, F0).
|
||||||
|
|
||||||
|
eval_outs([]).
|
||||||
|
eval_outs((V=F).Outs) :-
|
||||||
|
simplify_exp(F,NF),
|
||||||
|
V = NF,
|
||||||
|
eval_outs(Outs).
|
||||||
|
|
||||||
|
run_bdd_solver([[V]], LPs, bdd(Term, _Leaves, Nodes)) :-
|
||||||
|
build_out_node(Nodes, Node),
|
||||||
|
findall(Prob, get_prob(Term, Node, V, Prob),TermProbs),
|
||||||
|
sumlist(TermProbs, Sum),
|
||||||
|
writeln(TermProbs:Sum),
|
||||||
|
normalise(TermProbs, Sum, LPs).
|
||||||
|
|
||||||
|
build_out_node([_Top], []).
|
||||||
|
build_out_node([T,T1|Tops], [Top = T*Top]) :-
|
||||||
|
build_out_node2(T1.Tops, Top).
|
||||||
|
|
||||||
|
build_out_node2([Top], Top).
|
||||||
|
build_out_node2([T,T1|Tops], T*Top) :-
|
||||||
|
build_out_node2(T1.Tops, Top).
|
||||||
|
|
||||||
|
|
||||||
|
get_prob(Term, Node, V, SP) :-
|
||||||
|
bind_all(Term, Node, Bindings, V, AllParms, AllParmValues),
|
||||||
|
% reverse(AllParms, RAllParms),
|
||||||
|
term_variables(AllParms, NVs),
|
||||||
|
build_bdd(Bindings, NVs, AllParms, AllParmValues, Bdd),
|
||||||
|
bdd_to_probability_sum_product(Bdd, SP),
|
||||||
|
bdd_close(Bdd).
|
||||||
|
|
||||||
|
build_bdd(Bindings, NVs, VTheta, Theta, Bdd) :-
|
||||||
|
bdd_from_list(Bindings, NVs, Bdd),
|
||||||
|
bdd_size(Bdd, Len),
|
||||||
|
number_codes(Len,Codes),
|
||||||
|
atom_codes(Name,Codes),
|
||||||
|
bdd_print(Bdd, Name),
|
||||||
|
writeln(length=Len),
|
||||||
|
VTheta = Theta.
|
||||||
|
|
||||||
|
bind_all([], End, End, _V, [], []).
|
||||||
|
bind_all(info(V, _Tree, Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
|
||||||
|
V0 == V, !,
|
||||||
|
set_to_one_zeros(Ev),
|
||||||
|
bind_formula(Formula, BindsF, BindsI),
|
||||||
|
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
|
||||||
|
bind_all(info(_V, _Tree, Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
|
||||||
|
set_to_ones(Ev),!,
|
||||||
|
bind_formula(Formula, BindsF, BindsI),
|
||||||
|
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
|
||||||
|
% evidence: no need to add any stuff.
|
||||||
|
bind_all(info(_V, _Tree, _Ev, _Values, Formula, ParmVars, Parms).Term, End, BindsF, V0, ParmVars.AllParms, Parms.AllTheta) :-
|
||||||
|
bind_formula(Formula, BindsF, BindsI),
|
||||||
|
bind_all(Term, End, BindsI, V0, AllParms, AllTheta).
|
||||||
|
|
||||||
|
bind_formula([], L, L).
|
||||||
|
bind_formula(B.Formula, B.BsF, Bs0) :-
|
||||||
|
bind_formula(Formula, BsF, Bs0).
|
||||||
|
|
||||||
|
set_to_one_zeros([1|Values]) :-
|
||||||
|
set_to_zeros(Values).
|
||||||
|
set_to_one_zeros([0|Values]) :-
|
||||||
|
set_to_one_zeros(Values).
|
||||||
|
|
||||||
|
set_to_zeros([]).
|
||||||
|
set_to_zeros(0.Values) :-
|
||||||
|
set_to_zeros(Values).
|
||||||
|
|
||||||
|
set_to_ones([]).
|
||||||
|
set_to_ones(1.Values) :-
|
||||||
|
set_to_ones(Values).
|
||||||
|
|
||||||
|
normalise([], _Sum, []).
|
||||||
|
normalise(P.TermProbs, Sum, NP.LPs) :-
|
||||||
|
NP is P/Sum,
|
||||||
|
normalise(TermProbs, Sum, LPs).
|
||||||
|
|
||||||
|
finalize_bdd_solver(_).
|
||||||
|
|
@ -1,16 +1,16 @@
|
|||||||
|
|
||||||
/************************************************
|
/*******************************************************
|
||||||
|
|
||||||
Belief Propagation in CLP(BN)
|
Belief Propagation and Variable Elimination Interface
|
||||||
|
|
||||||
**************************************************/
|
********************************************************/
|
||||||
|
|
||||||
:- module(clpbn_bp,
|
:- module(clpbn_bp,
|
||||||
[bp/3,
|
[bp/3,
|
||||||
check_if_bp_done/1,
|
check_if_bp_done/1,
|
||||||
set_horus_flag/2,
|
|
||||||
init_bp_solver/4,
|
init_bp_solver/4,
|
||||||
run_bp_solver/3,
|
run_bp_solver/3,
|
||||||
|
call_bp_ground/6,
|
||||||
finalize_bp_solver/1
|
finalize_bp_solver/1
|
||||||
]).
|
]).
|
||||||
|
|
||||||
@ -24,154 +24,143 @@
|
|||||||
|
|
||||||
|
|
||||||
:- use_module(library('clpbn/display'),
|
:- use_module(library('clpbn/display'),
|
||||||
[clpbn_bind_vals/3]).
|
[clpbn_bind_vals/3]).
|
||||||
|
|
||||||
|
|
||||||
:- use_module(library('clpbn/aggregates'),
|
:- use_module(library('clpbn/aggregates'),
|
||||||
[check_for_agg_vars/2]).
|
[check_for_agg_vars/2]).
|
||||||
|
|
||||||
|
|
||||||
|
:- use_module(library(charsio),
|
||||||
|
[term_to_atom/2]).
|
||||||
|
|
||||||
|
|
||||||
|
:- use_module(library(pfl),
|
||||||
|
[skolem/2,
|
||||||
|
get_pfl_parameters/2
|
||||||
|
]).
|
||||||
|
|
||||||
|
|
||||||
|
:- use_module(library(lists)).
|
||||||
|
|
||||||
:- use_module(library(atts)).
|
:- use_module(library(atts)).
|
||||||
:- use_module(library(lists)).
|
|
||||||
:- use_module(library(charsio)).
|
|
||||||
|
|
||||||
:- load_foreign_files(['horus'], [], init_predicates).
|
:- use_module(library(bhash)).
|
||||||
|
|
||||||
:- attribute id/1.
|
|
||||||
|
|
||||||
|
|
||||||
%:- set_horus_flag(inf_alg, ve).
|
:- use_module(horus,
|
||||||
:- set_horus_flag(inf_alg, bn_bp).
|
[create_ground_network/4,
|
||||||
%:- set_horus_flag(inf_alg, fg_bp).
|
set_factors_params/2,
|
||||||
%: -set_horus_flag(inf_alg, cbp).
|
run_ground_solver/3,
|
||||||
|
set_vars_information/2,
|
||||||
|
free_ground_network/1
|
||||||
|
]).
|
||||||
|
|
||||||
:- set_horus_flag(schedule, seq_fixed).
|
|
||||||
%:- set_horus_flag(schedule, seq_random).
|
|
||||||
%:- set_horus_flag(schedule, parallel).
|
|
||||||
%:- set_horus_flag(schedule, max_residual).
|
|
||||||
|
|
||||||
:- set_horus_flag(accuracy, 0.0001).
|
call_bp_ground(QueryVars, QueryKeys, AllKeys, Factors, Evidence, Output) :-
|
||||||
|
writeln(here:Factors),
|
||||||
|
b_hash_new(Hash0),
|
||||||
|
keys_to_ids(AllKeys, 0, Hash0, Hash),
|
||||||
|
get_factors_type(Factors, Type),
|
||||||
|
evidence_to_ids(Evidence, Hash, EvidenceIds),
|
||||||
|
factors_to_ids(Factors, Hash, FactorIds),
|
||||||
|
writeln(type:Type), writeln(''),
|
||||||
|
writeln(allKeys:AllKeys), writeln(''),
|
||||||
|
writeln(factors:Factors), writeln(''),
|
||||||
|
writeln(factorIds:FactorIds), writeln(''),
|
||||||
|
writeln(evidence:Evidence), writeln(''),
|
||||||
|
writeln(evidenceIds:EvidenceIds), writeln(''),
|
||||||
|
create_ground_network(Type, FactorIds, EvidenceIds, Network),
|
||||||
|
%get_vars_information(AllKeys, StatesNames),
|
||||||
|
%set_vars_information(AllKeys, StatesNames),
|
||||||
|
run_solver(ground(Network,Hash), QueryKeys, Solutions),
|
||||||
|
writeln(answer:Solutions),
|
||||||
|
clpbn_bind_vals([QueryVars], Solutions, Output),
|
||||||
|
free_ground_network(Network).
|
||||||
|
|
||||||
:- set_horus_flag(max_iter, 1000).
|
|
||||||
|
|
||||||
:- set_horus_flag(use_logarithms, false).
|
run_solver(ground(Network,Hash), QueryKeys, Solutions) :-
|
||||||
%:- set_horus_flag(use_logarithms, true).
|
%get_dists_parameters(DistIds, DistsParams),
|
||||||
|
%set_factors_params(Network, DistsParams),
|
||||||
|
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||||
|
writeln(queryKeys:QueryKeys), writeln(''),
|
||||||
|
writeln(queryIds:QueryIds), writeln(''),
|
||||||
|
list_of_keys_to_ids(QueryKeys, Hash, QueryIds),
|
||||||
|
run_ground_solver(Network, [QueryIds], Solutions).
|
||||||
|
|
||||||
:- set_horus_flag(order_factor_variables, false).
|
|
||||||
%:- set_horus_flag(order_factor_variables, true).
|
|
||||||
|
|
||||||
|
keys_to_ids([], _, Hash, Hash).
|
||||||
|
keys_to_ids([Key|AllKeys], I0, Hash0, Hash) :-
|
||||||
|
b_hash_insert(Hash0, Key, I0, HashI),
|
||||||
|
I is I0+1,
|
||||||
|
keys_to_ids(AllKeys, I, HashI, Hash).
|
||||||
|
|
||||||
|
|
||||||
|
get_factors_type([f(bayes, _, _, _)|_], bayes) :- ! .
|
||||||
|
get_factors_type([f(markov, _, _, _)|_], markov) :- ! .
|
||||||
|
|
||||||
|
|
||||||
|
list_of_keys_to_ids([], _, []).
|
||||||
|
list_of_keys_to_ids([Key|QueryKeys], Hash, [Id|QueryIds]) :-
|
||||||
|
b_hash_lookup(Key, Id, Hash),
|
||||||
|
list_of_keys_to_ids(QueryKeys, Hash, QueryIds).
|
||||||
|
|
||||||
|
|
||||||
|
factors_to_ids([], _, []).
|
||||||
|
factors_to_ids([f(_, DistId, Keys, CPT)|Fs], Hash, [f(Ids, Ranges, CPT, DistId)|NFs]) :-
|
||||||
|
list_of_keys_to_ids(Keys, Hash, Ids),
|
||||||
|
get_ranges(Keys, Ranges),
|
||||||
|
factors_to_ids(Fs, Hash, NFs).
|
||||||
|
|
||||||
|
|
||||||
|
get_ranges([],[]).
|
||||||
|
get_ranges(K.Ks, Range.Rs) :- !,
|
||||||
|
skolem(K,Domain),
|
||||||
|
length(Domain,Range),
|
||||||
|
get_ranges(Ks, Rs).
|
||||||
|
|
||||||
|
|
||||||
|
evidence_to_ids([], _, []).
|
||||||
|
evidence_to_ids([Key=Ev|QueryKeys], Hash, [Id=Ev|QueryIds]) :-
|
||||||
|
b_hash_lookup(Key, Id, Hash),
|
||||||
|
evidence_to_ids(QueryKeys, Hash, QueryIds).
|
||||||
|
|
||||||
|
|
||||||
|
get_vars_information([], []).
|
||||||
|
get_vars_information(Key.QueryKeys, Domain.StatesNames) :-
|
||||||
|
pfl:skolem(Key, Domain),
|
||||||
|
get_vars_information(QueryKeys, StatesNames).
|
||||||
|
|
||||||
|
|
||||||
|
finalize_bp_solver(bp(Network, _)) :-
|
||||||
|
free_ground_network(Network).
|
||||||
|
|
||||||
|
|
||||||
bp([[]],_,_) :- !.
|
bp([[]],_,_) :- !.
|
||||||
bp([QueryVars], AllVars, Output) :-
|
bp([QueryVars], AllVars, Output) :-
|
||||||
init_bp_solver(_, AllVars, _, Network),
|
init_bp_solver(_, AllVars, _, Network),
|
||||||
run_bp_solver([QueryVars], LPs, Network),
|
run_bp_solver([QueryVars], LPs, Network),
|
||||||
finalize_bp_solver(Network),
|
finalize_bp_solver(Network),
|
||||||
clpbn_bind_vals([QueryVars], LPs, Output).
|
clpbn_bind_vals([QueryVars], LPs, Output).
|
||||||
|
|
||||||
|
|
||||||
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
init_bp_solver(_, AllVars0, _, bp(BayesNet, DistIds)) :-
|
||||||
check_for_agg_vars(AllVars0, AllVars),
|
%check_for_agg_vars(AllVars0, AllVars),
|
||||||
writeln('clpbn_vars:'),
|
get_vars_info(AllVars0, VarsInfo, DistIds0),
|
||||||
print_clpbn_vars(AllVars),
|
sort(DistIds0, DistIds),
|
||||||
assign_ids(AllVars, 0),
|
create_ground_network(VarsInfo, BayesNet),
|
||||||
get_vars_info(AllVars, VarsInfo, DistIds0),
|
true.
|
||||||
sort(DistIds0, DistIds),
|
|
||||||
create_ground_network(VarsInfo, BayesNet).
|
|
||||||
%get_extra_vars_info(AllVars, ExtraVarsInfo),
|
|
||||||
%set_extra_vars_info(BayesNet, ExtraVarsInfo).
|
|
||||||
|
|
||||||
|
|
||||||
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
run_bp_solver(QueryVars, Solutions, bp(Network, DistIds)) :-
|
||||||
get_dists_parameters(DistIds, DistsParams),
|
get_dists_parameters(DistIds, DistsParams),
|
||||||
set_bayes_net_params(Network, DistsParams),
|
set_factors_params(Network, DistsParams),
|
||||||
flatten_1_element_sublists(QueryVars, QueryVars1),
|
vars_to_ids(QueryVars, QueryVarsIds),
|
||||||
vars_to_ids(QueryVars1, QueryVarsIds),
|
run_ground_solver(Network, QueryVarsIds, Solutions).
|
||||||
run_other_solvers(Network, QueryVarsIds, Solutions).
|
|
||||||
|
|
||||||
|
|
||||||
finalize_bp_solver(bp(Network, _)) :-
|
|
||||||
free_bayesian_network(Network).
|
|
||||||
|
|
||||||
|
|
||||||
assign_ids([], _).
|
|
||||||
assign_ids([V|Vs], Count) :-
|
|
||||||
put_atts(V, [id(Count)]),
|
|
||||||
Count1 is Count + 1,
|
|
||||||
assign_ids(Vs, Count1).
|
|
||||||
|
|
||||||
|
|
||||||
get_vars_info([], [], []).
|
|
||||||
get_vars_info(V.Vs,
|
|
||||||
var(VarId,DS,Ev,PIds,DistId).VarsInfo,
|
|
||||||
DistId.DistIds) :-
|
|
||||||
clpbn:get_atts(V, [dist(DistId, Parents)]), !,
|
|
||||||
get_atts(V, [id(VarId)]),
|
|
||||||
get_dist_domain_size(DistId, DS),
|
|
||||||
get_evidence(V, Ev),
|
|
||||||
vars_to_ids(Parents, PIds),
|
|
||||||
get_vars_info(Vs, VarsInfo, DistIds).
|
|
||||||
|
|
||||||
|
|
||||||
get_evidence(V, Ev) :-
|
|
||||||
clpbn:get_atts(V, [evidence(Ev)]), !.
|
|
||||||
get_evidence(_V, -1). % no evidence !!!
|
|
||||||
|
|
||||||
|
|
||||||
vars_to_ids([], []).
|
|
||||||
vars_to_ids([L|Vars], [LIds|Ids]) :-
|
|
||||||
is_list(L), !,
|
|
||||||
vars_to_ids(L, LIds),
|
|
||||||
vars_to_ids(Vars, Ids).
|
|
||||||
vars_to_ids([V|Vars], [VarId|Ids]) :-
|
|
||||||
get_atts(V, [id(VarId)]),
|
|
||||||
vars_to_ids(Vars, Ids).
|
|
||||||
|
|
||||||
|
|
||||||
get_extra_vars_info([], []).
|
|
||||||
get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
|
|
||||||
get_atts(V, [id(VarId)]), !,
|
|
||||||
clpbn:get_atts(V, [key(Key),dist(DistId, _)]),
|
|
||||||
term_to_atom(Key, Label),
|
|
||||||
get_dist_domain(DistId, Domain0),
|
|
||||||
numbers_to_atoms(Domain0, Domain),
|
|
||||||
get_extra_vars_info(Vs, VarsInfo).
|
|
||||||
get_extra_vars_info([_|Vs], VarsInfo) :-
|
|
||||||
get_extra_vars_info(Vs, VarsInfo).
|
|
||||||
|
|
||||||
|
|
||||||
get_dists_parameters([],[]).
|
get_dists_parameters([],[]).
|
||||||
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
||||||
get_dist_params(Id, Params),
|
get_dist_params(Id, Params),
|
||||||
get_dists_parameters(Ids, DistsInfo).
|
get_dists_parameters(Ids, DistsInfo).
|
||||||
|
|
||||||
|
|
||||||
numbers_to_atoms([], []).
|
|
||||||
numbers_to_atoms([Atom|L0], [Atom|L]) :-
|
|
||||||
atom(Atom), !,
|
|
||||||
numbers_to_atoms(L0, L).
|
|
||||||
numbers_to_atoms([Number|L0], [Atom|L]) :-
|
|
||||||
number_atom(Number, Atom),
|
|
||||||
numbers_to_atoms(L0, L).
|
|
||||||
|
|
||||||
|
|
||||||
flatten_1_element_sublists([],[]).
|
|
||||||
flatten_1_element_sublists([[H|[]]|T],[H|R]) :- !,
|
|
||||||
flatten_1_element_sublists(T,R).
|
|
||||||
flatten_1_element_sublists([H|T],[H|R]) :-
|
|
||||||
flatten_1_element_sublists(T,R).
|
|
||||||
|
|
||||||
|
|
||||||
print_clpbn_vars(Var.AllVars) :-
|
|
||||||
clpbn:get_atts(Var, [key(Key),dist(DistId,Parents)]),
|
|
||||||
parents_to_keys(Parents, ParentKeys),
|
|
||||||
writeln(Var:Key:ParentKeys:DistId),
|
|
||||||
print_clpbn_vars(AllVars).
|
|
||||||
print_clpbn_vars([]).
|
|
||||||
|
|
||||||
|
|
||||||
parents_to_keys([], []).
|
|
||||||
parents_to_keys(Var.Parents, Key.Keys) :-
|
|
||||||
clpbn:get_atts(Var, [key(Key)]),
|
|
||||||
parents_to_keys(Parents, Keys).
|
|
||||||
|
|
||||||
|
77
packages/CLPBN/clpbn/bp/BayesBall.cpp
Normal file
77
packages/CLPBN/clpbn/bp/BayesBall.cpp
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
#include <cstdlib>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "BayesBall.h"
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FactorGraph*
|
||||||
|
BayesBall::getMinimalFactorGraph (const VarIds& queryIds)
|
||||||
|
{
|
||||||
|
assert (fg_.isFromBayesNetwork());
|
||||||
|
|
||||||
|
Scheduling scheduling;
|
||||||
|
for (unsigned i = 0; i < queryIds.size(); i++) {
|
||||||
|
assert (dag_.getNode (queryIds[i]));
|
||||||
|
DAGraphNode* n = dag_.getNode (queryIds[i]);
|
||||||
|
scheduling.push (ScheduleInfo (n, false, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
while (!scheduling.empty()) {
|
||||||
|
ScheduleInfo& sch = scheduling.front();
|
||||||
|
DAGraphNode* n = sch.node;
|
||||||
|
n->setAsVisited();
|
||||||
|
if (n->hasEvidence() == false && sch.visitedFromChild) {
|
||||||
|
if (n->isMarkedOnTop() == false) {
|
||||||
|
n->markOnTop();
|
||||||
|
scheduleParents (n, scheduling);
|
||||||
|
}
|
||||||
|
if (n->isMarkedOnBottom() == false) {
|
||||||
|
n->markOnBottom();
|
||||||
|
scheduleChilds (n, scheduling);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (sch.visitedFromParent) {
|
||||||
|
if (n->hasEvidence() && n->isMarkedOnTop() == false) {
|
||||||
|
n->markOnTop();
|
||||||
|
scheduleParents (n, scheduling);
|
||||||
|
}
|
||||||
|
if (n->hasEvidence() == false && n->isMarkedOnBottom() == false) {
|
||||||
|
n->markOnBottom();
|
||||||
|
scheduleChilds (n, scheduling);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scheduling.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
FactorGraph* fg = new FactorGraph();
|
||||||
|
constructGraph (fg);
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BayesBall::constructGraph (FactorGraph* fg) const
|
||||||
|
{
|
||||||
|
const FacNodes& facNodes = fg_.facNodes();
|
||||||
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
|
const DAGraphNode* n = dag_.getNode (
|
||||||
|
facNodes[i]->factor().argument (0));
|
||||||
|
if (n->isMarkedOnTop()) {
|
||||||
|
fg->addFactor (Factor (facNodes[i]->factor()));
|
||||||
|
} else if (n->hasEvidence() && n->isVisited()) {
|
||||||
|
VarIds varIds = { facNodes[i]->factor().argument (0) };
|
||||||
|
Ranges ranges = { facNodes[i]->factor().range (0) };
|
||||||
|
Params params (ranges[0], LogAware::noEvidence());
|
||||||
|
params[n->getEvidence()] = LogAware::withEvidence();
|
||||||
|
fg->addFactor (Factor (varIds, ranges, params));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
85
packages/CLPBN/clpbn/bp/BayesBall.h
Normal file
85
packages/CLPBN/clpbn/bp/BayesBall.h
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
#ifndef HORUS_BAYESBALL_H
|
||||||
|
#define HORUS_BAYESBALL_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <queue>
|
||||||
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "BayesNet.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
struct ScheduleInfo
|
||||||
|
{
|
||||||
|
ScheduleInfo (DAGraphNode* n, bool vfp, bool vfc) :
|
||||||
|
node(n), visitedFromParent(vfp), visitedFromChild(vfc) { }
|
||||||
|
|
||||||
|
DAGraphNode* node;
|
||||||
|
bool visitedFromParent;
|
||||||
|
bool visitedFromChild;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
typedef queue<ScheduleInfo, list<ScheduleInfo>> Scheduling;
|
||||||
|
|
||||||
|
|
||||||
|
class BayesBall
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
BayesBall (FactorGraph& fg)
|
||||||
|
: fg_(fg) , dag_(fg.getStructure())
|
||||||
|
{
|
||||||
|
dag_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
FactorGraph* getMinimalFactorGraph (const VarIds&);
|
||||||
|
|
||||||
|
static FactorGraph* getMinimalFactorGraph (FactorGraph& fg, VarIds vids)
|
||||||
|
{
|
||||||
|
BayesBall bb (fg);
|
||||||
|
return bb.getMinimalFactorGraph (vids);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
void constructGraph (FactorGraph* fg) const;
|
||||||
|
|
||||||
|
void scheduleParents (const DAGraphNode* n, Scheduling& sch) const;
|
||||||
|
|
||||||
|
void scheduleChilds (const DAGraphNode* n, Scheduling& sch) const;
|
||||||
|
|
||||||
|
FactorGraph& fg_;
|
||||||
|
|
||||||
|
DAGraph& dag_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
inline void
|
||||||
|
BayesBall::scheduleParents (const DAGraphNode* n, Scheduling& sch) const
|
||||||
|
{
|
||||||
|
const vector<DAGraphNode*>& ps = n->parents();
|
||||||
|
for (vector<DAGraphNode*>::const_iterator it = ps.begin();
|
||||||
|
it != ps.end(); it++) {
|
||||||
|
sch.push (ScheduleInfo (*it, false, true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
inline void
|
||||||
|
BayesBall::scheduleChilds (const DAGraphNode* n, Scheduling& sch) const
|
||||||
|
{
|
||||||
|
const vector<DAGraphNode*>& cs = n->childs();
|
||||||
|
for (vector<DAGraphNode*>::const_iterator it = cs.begin();
|
||||||
|
it != cs.end(); it++) {
|
||||||
|
sch.push (ScheduleInfo (*it, true, false));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // HORUS_BAYESBALL_H
|
||||||
|
|
@ -5,381 +5,57 @@
|
|||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "xmlParser/xmlParser.h"
|
|
||||||
|
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
BayesNet::~BayesNet (void)
|
DAGraph::addNode (DAGraphNode* n)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
assert (Util::contains (varMap_, n->varId()) == false);
|
||||||
delete nodes_[i];
|
nodes_.push_back (n);
|
||||||
}
|
varMap_[n->varId()] = n;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::readFromBifFormat (const char* fileName)
|
DAGraph::addEdge (VarId vid1, VarId vid2)
|
||||||
{
|
{
|
||||||
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF");
|
unordered_map<VarId, DAGraphNode*>::iterator it1;
|
||||||
// only the first network is parsed, others are ignored
|
unordered_map<VarId, DAGraphNode*>::iterator it2;
|
||||||
XMLNode xNode = xMainNode.getChildNode ("NETWORK");
|
it1 = varMap_.find (vid1);
|
||||||
unsigned nVars = xNode.nChildNode ("VARIABLE");
|
it2 = varMap_.find (vid2);
|
||||||
for (unsigned i = 0; i < nVars; i++) {
|
assert (it1 != varMap_.end());
|
||||||
XMLNode var = xNode.getChildNode ("VARIABLE", i);
|
assert (it2 != varMap_.end());
|
||||||
if (string (var.getAttribute ("TYPE")) != "nature") {
|
it1->second->addChild (it2->second);
|
||||||
cerr << "error: only \"nature\" variables are supported" << endl;
|
it2->second->addParent (it1->second);
|
||||||
abort();
|
|
||||||
}
|
|
||||||
States states;
|
|
||||||
string label = var.getChildNode("NAME").getText();
|
|
||||||
unsigned nrStates = var.nChildNode ("OUTCOME");
|
|
||||||
for (unsigned j = 0; j < nrStates; j++) {
|
|
||||||
if (var.getChildNode("OUTCOME", j).getText() == 0) {
|
|
||||||
stringstream ss;
|
|
||||||
ss << j + 1;
|
|
||||||
states.push_back (ss.str());
|
|
||||||
} else {
|
|
||||||
states.push_back (var.getChildNode("OUTCOME", j).getText());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
addNode (label, states);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned nDefs = xNode.nChildNode ("DEFINITION");
|
|
||||||
if (nVars != nDefs) {
|
|
||||||
cerr << "error: different number of variables and definitions" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < nDefs; i++) {
|
|
||||||
XMLNode def = xNode.getChildNode ("DEFINITION", i);
|
|
||||||
string label = def.getChildNode("FOR").getText();
|
|
||||||
BayesNode* node = getBayesNode (label);
|
|
||||||
if (!node) {
|
|
||||||
cerr << "error: unknow variable `" << label << "'" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
BnNodeSet parents;
|
|
||||||
unsigned nParams = node->nrStates();
|
|
||||||
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
|
|
||||||
string parentLabel = def.getChildNode("GIVEN", j).getText();
|
|
||||||
BayesNode* parentNode = getBayesNode (parentLabel);
|
|
||||||
if (!parentNode) {
|
|
||||||
cerr << "error: unknow variable `" << parentLabel << "'" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
nParams *= parentNode->nrStates();
|
|
||||||
parents.push_back (parentNode);
|
|
||||||
}
|
|
||||||
node->setParents (parents);
|
|
||||||
unsigned count = 0;
|
|
||||||
Params params (nParams);
|
|
||||||
stringstream s (def.getChildNode("TABLE").getText());
|
|
||||||
while (!s.eof() && count < nParams) {
|
|
||||||
s >> params[count];
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
if (count != nParams) {
|
|
||||||
cerr << "error: invalid number of parameters " ;
|
|
||||||
cerr << "for variable `" << label << "'" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
params = reorderParameters (params, node->nrStates());
|
|
||||||
Distribution* dist = new Distribution (params);
|
|
||||||
node->setDistribution (dist);
|
|
||||||
addDistribution (dist);
|
|
||||||
}
|
|
||||||
|
|
||||||
setIndexes();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
distributionsToLogs();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode*
|
const DAGraphNode*
|
||||||
BayesNet::addNode (string label, const States& states)
|
DAGraph::getNode (VarId vid) const
|
||||||
{
|
{
|
||||||
VarId vid = nodes_.size();
|
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
||||||
varMap_.insert (make_pair (vid, nodes_.size()));
|
it = varMap_.find (vid);
|
||||||
GraphicalModel::addVariableInformation (vid, label, states);
|
return it != varMap_.end() ? it->second : 0;
|
||||||
BayesNode* node = new BayesNode (VarNode (vid, states.size()));
|
|
||||||
nodes_.push_back (node);
|
|
||||||
return node;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode*
|
DAGraphNode*
|
||||||
BayesNet::addNode (VarId vid, unsigned dsize, int evidence, Distribution* dist)
|
DAGraph::getNode (VarId vid)
|
||||||
{
|
{
|
||||||
varMap_.insert (make_pair (vid, nodes_.size()));
|
unordered_map<VarId, DAGraphNode*>::const_iterator it;
|
||||||
nodes_.push_back (new BayesNode (vid, dsize, evidence, dist));
|
it = varMap_.find (vid);
|
||||||
return nodes_.back();
|
return it != varMap_.end() ? it->second : 0;
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode*
|
|
||||||
BayesNet::getBayesNode (VarId vid) const
|
|
||||||
{
|
|
||||||
IndexMap::const_iterator it = varMap_.find (vid);
|
|
||||||
if (it == varMap_.end()) {
|
|
||||||
return 0;
|
|
||||||
} else {
|
|
||||||
return nodes_[it->second];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode*
|
|
||||||
BayesNet::getBayesNode (string label) const
|
|
||||||
{
|
|
||||||
BayesNode* node = 0;
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
if (nodes_[i]->label() == label) {
|
|
||||||
node = nodes_[i];
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarNode*
|
|
||||||
BayesNet::getVariableNode (VarId vid) const
|
|
||||||
{
|
|
||||||
BayesNode* node = getBayesNode (vid);
|
|
||||||
assert (node);
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarNodes
|
|
||||||
BayesNet::getVariableNodes (void) const
|
|
||||||
{
|
|
||||||
VarNodes vars;
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
vars.push_back (nodes_[i]);
|
|
||||||
}
|
|
||||||
return vars;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::addDistribution (Distribution* dist)
|
DAGraph::setIndexes (void)
|
||||||
{
|
|
||||||
dists_.push_back (dist);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Distribution*
|
|
||||||
BayesNet::getDistribution (unsigned distId) const
|
|
||||||
{
|
|
||||||
Distribution* dist = 0;
|
|
||||||
for (unsigned i = 0; i < dists_.size(); i++) {
|
|
||||||
if (dists_[i]->id == (int) distId) {
|
|
||||||
dist = dists_[i];
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dist;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const BnNodeSet&
|
|
||||||
BayesNet::getBayesNodes (void) const
|
|
||||||
{
|
|
||||||
return nodes_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
BayesNet::nrNodes (void) const
|
|
||||||
{
|
|
||||||
return nodes_.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BnNodeSet
|
|
||||||
BayesNet::getRootNodes (void) const
|
|
||||||
{
|
|
||||||
BnNodeSet roots;
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
if (nodes_[i]->isRoot()) {
|
|
||||||
roots.push_back (nodes_[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return roots;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BnNodeSet
|
|
||||||
BayesNet::getLeafNodes (void) const
|
|
||||||
{
|
|
||||||
BnNodeSet leafs;
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
if (nodes_[i]->isLeaf()) {
|
|
||||||
leafs.push_back (nodes_[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return leafs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNet*
|
|
||||||
BayesNet::getMinimalRequesiteNetwork (VarId vid) const
|
|
||||||
{
|
|
||||||
return getMinimalRequesiteNetwork (VarIds() = {vid});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNet*
|
|
||||||
BayesNet::getMinimalRequesiteNetwork (const VarIds& queryVarIds) const
|
|
||||||
{
|
|
||||||
BnNodeSet queryVars;
|
|
||||||
Scheduling scheduling;
|
|
||||||
for (unsigned i = 0; i < queryVarIds.size(); i++) {
|
|
||||||
BayesNode* n = getBayesNode (queryVarIds[i]);
|
|
||||||
assert (n);
|
|
||||||
queryVars.push_back (n);
|
|
||||||
scheduling.push (ScheduleInfo (n, false, true));
|
|
||||||
}
|
|
||||||
|
|
||||||
vector<StateInfo*> states (nodes_.size(), 0);
|
|
||||||
|
|
||||||
while (!scheduling.empty()) {
|
|
||||||
ScheduleInfo& sch = scheduling.front();
|
|
||||||
StateInfo* state = states[sch.node->getIndex()];
|
|
||||||
if (!state) {
|
|
||||||
state = new StateInfo();
|
|
||||||
states[sch.node->getIndex()] = state;
|
|
||||||
} else {
|
|
||||||
state->visited = true;
|
|
||||||
}
|
|
||||||
if (!sch.node->hasEvidence() && sch.visitedFromChild) {
|
|
||||||
if (!state->markedOnTop) {
|
|
||||||
state->markedOnTop = true;
|
|
||||||
scheduleParents (sch.node, scheduling);
|
|
||||||
}
|
|
||||||
if (!state->markedOnBottom) {
|
|
||||||
state->markedOnBottom = true;
|
|
||||||
scheduleChilds (sch.node, scheduling);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (sch.visitedFromParent) {
|
|
||||||
if (sch.node->hasEvidence() && !state->markedOnTop) {
|
|
||||||
state->markedOnTop = true;
|
|
||||||
scheduleParents (sch.node, scheduling);
|
|
||||||
}
|
|
||||||
if (!sch.node->hasEvidence() && !state->markedOnBottom) {
|
|
||||||
state->markedOnBottom = true;
|
|
||||||
scheduleChilds (sch.node, scheduling);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
scheduling.pop();
|
|
||||||
}
|
|
||||||
/*
|
|
||||||
cout << "\t\ttop\tbottom" << endl;
|
|
||||||
cout << "variable\t\tmarked\tmarked\tvisited\tobserved" << endl;
|
|
||||||
cout << "----------------------------------------------------------" ;
|
|
||||||
cout << endl;
|
|
||||||
for (unsigned i = 0; i < states.size(); i++) {
|
|
||||||
cout << nodes_[i]->label() << ":\t\t" ;
|
|
||||||
if (states[i]) {
|
|
||||||
states[i]->markedOnTop ? cout << "yes\t" : cout << "no\t" ;
|
|
||||||
states[i]->markedOnBottom ? cout << "yes\t" : cout << "no\t" ;
|
|
||||||
states[i]->visited ? cout << "yes\t" : cout << "no\t" ;
|
|
||||||
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
|
|
||||||
cout << endl;
|
|
||||||
} else {
|
|
||||||
cout << "no\tno\tno\t" ;
|
|
||||||
nodes_[i]->hasEvidence() ? cout << "yes" : cout << "no" ;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
*/
|
|
||||||
BayesNet* bn = new BayesNet();
|
|
||||||
constructGraph (bn, states);
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
delete states[i];
|
|
||||||
}
|
|
||||||
return bn;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNet::constructGraph (BayesNet* bn,
|
|
||||||
const vector<StateInfo*>& states) const
|
|
||||||
{
|
|
||||||
BnNodeSet mrnNodes;
|
|
||||||
vector<VarIds> parents;
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
bool isRequired = false;
|
|
||||||
if (states[i]) {
|
|
||||||
isRequired = (nodes_[i]->hasEvidence() && states[i]->visited)
|
|
||||||
||
|
|
||||||
states[i]->markedOnTop;
|
|
||||||
}
|
|
||||||
if (isRequired) {
|
|
||||||
parents.push_back (VarIds());
|
|
||||||
if (states[i]->markedOnTop) {
|
|
||||||
const BnNodeSet& ps = nodes_[i]->getParents();
|
|
||||||
for (unsigned j = 0; j < ps.size(); j++) {
|
|
||||||
parents.back().push_back (ps[j]->varId());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
|
|
||||||
BayesNode* mrnNode = bn->addNode (nodes_[i]->varId(),
|
|
||||||
nodes_[i]->nrStates(),
|
|
||||||
nodes_[i]->getEvidence(),
|
|
||||||
nodes_[i]->getDistribution());
|
|
||||||
mrnNodes.push_back (mrnNode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < mrnNodes.size(); i++) {
|
|
||||||
BnNodeSet ps;
|
|
||||||
for (unsigned j = 0; j < parents[i].size(); j++) {
|
|
||||||
assert (bn->getBayesNode (parents[i][j]) != 0);
|
|
||||||
ps.push_back (bn->getBayesNode (parents[i][j]));
|
|
||||||
}
|
|
||||||
mrnNodes[i]->setParents (ps);
|
|
||||||
}
|
|
||||||
bn->setIndexes();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BayesNet::isPolyTree (void) const
|
|
||||||
{
|
|
||||||
return !containsUndirectedCycle();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNet::setIndexes (void)
|
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
nodes_[i]->setIndex (i);
|
nodes_[i]->setIndex (i);
|
||||||
@ -389,233 +65,43 @@ BayesNet::setIndexes (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::distributionsToLogs (void)
|
DAGraph::clear (void)
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < dists_.size(); i++) {
|
|
||||||
Util::toLog (dists_[i]->params);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNet::freeDistributions (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < dists_.size(); i++) {
|
|
||||||
delete dists_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNet::printGraphicalModel (void) const
|
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
cout << *nodes_[i];
|
nodes_[i]->clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
BayesNet::exportToGraphViz (const char* fileName,
|
DAGraph::exportToGraphViz (const char* fileName)
|
||||||
bool showNeighborless,
|
|
||||||
const VarIds& highlightVarIds) const
|
|
||||||
{
|
{
|
||||||
ofstream out (fileName);
|
ofstream out (fileName);
|
||||||
if (!out.is_open()) {
|
if (!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file to write at " ;
|
||||||
cerr << "BayesNet::exportToDotFile()" << endl;
|
cerr << "DAGraph::exportToDotFile()" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
out << "digraph {" << endl;
|
out << "digraph {" << endl;
|
||||||
out << "ranksep=1" << endl;
|
out << "ranksep=1" << endl;
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
if (showNeighborless || nodes_[i]->hasNeighbors()) {
|
out << nodes_[i]->varId() ;
|
||||||
out << nodes_[i]->varId() ;
|
out << " [" ;
|
||||||
if (nodes_[i]->hasEvidence()) {
|
out << "label=\"" << nodes_[i]->label() << "\"" ;
|
||||||
out << " [" ;
|
if (nodes_[i]->hasEvidence()) {
|
||||||
out << "label=\"" << nodes_[i]->label() << "\"," ;
|
out << ",style=filled, fillcolor=yellow" ;
|
||||||
out << "style=filled, fillcolor=yellow" ;
|
|
||||||
out << "]" ;
|
|
||||||
} else {
|
|
||||||
out << " [" ;
|
|
||||||
out << "label=\"" << nodes_[i]->label() << "\"" ;
|
|
||||||
out << "]" ;
|
|
||||||
}
|
|
||||||
out << endl;
|
|
||||||
}
|
}
|
||||||
|
out << "]" << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
|
||||||
BayesNode* node = getBayesNode (highlightVarIds[i]);
|
|
||||||
if (node) {
|
|
||||||
out << node->varId() ;
|
|
||||||
out << " [shape=box3d]" << endl;
|
|
||||||
} else {
|
|
||||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
const BnNodeSet& childs = nodes_[i]->getChilds();
|
const vector<DAGraphNode*>& childs = nodes_[i]->childs();
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
for (unsigned j = 0; j < childs.size(); j++) {
|
||||||
out << nodes_[i]->varId() << " -> " << childs[j]->varId() << " [style=bold]" << endl ;
|
out << nodes_[i]->varId() << " -> " << childs[j]->varId();
|
||||||
|
out << " [style=bold]" << endl ;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out << "}" << endl;
|
out << "}" << endl;
|
||||||
out.close();
|
out.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNet::exportToBifFormat (const char* fileName) const
|
|
||||||
{
|
|
||||||
ofstream out (fileName);
|
|
||||||
if(!out.is_open()) {
|
|
||||||
cerr << "error: cannot open file to write at " ;
|
|
||||||
cerr << "BayesNet::exportToBifFile()" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
out << "<?xml version=\"1.0\" encoding=\"US-ASCII\"?>" << endl;
|
|
||||||
out << "<BIF VERSION=\"0.3\">" << endl;
|
|
||||||
out << "<NETWORK>" << endl;
|
|
||||||
out << "<NAME>" << fileName << "</NAME>" << endl << endl;
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
out << "<VARIABLE TYPE=\"nature\">" << endl;
|
|
||||||
out << "\t<NAME>" << nodes_[i]->label() << "</NAME>" << endl;
|
|
||||||
const States& states = nodes_[i]->states();
|
|
||||||
for (unsigned j = 0; j < states.size(); j++) {
|
|
||||||
out << "\t<OUTCOME>" << states[j] << "</OUTCOME>" << endl;
|
|
||||||
}
|
|
||||||
out << "</VARIABLE>" << endl << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
out << "<DEFINITION>" << endl;
|
|
||||||
out << "\t<FOR>" << nodes_[i]->label() << "</FOR>" << endl;
|
|
||||||
const BnNodeSet& parents = nodes_[i]->getParents();
|
|
||||||
for (unsigned j = 0; j < parents.size(); j++) {
|
|
||||||
out << "\t<GIVEN>" << parents[j]->label();
|
|
||||||
out << "</GIVEN>" << endl;
|
|
||||||
}
|
|
||||||
Params params = revertParameterReorder (nodes_[i]->getParameters(),
|
|
||||||
nodes_[i]->nrStates());
|
|
||||||
out << "\t<TABLE>" ;
|
|
||||||
for (unsigned j = 0; j < params.size(); j++) {
|
|
||||||
out << " " << params[j];
|
|
||||||
}
|
|
||||||
out << " </TABLE>" << endl;
|
|
||||||
out << "</DEFINITION>" << endl << endl;
|
|
||||||
}
|
|
||||||
out << "</NETWORK>" << endl;
|
|
||||||
out << "</BIF>" << endl << endl;
|
|
||||||
out.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BayesNet::containsUndirectedCycle (void) const
|
|
||||||
{
|
|
||||||
vector<bool> visited (nodes_.size(), false);
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
int v = nodes_[i]->getIndex();
|
|
||||||
if (!visited[v]) {
|
|
||||||
if (containsUndirectedCycle (v, -1, visited)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BayesNet::containsUndirectedCycle (int v, int p, vector<bool>& visited) const
|
|
||||||
{
|
|
||||||
visited[v] = true;
|
|
||||||
vector<int> adjacencies = getAdjacentNodes (v);
|
|
||||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
|
||||||
int w = adjacencies[i];
|
|
||||||
if (!visited[w]) {
|
|
||||||
if (containsUndirectedCycle (w, v, visited)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (visited[w] && w != p) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false; // no cycle detected in this component
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<int>
|
|
||||||
BayesNet::getAdjacentNodes (int v) const
|
|
||||||
{
|
|
||||||
vector<int> adjacencies;
|
|
||||||
const BnNodeSet& parents = nodes_[v]->getParents();
|
|
||||||
const BnNodeSet& childs = nodes_[v]->getChilds();
|
|
||||||
for (unsigned i = 0; i < parents.size(); i++) {
|
|
||||||
adjacencies.push_back (parents[i]->getIndex());
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < childs.size(); i++) {
|
|
||||||
adjacencies.push_back (childs[i]->getIndex());
|
|
||||||
}
|
|
||||||
return adjacencies;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BayesNet::reorderParameters (const Params& params, unsigned dsize) const
|
|
||||||
{
|
|
||||||
// the interchange format for bayesian networks keeps the probabilities
|
|
||||||
// in the following order:
|
|
||||||
// p(a1|b1,c1) p(a2|b1,c1) p(a1|b1,c2) p(a2|b1,c2) p(a1|b2,c1) p(a2|b2,c1)
|
|
||||||
// p(a1|b2,c2) p(a2|b2,c2).
|
|
||||||
//
|
|
||||||
// however, in clpbn we keep the probabilities in this order:
|
|
||||||
// p(a1|b1,c1) p(a1|b1,c2) p(a1|b2,c1) p(a1|b2,c2) p(a2|b1,c1) p(a2|b1,c2)
|
|
||||||
// p(a2|b2,c1) p(a2|b2,c2).
|
|
||||||
unsigned count = 0;
|
|
||||||
unsigned rowSize = params.size() / dsize;
|
|
||||||
Params reordered;
|
|
||||||
while (reordered.size() < params.size()) {
|
|
||||||
unsigned idx = count;
|
|
||||||
for (unsigned i = 0; i < rowSize; i++) {
|
|
||||||
reordered.push_back (params[idx]);
|
|
||||||
idx += dsize ;
|
|
||||||
}
|
|
||||||
count++;
|
|
||||||
}
|
|
||||||
return reordered;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BayesNet::revertParameterReorder (const Params& params, unsigned dsize) const
|
|
||||||
{
|
|
||||||
unsigned count = 0;
|
|
||||||
unsigned rowSize = params.size() / dsize;
|
|
||||||
Params reordered;
|
|
||||||
while (reordered.size() < params.size()) {
|
|
||||||
unsigned idx = count;
|
|
||||||
for (unsigned i = 0; i < dsize; i++) {
|
|
||||||
reordered.push_back (params[idx]);
|
|
||||||
idx += rowSize;
|
|
||||||
}
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
return reordered;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -6,118 +6,83 @@
|
|||||||
#include <list>
|
#include <list>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
#include "GraphicalModel.h"
|
#include "Var.h"
|
||||||
#include "BayesNode.h"
|
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class Distribution;
|
|
||||||
|
|
||||||
struct ScheduleInfo
|
class Var;
|
||||||
{
|
|
||||||
ScheduleInfo (BayesNode* n, bool vfp, bool vfc)
|
|
||||||
{
|
|
||||||
node = n;
|
|
||||||
visitedFromParent = vfp;
|
|
||||||
visitedFromChild = vfc;
|
|
||||||
}
|
|
||||||
BayesNode* node;
|
|
||||||
bool visitedFromParent;
|
|
||||||
bool visitedFromChild;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
class DAGraphNode : public Var
|
||||||
struct StateInfo
|
|
||||||
{
|
|
||||||
StateInfo (void)
|
|
||||||
{
|
|
||||||
visited = true;
|
|
||||||
markedOnTop = false;
|
|
||||||
markedOnBottom = false;
|
|
||||||
}
|
|
||||||
bool visited;
|
|
||||||
bool markedOnTop;
|
|
||||||
bool markedOnBottom;
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef vector<Distribution*> DistSet;
|
|
||||||
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
|
|
||||||
|
|
||||||
|
|
||||||
class BayesNet : public GraphicalModel
|
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
BayesNet (void) {};
|
DAGraphNode (Var* v) : Var (v) , visited_(false),
|
||||||
~BayesNet (void);
|
markedOnTop_(false), markedOnBottom_(false) { }
|
||||||
|
|
||||||
void readFromBifFormat (const char*);
|
const vector<DAGraphNode*>& childs (void) const { return childs_; }
|
||||||
BayesNode* addNode (string, const States&);
|
|
||||||
// BayesNode* addNode (VarId, unsigned, int, BnNodeSet&, Distribution*);
|
vector<DAGraphNode*>& childs (void) { return childs_; }
|
||||||
BayesNode* addNode (VarId, unsigned, int, Distribution*);
|
|
||||||
BayesNode* getBayesNode (VarId) const;
|
const vector<DAGraphNode*>& parents (void) const { return parents_; }
|
||||||
BayesNode* getBayesNode (string) const;
|
|
||||||
VarNode* getVariableNode (VarId) const;
|
vector<DAGraphNode*>& parents (void) { return parents_; }
|
||||||
VarNodes getVariableNodes (void) const;
|
|
||||||
void addDistribution (Distribution*);
|
void addParent (DAGraphNode* p) { parents_.push_back (p); }
|
||||||
Distribution* getDistribution (unsigned) const;
|
|
||||||
const BnNodeSet& getBayesNodes (void) const;
|
void addChild (DAGraphNode* c) { childs_.push_back (c); }
|
||||||
unsigned nrNodes (void) const;
|
|
||||||
BnNodeSet getRootNodes (void) const;
|
bool isVisited (void) const { return visited_; }
|
||||||
BnNodeSet getLeafNodes (void) const;
|
|
||||||
BayesNet* getMinimalRequesiteNetwork (VarId) const;
|
void setAsVisited (void) { visited_ = true; }
|
||||||
BayesNet* getMinimalRequesiteNetwork (const VarIds&) const;
|
|
||||||
void constructGraph (
|
bool isMarkedOnTop (void) const { return markedOnTop_; }
|
||||||
BayesNet*, const vector<StateInfo*>&) const;
|
|
||||||
bool isPolyTree (void) const;
|
void markOnTop (void) { markedOnTop_ = true; }
|
||||||
void setIndexes (void);
|
|
||||||
void distributionsToLogs (void);
|
bool isMarkedOnBottom (void) const { return markedOnBottom_; }
|
||||||
void freeDistributions (void);
|
|
||||||
void printGraphicalModel (void) const;
|
void markOnBottom (void) { markedOnBottom_ = true; }
|
||||||
void exportToGraphViz (const char*, bool = true,
|
|
||||||
const VarIds& = VarIds()) const;
|
void clear (void) { visited_ = markedOnTop_ = markedOnBottom_ = false; }
|
||||||
void exportToBifFormat (const char*) const;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (BayesNet);
|
bool visited_;
|
||||||
|
bool markedOnTop_;
|
||||||
|
bool markedOnBottom_;
|
||||||
|
|
||||||
bool containsUndirectedCycle (void) const;
|
vector<DAGraphNode*> childs_;
|
||||||
bool containsUndirectedCycle (int, int, vector<bool>&)const;
|
vector<DAGraphNode*> parents_;
|
||||||
vector<int> getAdjacentNodes (int) const;
|
|
||||||
Params reorderParameters (const Params&, unsigned) const;
|
|
||||||
Params revertParameterReorder (const Params&, unsigned) const;
|
|
||||||
void scheduleParents (const BayesNode*, Scheduling&) const;
|
|
||||||
void scheduleChilds (const BayesNode*, Scheduling&) const;
|
|
||||||
|
|
||||||
BnNodeSet nodes_;
|
|
||||||
DistSet dists_;
|
|
||||||
|
|
||||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
|
||||||
IndexMap varMap_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class DAGraph
|
||||||
inline void
|
|
||||||
BayesNet::scheduleParents (const BayesNode* n, Scheduling& sch) const
|
|
||||||
{
|
{
|
||||||
const BnNodeSet& ps = n->getParents();
|
public:
|
||||||
for (BnNodeSet::const_iterator it = ps.begin(); it != ps.end(); it++) {
|
DAGraph (void) { }
|
||||||
sch.push (ScheduleInfo (*it, false, true));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
void addNode (DAGraphNode* n);
|
||||||
|
|
||||||
|
void addEdge (VarId vid1, VarId vid2);
|
||||||
|
|
||||||
inline void
|
const DAGraphNode* getNode (VarId vid) const;
|
||||||
BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const
|
|
||||||
{
|
DAGraphNode* getNode (VarId vid);
|
||||||
const BnNodeSet& cs = n->getChilds();
|
|
||||||
for (BnNodeSet::const_iterator it = cs.begin(); it != cs.end(); it++) {
|
bool empty (void) const { return nodes_.empty(); }
|
||||||
sch.push (ScheduleInfo (*it, true, false));
|
|
||||||
}
|
void setIndexes (void);
|
||||||
}
|
|
||||||
|
void clear (void);
|
||||||
|
|
||||||
|
void exportToGraphViz (const char*);
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<DAGraphNode*> nodes_;
|
||||||
|
|
||||||
|
unordered_map<VarId, DAGraphNode*> varMap_;
|
||||||
|
};
|
||||||
|
|
||||||
#endif // HORUS_BAYESNET_H
|
#endif // HORUS_BAYESNET_H
|
||||||
|
|
||||||
|
@ -1,291 +0,0 @@
|
|||||||
#include <cstdlib>
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include <iomanip>
|
|
||||||
#include <iostream>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "BayesNode.h"
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode::BayesNode (VarId vid,
|
|
||||||
unsigned dsize,
|
|
||||||
int evidence,
|
|
||||||
Distribution* dist)
|
|
||||||
: VarNode (vid, dsize, evidence)
|
|
||||||
{
|
|
||||||
dist_ = dist;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BayesNode::BayesNode (VarId vid,
|
|
||||||
unsigned dsize,
|
|
||||||
int evidence,
|
|
||||||
const BnNodeSet& parents,
|
|
||||||
Distribution* dist)
|
|
||||||
: VarNode (vid, dsize, evidence)
|
|
||||||
{
|
|
||||||
parents_ = parents;
|
|
||||||
dist_ = dist;
|
|
||||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
|
||||||
parents[i]->addChild (this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNode::setParents (const BnNodeSet& parents)
|
|
||||||
{
|
|
||||||
parents_ = parents;
|
|
||||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
|
||||||
parents[i]->addChild (this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNode::addChild (BayesNode* node)
|
|
||||||
{
|
|
||||||
childs_.push_back (node);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BayesNode::setDistribution (Distribution* dist)
|
|
||||||
{
|
|
||||||
assert (dist);
|
|
||||||
dist_ = dist;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Distribution*
|
|
||||||
BayesNode::getDistribution (void)
|
|
||||||
{
|
|
||||||
return dist_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const Params&
|
|
||||||
BayesNode::getParameters (void)
|
|
||||||
{
|
|
||||||
return dist_->params;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BayesNode::getRow (int rowIndex) const
|
|
||||||
{
|
|
||||||
int rowSize = getRowSize();
|
|
||||||
int offset = rowSize * rowIndex;
|
|
||||||
Params row (rowSize);
|
|
||||||
for (int i = 0; i < rowSize; i++) {
|
|
||||||
row[i] = dist_->params[offset + i] ;
|
|
||||||
}
|
|
||||||
return row;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BayesNode::isRoot (void)
|
|
||||||
{
|
|
||||||
return getParents().empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BayesNode::isLeaf (void)
|
|
||||||
{
|
|
||||||
return getChilds().empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BayesNode::hasNeighbors (void) const
|
|
||||||
{
|
|
||||||
return childs_.size() != 0 || parents_.size() != 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
BayesNode::getCptSize (void)
|
|
||||||
{
|
|
||||||
return dist_->params.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
BayesNode::getIndexOfParent (const BayesNode* parent) const
|
|
||||||
{
|
|
||||||
for (unsigned int i = 0; i < parents_.size(); i++) {
|
|
||||||
if (parents_[i] == parent) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
|
||||||
BayesNode::cptEntryToString (
|
|
||||||
int row,
|
|
||||||
const vector<unsigned>& stateConf) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss << "p(" ;
|
|
||||||
ss << states()[row];
|
|
||||||
if (parents_.size() > 0) {
|
|
||||||
ss << "|" ;
|
|
||||||
for (unsigned int i = 0; i < stateConf.size(); i++) {
|
|
||||||
if (i != 0) {
|
|
||||||
ss << ",";
|
|
||||||
}
|
|
||||||
ss << parents_[i]->states()[stateConf[i]];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ss << ")" ;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<string>
|
|
||||||
BayesNode::getDomainHeaders (void) const
|
|
||||||
{
|
|
||||||
unsigned nParents = parents_.size();
|
|
||||||
unsigned rowSize = getRowSize();
|
|
||||||
unsigned nReps = 1;
|
|
||||||
vector<string> headers (rowSize);
|
|
||||||
for (int i = nParents - 1; i >= 0; i--) {
|
|
||||||
States states = parents_[i]->states();
|
|
||||||
unsigned index = 0;
|
|
||||||
while (index < rowSize) {
|
|
||||||
for (unsigned j = 0; j < parents_[i]->nrStates(); j++) {
|
|
||||||
for (unsigned r = 0; r < nReps; r++) {
|
|
||||||
if (headers[index] != "") {
|
|
||||||
headers[index] = states[j] + "," + headers[index];
|
|
||||||
} else {
|
|
||||||
headers[index] = states[j];
|
|
||||||
}
|
|
||||||
index++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
nReps *= parents_[i]->nrStates();
|
|
||||||
}
|
|
||||||
return headers;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ostream&
|
|
||||||
operator << (ostream& o, const BayesNode& node)
|
|
||||||
{
|
|
||||||
o << "variable " << node.getIndex() << endl;
|
|
||||||
o << "Var Id: " << node.varId() << endl;
|
|
||||||
o << "Label: " << node.label() << endl;
|
|
||||||
|
|
||||||
o << "Evidence: " ;
|
|
||||||
if (node.hasEvidence()) {
|
|
||||||
o << node.getEvidence();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
o << "no" ;
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
|
|
||||||
o << "Parents: " ;
|
|
||||||
const BnNodeSet& parents = node.getParents();
|
|
||||||
if (parents.size() != 0) {
|
|
||||||
for (unsigned int i = 0; i < parents.size() - 1; i++) {
|
|
||||||
o << parents[i]->label() << ", " ;
|
|
||||||
}
|
|
||||||
o << parents[parents.size() - 1]->label();
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
|
|
||||||
o << "Childs: " ;
|
|
||||||
const BnNodeSet& childs = node.getChilds();
|
|
||||||
if (childs.size() != 0) {
|
|
||||||
for (unsigned int i = 0; i < childs.size() - 1; i++) {
|
|
||||||
o << childs[i]->label() << ", " ;
|
|
||||||
}
|
|
||||||
o << childs[childs.size() - 1]->label();
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
|
|
||||||
o << "Domain: " ;
|
|
||||||
States states = node.states();
|
|
||||||
for (unsigned int i = 0; i < states.size() - 1; i++) {
|
|
||||||
o << states[i] << ", " ;
|
|
||||||
}
|
|
||||||
if (states.size() != 0) {
|
|
||||||
o << states[states.size() - 1];
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
|
|
||||||
// min width of first column
|
|
||||||
const unsigned int MIN_DOMAIN_WIDTH = 4;
|
|
||||||
// min width of following columns
|
|
||||||
const unsigned int MIN_COMBO_WIDTH = 12;
|
|
||||||
|
|
||||||
unsigned int domainWidth = states[0].length();
|
|
||||||
for (unsigned int i = 1; i < states.size(); i++) {
|
|
||||||
if (states[i].length() > domainWidth) {
|
|
||||||
domainWidth = states[i].length();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
domainWidth = (domainWidth < MIN_DOMAIN_WIDTH)
|
|
||||||
? MIN_DOMAIN_WIDTH
|
|
||||||
: domainWidth;
|
|
||||||
|
|
||||||
o << left << setw (domainWidth) << "cpt" << right;
|
|
||||||
|
|
||||||
vector<int> widths;
|
|
||||||
int lineWidth = domainWidth;
|
|
||||||
vector<string> headers = node.getDomainHeaders();
|
|
||||||
|
|
||||||
if (!headers.empty()) {
|
|
||||||
for (unsigned int i = 0; i < headers.size(); i++) {
|
|
||||||
unsigned int len = headers[i].length();
|
|
||||||
int w = (len < MIN_COMBO_WIDTH) ? MIN_COMBO_WIDTH : len;
|
|
||||||
widths.push_back (w);
|
|
||||||
o << setw (w) << headers[i];
|
|
||||||
lineWidth += w;
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
} else {
|
|
||||||
cout << endl;
|
|
||||||
widths.push_back (domainWidth);
|
|
||||||
lineWidth += MIN_COMBO_WIDTH;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < lineWidth; i++) {
|
|
||||||
o << "-" ;
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
|
|
||||||
for (unsigned int i = 0; i < states.size(); i++) {
|
|
||||||
Params row = node.getRow (i);
|
|
||||||
o << left << setw (domainWidth) << states[i] << right;
|
|
||||||
for (unsigned j = 0; j < node.getRowSize(); j++) {
|
|
||||||
o << setw (widths[j]) << row[j];
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
}
|
|
||||||
o << endl;
|
|
||||||
|
|
||||||
return o;
|
|
||||||
}
|
|
||||||
|
|
@ -1,61 +0,0 @@
|
|||||||
#ifndef HORUS_BAYESNODE_H
|
|
||||||
#define HORUS_BAYESNODE_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "VarNode.h"
|
|
||||||
#include "Distribution.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
class BayesNode : public VarNode
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
BayesNode (const VarNode& v) : VarNode (v) {}
|
|
||||||
BayesNode (VarId, unsigned, int, Distribution*);
|
|
||||||
BayesNode (VarId, unsigned, int, const BnNodeSet&, Distribution*);
|
|
||||||
|
|
||||||
void setParents (const BnNodeSet&);
|
|
||||||
void addChild (BayesNode*);
|
|
||||||
void setDistribution (Distribution*);
|
|
||||||
Distribution* getDistribution (void);
|
|
||||||
const Params& getParameters (void);
|
|
||||||
Params getRow (int) const;
|
|
||||||
bool isRoot (void);
|
|
||||||
bool isLeaf (void);
|
|
||||||
bool hasNeighbors (void) const;
|
|
||||||
int getCptSize (void);
|
|
||||||
int getIndexOfParent (const BayesNode*) const;
|
|
||||||
string cptEntryToString (int, const vector<unsigned>&) const;
|
|
||||||
|
|
||||||
const BnNodeSet& getParents (void) const { return parents_; }
|
|
||||||
const BnNodeSet& getChilds (void) const { return childs_; }
|
|
||||||
|
|
||||||
unsigned getRowSize (void) const
|
|
||||||
{
|
|
||||||
return dist_->params.size() / nrStates();
|
|
||||||
}
|
|
||||||
|
|
||||||
double getProbability (int row, unsigned col)
|
|
||||||
{
|
|
||||||
int idx = (row * getRowSize()) + col;
|
|
||||||
return dist_->params[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
DISALLOW_COPY_AND_ASSIGN (BayesNode);
|
|
||||||
|
|
||||||
States getDomainHeaders (void) const;
|
|
||||||
friend ostream& operator << (ostream&, const BayesNode&);
|
|
||||||
|
|
||||||
BnNodeSet parents_;
|
|
||||||
BnNodeSet childs_;
|
|
||||||
Distribution* dist_;
|
|
||||||
};
|
|
||||||
|
|
||||||
ostream& operator << (ostream&, const BayesNode&);
|
|
||||||
|
|
||||||
#endif // HORUS_BAYESNODE_H
|
|
||||||
|
|
@ -1,803 +0,0 @@
|
|||||||
#include <cstdlib>
|
|
||||||
#include <limits>
|
|
||||||
#include <time.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <sstream>
|
|
||||||
#include <iomanip>
|
|
||||||
|
|
||||||
#include "BnBpSolver.h"
|
|
||||||
#include "Indexer.h"
|
|
||||||
|
|
||||||
BnBpSolver::BnBpSolver (const BayesNet& bn) : Solver (&bn)
|
|
||||||
{
|
|
||||||
bayesNet_ = &bn;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BnBpSolver::~BnBpSolver (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < nodesI_.size(); i++) {
|
|
||||||
delete nodesI_[i];
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
delete links_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::runSolver (void)
|
|
||||||
{
|
|
||||||
clock_t start;
|
|
||||||
if (COLLECT_STATISTICS) {
|
|
||||||
start = clock();
|
|
||||||
}
|
|
||||||
initializeSolver();
|
|
||||||
runLoopySolver();
|
|
||||||
if (DL >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
if (nIters_ < BpOptions::maxIter) {
|
|
||||||
cout << "Belief propagation converged in " ;
|
|
||||||
cout << nIters_ << " iterations" << endl;
|
|
||||||
} else {
|
|
||||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned size = bayesNet_->nrNodes();
|
|
||||||
if (COLLECT_STATISTICS) {
|
|
||||||
unsigned nIters = 0;
|
|
||||||
bool loopy = bayesNet_->isPolyTree() == false;
|
|
||||||
if (loopy) nIters = nIters_;
|
|
||||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
|
||||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
|
||||||
}
|
|
||||||
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
|
|
||||||
stringstream ss;
|
|
||||||
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
|
|
||||||
bayesNet_->exportToGraphViz (ss.str().c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BnBpSolver::getPosterioriOf (VarId vid)
|
|
||||||
{
|
|
||||||
BayesNode* node = bayesNet_->getBayesNode (vid);
|
|
||||||
assert (node);
|
|
||||||
return nodesI_[node->getIndex()]->getBeliefs();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BnBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|
||||||
{
|
|
||||||
if (DL >= 2) {
|
|
||||||
cout << "calculating joint distribution on: " ;
|
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
|
||||||
VarNode* var = bayesNet_->getBayesNode (jointVarIds[i]);
|
|
||||||
cout << var->label() << " " ;
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
return getJointByConditioning (jointVarIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::initializeSolver (void)
|
|
||||||
{
|
|
||||||
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
|
|
||||||
for (unsigned i = 0; i < nodesI_.size(); i++) {
|
|
||||||
delete nodesI_[i];
|
|
||||||
}
|
|
||||||
nodesI_.clear();
|
|
||||||
nodesI_.reserve (nodes.size());
|
|
||||||
links_.clear();
|
|
||||||
sortedOrder_.clear();
|
|
||||||
linkMap_.clear();
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
nodesI_.push_back (new BpNodeInfo (nodes[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
BnNodeSet roots = bayesNet_->getRootNodes();
|
|
||||||
for (unsigned i = 0; i < roots.size(); i++) {
|
|
||||||
const Params& params = roots[i]->getParameters();
|
|
||||||
Params& piVals = ninf(roots[i])->getPiValues();
|
|
||||||
for (unsigned ri = 0; ri < roots[i]->nrStates(); ri++) {
|
|
||||||
piVals[ri] = params[ri];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
const BnNodeSet& parents = nodes[i]->getParents();
|
|
||||||
for (unsigned j = 0; j < parents.size(); j++) {
|
|
||||||
BpLink* newLink = new BpLink (
|
|
||||||
parents[j], nodes[i], LinkOrientation::DOWN);
|
|
||||||
links_.push_back (newLink);
|
|
||||||
ninf(nodes[i])->addIncomingParentLink (newLink);
|
|
||||||
ninf(parents[j])->addOutcomingChildLink (newLink);
|
|
||||||
}
|
|
||||||
const BnNodeSet& childs = nodes[i]->getChilds();
|
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
|
||||||
BpLink* newLink = new BpLink (
|
|
||||||
childs[j], nodes[i], LinkOrientation::UP);
|
|
||||||
links_.push_back (newLink);
|
|
||||||
ninf(nodes[i])->addIncomingChildLink (newLink);
|
|
||||||
ninf(childs[j])->addOutcomingParentLink (newLink);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
if (nodes[i]->hasEvidence()) {
|
|
||||||
Params& piVals = ninf(nodes[i])->getPiValues();
|
|
||||||
Params& ldVals = ninf(nodes[i])->getLambdaValues();
|
|
||||||
for (unsigned xi = 0; xi < nodes[i]->nrStates(); xi++) {
|
|
||||||
piVals[xi] = Util::noEvidence();
|
|
||||||
ldVals[xi] = Util::noEvidence();
|
|
||||||
}
|
|
||||||
piVals[nodes[i]->getEvidence()] = Util::withEvidence();
|
|
||||||
ldVals[nodes[i]->getEvidence()] = Util::withEvidence();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::runLoopySolver()
|
|
||||||
{
|
|
||||||
nIters_ = 0;
|
|
||||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
|
||||||
|
|
||||||
nIters_++;
|
|
||||||
if (DL >= 2) {
|
|
||||||
cout << "****************************************" ;
|
|
||||||
cout << "****************************************" ;
|
|
||||||
cout << endl;
|
|
||||||
cout << " Iteration " << nIters_ << endl;
|
|
||||||
cout << "****************************************" ;
|
|
||||||
cout << "****************************************" ;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (BpOptions::schedule) {
|
|
||||||
|
|
||||||
case BpOptions::Schedule::SEQ_RANDOM:
|
|
||||||
random_shuffle (links_.begin(), links_.end());
|
|
||||||
// no break
|
|
||||||
|
|
||||||
case BpOptions::Schedule::SEQ_FIXED:
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateAndUpdateMessage (links_[i]);
|
|
||||||
updateValues (links_[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case BpOptions::Schedule::PARALLEL:
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateMessage (links_[i]);
|
|
||||||
}
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
updateMessage (links_[i]);
|
|
||||||
updateValues (links_[i]);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
|
||||||
maxResidualSchedule();
|
|
||||||
break;
|
|
||||||
|
|
||||||
}
|
|
||||||
if (DL >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BnBpSolver::converged (void) const
|
|
||||||
{
|
|
||||||
// this can happen if the graph is fully disconnected
|
|
||||||
if (links_.size() == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (nIters_ == 0 || nIters_ == 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool converged = true;
|
|
||||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
|
||||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
|
||||||
if (maxResidual < BpOptions::accuracy) {
|
|
||||||
converged = true;
|
|
||||||
} else {
|
|
||||||
converged = false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
double residual = links_[i]->getResidual();
|
|
||||||
if (DL >= 2) {
|
|
||||||
cout << links_[i]->toString() + " residual change = " ;
|
|
||||||
cout << residual << endl;
|
|
||||||
}
|
|
||||||
if (residual > BpOptions::accuracy) {
|
|
||||||
converged = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return converged;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::maxResidualSchedule (void)
|
|
||||||
{
|
|
||||||
if (nIters_ == 1) {
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
calculateMessage (links_[i]);
|
|
||||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
|
||||||
linkMap_.insert (make_pair (links_[i], it));
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned c = 0; c < sortedOrder_.size(); c++) {
|
|
||||||
if (DL >= 2) {
|
|
||||||
cout << "current residuals:" << endl;
|
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
|
||||||
it != sortedOrder_.end(); it ++) {
|
|
||||||
cout << " " << setw (30) << left << (*it)->toString();
|
|
||||||
cout << "residual = " << (*it)->getResidual() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SortedOrder::iterator it = sortedOrder_.begin();
|
|
||||||
BpLink* link = *it;
|
|
||||||
if (link->getResidual() < BpOptions::accuracy) {
|
|
||||||
sortedOrder_.erase (it);
|
|
||||||
it = sortedOrder_.begin();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
updateMessage (link);
|
|
||||||
updateValues (link);
|
|
||||||
link->clearResidual();
|
|
||||||
sortedOrder_.erase (it);
|
|
||||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
|
||||||
|
|
||||||
const BpLinkSet& outParentLinks =
|
|
||||||
ninf(link->getDestination())->getOutcomingParentLinks();
|
|
||||||
for (unsigned i = 0; i < outParentLinks.size(); i++) {
|
|
||||||
if (outParentLinks[i]->getDestination() != link->getSource()
|
|
||||||
&& outParentLinks[i]->getDestination()->hasEvidence() == false) {
|
|
||||||
calculateMessage (outParentLinks[i]);
|
|
||||||
BpLinkMap::iterator iter = linkMap_.find (outParentLinks[i]);
|
|
||||||
sortedOrder_.erase (iter->second);
|
|
||||||
iter->second = sortedOrder_.insert (outParentLinks[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const BpLinkSet& outChildLinks =
|
|
||||||
ninf(link->getDestination())->getOutcomingChildLinks();
|
|
||||||
for (unsigned i = 0; i < outChildLinks.size(); i++) {
|
|
||||||
if (outChildLinks[i]->getDestination() != link->getSource()) {
|
|
||||||
calculateMessage (outChildLinks[i]);
|
|
||||||
BpLinkMap::iterator iter = linkMap_.find (outChildLinks[i]);
|
|
||||||
sortedOrder_.erase (iter->second);
|
|
||||||
iter->second = sortedOrder_.insert (outChildLinks[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (DL >= 2) {
|
|
||||||
cout << "----------------------------------------" ;
|
|
||||||
cout << "----------------------------------------" << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::updatePiValues (BayesNode* x)
|
|
||||||
{
|
|
||||||
// π(Xi)
|
|
||||||
if (DL >= 3) {
|
|
||||||
cout << "updating " << PI_SYMBOL << " values for " << x->label() << endl;
|
|
||||||
}
|
|
||||||
Params& piValues = ninf(x)->getPiValues();
|
|
||||||
const BpLinkSet& parentLinks = ninf(x)->getIncomingParentLinks();
|
|
||||||
const BnNodeSet& ps = x->getParents();
|
|
||||||
Ranges ranges;
|
|
||||||
for (unsigned i = 0; i < ps.size(); i++) {
|
|
||||||
ranges.push_back (ps[i]->nrStates());
|
|
||||||
}
|
|
||||||
StatesIndexer indexer (ranges, false);
|
|
||||||
stringstream* calcs1 = 0;
|
|
||||||
stringstream* calcs2 = 0;
|
|
||||||
|
|
||||||
Params messageProducts (indexer.size());
|
|
||||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
|
||||||
if (DL >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
double messageProduct = Util::multIdenty();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
|
||||||
messageProduct += parentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
|
||||||
messageProduct *= parentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
if (DL >= 5) {
|
|
||||||
if (i != 0) *calcs1 << " + " ;
|
|
||||||
if (i != 0) *calcs2 << " + " ;
|
|
||||||
*calcs1 << parentLinks[i]->toString (indexer[i]);
|
|
||||||
*calcs2 << parentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
messageProducts[k] = messageProduct;
|
|
||||||
if (DL >= 5) {
|
|
||||||
cout << " mp" << k;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
if (parentLinks.size() == 1) {
|
|
||||||
cout << " = " << messageProduct << endl;
|
|
||||||
} else {
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << messageProduct << endl;
|
|
||||||
}
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
|
||||||
double sum = Util::addIdenty();
|
|
||||||
if (DL >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
indexer.reset();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
|
||||||
Util::logSum (sum,
|
|
||||||
x->getProbability(xi, indexer.linearIndex()) + messageProducts[k]);
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned k = 0; k < indexer.size(); k++) {
|
|
||||||
sum += x->getProbability (xi, indexer.linearIndex()) * messageProducts[k];
|
|
||||||
if (DL >= 5) {
|
|
||||||
if (k != 0) *calcs1 << " + " ;
|
|
||||||
if (k != 0) *calcs2 << " + " ;
|
|
||||||
*calcs1 << x->cptEntryToString (xi, indexer.indices());
|
|
||||||
*calcs1 << ".mp" << k;
|
|
||||||
*calcs2 << Util::fl (x->getProbability (xi, indexer.linearIndex()));
|
|
||||||
*calcs2 << "*" << messageProducts[k];
|
|
||||||
}
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
piValues[xi] = sum;
|
|
||||||
if (DL >= 5) {
|
|
||||||
cout << " " << PI_SYMBOL << "(" << x->label() << ")" ;
|
|
||||||
cout << "[" << x->states()[xi] << "]" ;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << piValues[xi] << endl;
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::updateLambdaValues (BayesNode* x)
|
|
||||||
{
|
|
||||||
// λ(Xi)
|
|
||||||
if (DL >= 3) {
|
|
||||||
cout << "updating " << LD_SYMBOL << " values for " << x->label() << endl;
|
|
||||||
}
|
|
||||||
Params& lambdaValues = ninf(x)->getLambdaValues();
|
|
||||||
const BpLinkSet& childLinks = ninf(x)->getIncomingChildLinks();
|
|
||||||
stringstream* calcs1 = 0;
|
|
||||||
stringstream* calcs2 = 0;
|
|
||||||
|
|
||||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
|
||||||
if (DL >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
double product = Util::multIdenty();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
|
||||||
product += childLinks[i]->getMessage()[xi];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
|
||||||
product *= childLinks[i]->getMessage()[xi];
|
|
||||||
if (DL >= 5) {
|
|
||||||
if (i != 0) *calcs1 << "." ;
|
|
||||||
if (i != 0) *calcs2 << "*" ;
|
|
||||||
*calcs1 << childLinks[i]->toString (xi);
|
|
||||||
*calcs2 << childLinks[i]->getMessage()[xi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
lambdaValues[xi] = product;
|
|
||||||
if (DL >= 5) {
|
|
||||||
cout << " " << LD_SYMBOL << "(" << x->label() << ")" ;
|
|
||||||
cout << "[" << x->states()[xi] << "]" ;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
if (childLinks.size() == 1) {
|
|
||||||
cout << " = " << product << endl;
|
|
||||||
} else {
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << lambdaValues[xi] << endl;
|
|
||||||
}
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::calculatePiMessage (BpLink* link)
|
|
||||||
{
|
|
||||||
// πX(Zi)
|
|
||||||
BayesNode* z = link->getSource();
|
|
||||||
BayesNode* x = link->getDestination();
|
|
||||||
Params& zxPiNextMessage = link->getNextMessage();
|
|
||||||
const BpLinkSet& zChildLinks = ninf(z)->getIncomingChildLinks();
|
|
||||||
stringstream* calcs1 = 0;
|
|
||||||
stringstream* calcs2 = 0;
|
|
||||||
|
|
||||||
const Params& zPiValues = ninf(z)->getPiValues();
|
|
||||||
for (unsigned zi = 0; zi < z->nrStates(); zi++) {
|
|
||||||
double product = zPiValues[zi];
|
|
||||||
if (DL >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
*calcs1 << PI_SYMBOL << "(" << z->label() << ")";
|
|
||||||
*calcs1 << "[" << z->states()[zi] << "]" ;
|
|
||||||
*calcs2 << product;
|
|
||||||
}
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
|
||||||
if (zChildLinks[i]->getSource() != x) {
|
|
||||||
product += zChildLinks[i]->getMessage()[zi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
|
||||||
if (zChildLinks[i]->getSource() != x) {
|
|
||||||
product *= zChildLinks[i]->getMessage()[zi];
|
|
||||||
if (DL >= 5) {
|
|
||||||
*calcs1 << "." << zChildLinks[i]->toString (zi);
|
|
||||||
*calcs2 << " * " << zChildLinks[i]->getMessage()[zi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
zxPiNextMessage[zi] = product;
|
|
||||||
if (DL >= 5) {
|
|
||||||
cout << " " << link->toString();
|
|
||||||
cout << "[" << z->states()[zi] << "]" ;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
if (zChildLinks.size() == 1) {
|
|
||||||
cout << " = " << product << endl;
|
|
||||||
} else {
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << product << endl;
|
|
||||||
}
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Util::normalize (zxPiNextMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::calculateLambdaMessage (BpLink* link)
|
|
||||||
{
|
|
||||||
// λY(Xi)
|
|
||||||
BayesNode* y = link->getSource();
|
|
||||||
BayesNode* x = link->getDestination();
|
|
||||||
if (x->hasEvidence()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Params& yxLambdaNextMessage = link->getNextMessage();
|
|
||||||
const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks();
|
|
||||||
const Params& yLambdaValues = ninf(y)->getLambdaValues();
|
|
||||||
int parentIndex = y->getIndexOfParent (x);
|
|
||||||
stringstream* calcs1 = 0;
|
|
||||||
stringstream* calcs2 = 0;
|
|
||||||
|
|
||||||
const BnNodeSet& ps = y->getParents();
|
|
||||||
Ranges ranges;
|
|
||||||
for (unsigned i = 0; i < ps.size(); i++) {
|
|
||||||
ranges.push_back (ps[i]->nrStates());
|
|
||||||
}
|
|
||||||
StatesIndexer indexer (ranges, false);
|
|
||||||
|
|
||||||
|
|
||||||
unsigned N = indexer.size() / x->nrStates();
|
|
||||||
Params messageProducts (N);
|
|
||||||
for (unsigned k = 0; k < N; k++) {
|
|
||||||
while (indexer[parentIndex] != 0) {
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
if (DL >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
double messageProduct = Util::multIdenty();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
|
||||||
if (yParentLinks[i]->getSource() != x) {
|
|
||||||
messageProduct += yParentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
|
||||||
if (yParentLinks[i]->getSource() != x) {
|
|
||||||
if (DL >= 5) {
|
|
||||||
if (messageProduct != Util::multIdenty()) *calcs1 << "*" ;
|
|
||||||
if (messageProduct != Util::multIdenty()) *calcs2 << "*" ;
|
|
||||||
*calcs1 << yParentLinks[i]->toString (indexer[i]);
|
|
||||||
*calcs2 << yParentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
messageProduct *= yParentLinks[i]->getMessage()[indexer[i]];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
messageProducts[k] = messageProduct;
|
|
||||||
++ indexer;
|
|
||||||
if (DL >= 5) {
|
|
||||||
cout << " mp" << k;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
if (yParentLinks.size() == 1) {
|
|
||||||
cout << 1 << endl;
|
|
||||||
} else if (yParentLinks.size() == 2) {
|
|
||||||
cout << " = " << messageProduct << endl;
|
|
||||||
} else {
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << messageProduct << endl;
|
|
||||||
}
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
|
||||||
if (DL >= 5) {
|
|
||||||
calcs1 = new stringstream;
|
|
||||||
calcs2 = new stringstream;
|
|
||||||
}
|
|
||||||
double outerSum = Util::addIdenty();
|
|
||||||
for (unsigned yi = 0; yi < y->nrStates(); yi++) {
|
|
||||||
if (DL >= 5) {
|
|
||||||
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
|
|
||||||
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
|
|
||||||
}
|
|
||||||
double innerSum = Util::addIdenty();
|
|
||||||
indexer.reset();
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned k = 0; k < N; k++) {
|
|
||||||
while (indexer[parentIndex] != xi) {
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
Util::logSum (innerSum, y->getProbability (
|
|
||||||
yi, indexer.linearIndex()) + messageProducts[k]);
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
Util::logSum (outerSum, innerSum + yLambdaValues[yi]);
|
|
||||||
} else {
|
|
||||||
for (unsigned k = 0; k < N; k++) {
|
|
||||||
while (indexer[parentIndex] != xi) {
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
if (DL >= 5) {
|
|
||||||
if (k != 0) *calcs1 << " + " ;
|
|
||||||
if (k != 0) *calcs2 << " + " ;
|
|
||||||
*calcs1 << y->cptEntryToString (yi, indexer.indices());
|
|
||||||
*calcs1 << ".mp" << k;
|
|
||||||
*calcs2 << y->getProbability (yi, indexer.linearIndex());
|
|
||||||
*calcs2 << "*" << messageProducts[k];
|
|
||||||
}
|
|
||||||
innerSum += y->getProbability (
|
|
||||||
yi, indexer.linearIndex()) * messageProducts[k];
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
outerSum += innerSum * yLambdaValues[yi];
|
|
||||||
}
|
|
||||||
if (DL >= 5) {
|
|
||||||
*calcs1 << "}." << LD_SYMBOL << "(" << y->label() << ")" ;
|
|
||||||
*calcs1 << "[" << y->states()[yi] << "]";
|
|
||||||
*calcs2 << "}*" << yLambdaValues[yi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
yxLambdaNextMessage[xi] = outerSum;
|
|
||||||
if (DL >= 5) {
|
|
||||||
cout << " " << link->toString();
|
|
||||||
cout << "[" << x->states()[xi] << "]" ;
|
|
||||||
cout << " = " << (*calcs1).str();
|
|
||||||
cout << " = " << (*calcs2).str();
|
|
||||||
cout << " = " << yxLambdaNextMessage[xi] << endl;
|
|
||||||
delete calcs1;
|
|
||||||
delete calcs2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Util::normalize (yxLambdaNextMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BnBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|
||||||
{
|
|
||||||
BnNodeSet jointVars;
|
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
|
||||||
assert (bayesNet_->getBayesNode (jointVarIds[i]));
|
|
||||||
jointVars.push_back (bayesNet_->getBayesNode (jointVarIds[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
BayesNet* mrn = bayesNet_->getMinimalRequesiteNetwork (jointVarIds[0]);
|
|
||||||
BnBpSolver solver (*mrn);
|
|
||||||
solver.runSolver();
|
|
||||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
|
||||||
delete mrn;
|
|
||||||
|
|
||||||
VarIds observedVids = {jointVars[0]->varId()};
|
|
||||||
|
|
||||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
|
||||||
assert (jointVars[i]->hasEvidence() == false);
|
|
||||||
VarIds reqVars = {jointVarIds[i]};
|
|
||||||
reqVars.insert (reqVars.end(), observedVids.begin(), observedVids.end());
|
|
||||||
mrn = bayesNet_->getMinimalRequesiteNetwork (reqVars);
|
|
||||||
Params newBeliefs;
|
|
||||||
VarNodes observedVars;
|
|
||||||
for (unsigned j = 0; j < observedVids.size(); j++) {
|
|
||||||
observedVars.push_back (mrn->getBayesNode (observedVids[j]));
|
|
||||||
}
|
|
||||||
StatesIndexer idx (observedVars, false);
|
|
||||||
while (idx.valid()) {
|
|
||||||
for (unsigned j = 0; j < observedVars.size(); j++) {
|
|
||||||
observedVars[j]->setEvidence (idx[j]);
|
|
||||||
}
|
|
||||||
BnBpSolver solver (*mrn);
|
|
||||||
solver.runSolver();
|
|
||||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
|
||||||
for (unsigned k = 0; k < beliefs.size(); k++) {
|
|
||||||
newBeliefs.push_back (beliefs[k]);
|
|
||||||
}
|
|
||||||
++ idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
int count = -1;
|
|
||||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
|
||||||
if (j % jointVars[i]->nrStates() == 0) {
|
|
||||||
count ++;
|
|
||||||
}
|
|
||||||
newBeliefs[j] *= prevBeliefs[count];
|
|
||||||
}
|
|
||||||
prevBeliefs = newBeliefs;
|
|
||||||
observedVids.push_back (jointVars[i]->varId());
|
|
||||||
delete mrn;
|
|
||||||
}
|
|
||||||
return prevBeliefs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::printPiLambdaValues (const BayesNode* var) const
|
|
||||||
{
|
|
||||||
cout << left;
|
|
||||||
cout << setw (10) << "states" ;
|
|
||||||
cout << setw (20) << PI_SYMBOL << "(" + var->label() + ")" ;
|
|
||||||
cout << setw (20) << LD_SYMBOL << "(" + var->label() + ")" ;
|
|
||||||
cout << setw (16) << "belief" ;
|
|
||||||
cout << endl;
|
|
||||||
cout << "--------------------------------" ;
|
|
||||||
cout << "--------------------------------" ;
|
|
||||||
cout << endl;
|
|
||||||
const States& states = var->states();
|
|
||||||
const Params& piVals = ninf(var)->getPiValues();
|
|
||||||
const Params& ldVals = ninf(var)->getLambdaValues();
|
|
||||||
const Params& beliefs = ninf(var)->getBeliefs();
|
|
||||||
for (unsigned xi = 0; xi < var->nrStates(); xi++) {
|
|
||||||
cout << setw (10) << states[xi];
|
|
||||||
cout << setw (19) << piVals[xi];
|
|
||||||
cout << setw (19) << ldVals[xi];
|
|
||||||
cout.precision (PRECISION);
|
|
||||||
cout << setw (16) << beliefs[xi];
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
BnBpSolver::printAllMessageStatus (void) const
|
|
||||||
{
|
|
||||||
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
printPiLambdaValues (nodes[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BpNodeInfo::BpNodeInfo (BayesNode* node)
|
|
||||||
{
|
|
||||||
node_ = node;
|
|
||||||
piVals_.resize (node->nrStates(), Util::one());
|
|
||||||
ldVals_.resize (node->nrStates(), Util::one());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
BpNodeInfo::getBeliefs (void) const
|
|
||||||
{
|
|
||||||
double sum = 0.0;
|
|
||||||
Params beliefs (node_->nrStates());
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
|
||||||
beliefs[xi] = exp (piVals_[xi] + ldVals_[xi]);
|
|
||||||
sum += beliefs[xi];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
|
||||||
beliefs[xi] = piVals_[xi] * ldVals_[xi];
|
|
||||||
sum += beliefs[xi];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert (sum);
|
|
||||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
|
||||||
beliefs[xi] /= sum;
|
|
||||||
}
|
|
||||||
return beliefs;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
BpNodeInfo::receivedBottomInfluence (void) const
|
|
||||||
{
|
|
||||||
// if all lambda values are equal, then neither
|
|
||||||
// this node neither its descendents have evidence,
|
|
||||||
// we can use this to don't send lambda messages his parents
|
|
||||||
bool childInfluenced = false;
|
|
||||||
for (unsigned xi = 1; xi < node_->nrStates(); xi++) {
|
|
||||||
if (ldVals_[xi] != ldVals_[0]) {
|
|
||||||
childInfluenced = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return childInfluenced;
|
|
||||||
}
|
|
||||||
|
|
@ -1,245 +0,0 @@
|
|||||||
#ifndef HORUS_BNBPSOLVER_H
|
|
||||||
#define HORUS_BNBPSOLVER_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <set>
|
|
||||||
|
|
||||||
#include "Solver.h"
|
|
||||||
#include "BayesNet.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
#include "Util.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
class BpNodeInfo;
|
|
||||||
|
|
||||||
static const string PI_SYMBOL = "pi" ;
|
|
||||||
static const string LD_SYMBOL = "ld" ;
|
|
||||||
|
|
||||||
enum LinkOrientation {UP, DOWN};
|
|
||||||
|
|
||||||
class BpLink
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
BpLink (BayesNode* s, BayesNode* d, LinkOrientation o)
|
|
||||||
{
|
|
||||||
source_ = s;
|
|
||||||
destin_ = d;
|
|
||||||
orientation_ = o;
|
|
||||||
if (orientation_ == LinkOrientation::DOWN) {
|
|
||||||
v1_.resize (s->nrStates(), Util::tl (1.0 / s->nrStates()));
|
|
||||||
v2_.resize (s->nrStates(), Util::tl (1.0 / s->nrStates()));
|
|
||||||
} else {
|
|
||||||
v1_.resize (d->nrStates(), Util::tl (1.0 / d->nrStates()));
|
|
||||||
v2_.resize (d->nrStates(), Util::tl (1.0 / d->nrStates()));
|
|
||||||
}
|
|
||||||
currMsg_ = &v1_;
|
|
||||||
nextMsg_ = &v2_;
|
|
||||||
residual_ = 0;
|
|
||||||
msgSended_ = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateMessage (void)
|
|
||||||
{
|
|
||||||
swap (currMsg_, nextMsg_);
|
|
||||||
msgSended_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateResidual (void)
|
|
||||||
{
|
|
||||||
residual_ = Util::getMaxNorm (v1_, v2_);
|
|
||||||
}
|
|
||||||
|
|
||||||
string toString (void) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
if (orientation_ == LinkOrientation::DOWN) {
|
|
||||||
ss << PI_SYMBOL;
|
|
||||||
} else {
|
|
||||||
ss << LD_SYMBOL;
|
|
||||||
}
|
|
||||||
ss << "(" << source_->label();
|
|
||||||
ss << " --> " << destin_->label() << ")" ;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
string toString (unsigned stateIndex) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss << toString() << "[" ;
|
|
||||||
if (orientation_ == LinkOrientation::DOWN) {
|
|
||||||
ss << source_->states()[stateIndex] << "]" ;
|
|
||||||
} else {
|
|
||||||
ss << destin_->states()[stateIndex] << "]" ;
|
|
||||||
}
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
BayesNode* getSource (void) const { return source_; }
|
|
||||||
BayesNode* getDestination (void) const { return destin_; }
|
|
||||||
LinkOrientation getOrientation (void) const { return orientation_; }
|
|
||||||
const Params& getMessage (void) const { return *currMsg_; }
|
|
||||||
Params& getNextMessage (void) { return *nextMsg_; }
|
|
||||||
bool messageWasSended (void) const { return msgSended_; }
|
|
||||||
double getResidual (void) const { return residual_; }
|
|
||||||
void clearResidual (void) { residual_ = 0;}
|
|
||||||
|
|
||||||
private:
|
|
||||||
BayesNode* source_;
|
|
||||||
BayesNode* destin_;
|
|
||||||
LinkOrientation orientation_;
|
|
||||||
Params v1_;
|
|
||||||
Params v2_;
|
|
||||||
Params* currMsg_;
|
|
||||||
Params* nextMsg_;
|
|
||||||
bool msgSended_;
|
|
||||||
double residual_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
typedef vector<BpLink*> BpLinkSet;
|
|
||||||
|
|
||||||
|
|
||||||
class BpNodeInfo
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
BpNodeInfo (BayesNode*);
|
|
||||||
|
|
||||||
Params getBeliefs (void) const;
|
|
||||||
bool receivedBottomInfluence (void) const;
|
|
||||||
|
|
||||||
Params& getPiValues (void) { return piVals_; }
|
|
||||||
Params& getLambdaValues (void) { return ldVals_; }
|
|
||||||
|
|
||||||
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
|
|
||||||
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
|
|
||||||
const BpLinkSet& getOutcomingParentLinks (void) { return outParentLinks_; }
|
|
||||||
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
|
|
||||||
|
|
||||||
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
|
|
||||||
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
|
|
||||||
void addOutcomingParentLink (BpLink* l) { outParentLinks_.push_back (l); }
|
|
||||||
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
|
|
||||||
|
|
||||||
private:
|
|
||||||
DISALLOW_COPY_AND_ASSIGN (BpNodeInfo);
|
|
||||||
|
|
||||||
const BayesNode* node_;
|
|
||||||
Params piVals_; // pi values
|
|
||||||
Params ldVals_; // lambda values
|
|
||||||
BpLinkSet inParentLinks_;
|
|
||||||
BpLinkSet inChildLinks_;
|
|
||||||
BpLinkSet outParentLinks_;
|
|
||||||
BpLinkSet outChildLinks_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BnBpSolver : public Solver
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
BnBpSolver (const BayesNet&);
|
|
||||||
~BnBpSolver (void);
|
|
||||||
|
|
||||||
void runSolver (void);
|
|
||||||
Params getPosterioriOf (VarId);
|
|
||||||
Params getJointDistributionOf (const VarIds&);
|
|
||||||
|
|
||||||
|
|
||||||
private:
|
|
||||||
DISALLOW_COPY_AND_ASSIGN (BnBpSolver);
|
|
||||||
|
|
||||||
void initializeSolver (void);
|
|
||||||
void runLoopySolver (void);
|
|
||||||
void maxResidualSchedule (void);
|
|
||||||
bool converged (void) const;
|
|
||||||
void updatePiValues (BayesNode*);
|
|
||||||
void updateLambdaValues (BayesNode*);
|
|
||||||
void calculateLambdaMessage (BpLink*);
|
|
||||||
void calculatePiMessage (BpLink*);
|
|
||||||
Params getJointByJunctionNode (const VarIds&);
|
|
||||||
Params getJointByConditioning (const VarIds&) const;
|
|
||||||
void printPiLambdaValues (const BayesNode*) const;
|
|
||||||
void printAllMessageStatus (void) const;
|
|
||||||
|
|
||||||
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
|
|
||||||
{
|
|
||||||
if (DL >= 3) {
|
|
||||||
cout << "calculating & updating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
|
||||||
calculatePiMessage (link);
|
|
||||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
|
||||||
calculateLambdaMessage (link);
|
|
||||||
}
|
|
||||||
if (calcResidual) {
|
|
||||||
link->updateResidual();
|
|
||||||
}
|
|
||||||
link->updateMessage();
|
|
||||||
}
|
|
||||||
|
|
||||||
void calculateMessage (BpLink* link, bool calcResidual = true)
|
|
||||||
{
|
|
||||||
if (DL >= 3) {
|
|
||||||
cout << "calculating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
|
||||||
calculatePiMessage (link);
|
|
||||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
|
||||||
calculateLambdaMessage (link);
|
|
||||||
}
|
|
||||||
if (calcResidual) {
|
|
||||||
link->updateResidual();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateMessage (BpLink* link)
|
|
||||||
{
|
|
||||||
if (DL >= 3) {
|
|
||||||
cout << "updating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
link->updateMessage();
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateValues (BpLink* link)
|
|
||||||
{
|
|
||||||
if (!link->getDestination()->hasEvidence()) {
|
|
||||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
|
||||||
updatePiValues (link->getDestination());
|
|
||||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
|
||||||
updateLambdaValues (link->getDestination());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
BpNodeInfo* ninf (const BayesNode* node) const
|
|
||||||
{
|
|
||||||
assert (node);
|
|
||||||
assert (node == bayesNet_->getBayesNode (node->varId()));
|
|
||||||
assert (node->getIndex() < nodesI_.size());
|
|
||||||
return nodesI_[node->getIndex()];
|
|
||||||
}
|
|
||||||
|
|
||||||
const BayesNet* bayesNet_;
|
|
||||||
vector<BpLink*> links_;
|
|
||||||
vector<BpNodeInfo*> nodesI_;
|
|
||||||
unsigned nIters_;
|
|
||||||
|
|
||||||
struct compare
|
|
||||||
{
|
|
||||||
inline bool operator() (const BpLink* e1, const BpLink* e2)
|
|
||||||
{
|
|
||||||
return e1->getResidual() > e2->getResidual();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef multiset<BpLink*, compare> SortedOrder;
|
|
||||||
SortedOrder sortedOrder_;
|
|
||||||
|
|
||||||
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
|
||||||
BpLinkMap linkMap_;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_BNBPSOLVER_H
|
|
||||||
|
|
@ -5,21 +5,22 @@
|
|||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "FgBpSolver.h"
|
#include "BpSolver.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
FgBpSolver::FgBpSolver (const FactorGraph& fg) : Solver (&fg)
|
BpSolver::BpSolver (const FactorGraph& fg) : Solver (fg)
|
||||||
{
|
{
|
||||||
factorGraph_ = &fg;
|
fg_ = &fg;
|
||||||
|
runned_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FgBpSolver::~FgBpSolver (void)
|
BpSolver::~BpSolver (void)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||||
delete varsI_[i];
|
delete varsI_[i];
|
||||||
@ -34,64 +35,45 @@ FgBpSolver::~FgBpSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
Params
|
||||||
FgBpSolver::runSolver (void)
|
BpSolver::solveQuery (VarIds queryVids)
|
||||||
{
|
{
|
||||||
clock_t start;
|
assert (queryVids.empty() == false);
|
||||||
if (COLLECT_STATISTICS) {
|
if (queryVids.size() == 1) {
|
||||||
start = clock();
|
return getPosterioriOf (queryVids[0]);
|
||||||
}
|
} else {
|
||||||
runLoopySolver();
|
return getJointDistributionOf (queryVids);
|
||||||
if (DL >= 2) {
|
|
||||||
cout << endl;
|
|
||||||
if (nIters_ < BpOptions::maxIter) {
|
|
||||||
cout << "Sum-Product converged in " ;
|
|
||||||
cout << nIters_ << " iterations" << endl;
|
|
||||||
} else {
|
|
||||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
unsigned size = factorGraph_->getVarNodes().size();
|
|
||||||
if (COLLECT_STATISTICS) {
|
|
||||||
unsigned nIters = 0;
|
|
||||||
bool loopy = factorGraph_->isTree() == false;
|
|
||||||
if (loopy) nIters = nIters_;
|
|
||||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
|
||||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
|
||||||
}
|
|
||||||
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
|
|
||||||
stringstream ss;
|
|
||||||
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
|
|
||||||
factorGraph_->exportToGraphViz (ss.str().c_str());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
FgBpSolver::getPosterioriOf (VarId vid)
|
BpSolver::getPosterioriOf (VarId vid)
|
||||||
{
|
{
|
||||||
assert (factorGraph_->getFgVarNode (vid));
|
if (runned_ == false) {
|
||||||
FgVarNode* var = factorGraph_->getFgVarNode (vid);
|
runSolver();
|
||||||
|
}
|
||||||
|
assert (fg_->getVarNode (vid));
|
||||||
|
VarNode* var = fg_->getVarNode (vid);
|
||||||
Params probs;
|
Params probs;
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
probs.resize (var->nrStates(), Util::noEvidence());
|
probs.resize (var->range(), LogAware::noEvidence());
|
||||||
probs[var->getEvidence()] = Util::withEvidence();
|
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||||
} else {
|
} else {
|
||||||
probs.resize (var->nrStates(), Util::multIdenty());
|
probs.resize (var->range(), LogAware::multIdenty());
|
||||||
const SpLinkSet& links = ninf(var)->getLinks();
|
const SpLinkSet& links = ninf(var)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
Util::add (probs, links[i]->getMessage());
|
Util::add (probs, links[i]->getMessage());
|
||||||
}
|
}
|
||||||
Util::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
Util::fromLog (probs);
|
Util::fromLog (probs);
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
Util::multiply (probs, links[i]->getMessage());
|
Util::multiply (probs, links[i]->getMessage());
|
||||||
}
|
}
|
||||||
Util::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return probs;
|
return probs;
|
||||||
@ -100,13 +82,16 @@ FgBpSolver::getPosterioriOf (VarId vid)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
BpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
||||||
{
|
{
|
||||||
FgVarNode* vn = factorGraph_->getFgVarNode (jointVarIds[0]);
|
if (runned_ == false) {
|
||||||
const FgFacSet& factorNodes = vn->neighbors();
|
runSolver();
|
||||||
|
}
|
||||||
int idx = -1;
|
int idx = -1;
|
||||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
VarNode* vn = fg_->getVarNode (jointVarIds[0]);
|
||||||
if (factorNodes[i]->factor()->contains (jointVarIds)) {
|
const FacNodes& facNodes = vn->neighbors();
|
||||||
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
|
if (facNodes[i]->factor().contains (jointVarIds)) {
|
||||||
idx = i;
|
idx = i;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -114,18 +99,18 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
if (idx == -1) {
|
if (idx == -1) {
|
||||||
return getJointByConditioning (jointVarIds);
|
return getJointByConditioning (jointVarIds);
|
||||||
} else {
|
} else {
|
||||||
Factor r (*factorNodes[idx]->factor());
|
Factor res (facNodes[idx]->factor());
|
||||||
const SpLinkSet& links = ninf(factorNodes[idx])->getLinks();
|
const SpLinkSet& links = ninf(facNodes[idx])->getLinks();
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
Factor msg (links[i]->getVariable()->varId(),
|
Factor msg ({links[i]->getVariable()->varId()},
|
||||||
links[i]->getVariable()->nrStates(),
|
{links[i]->getVariable()->range()},
|
||||||
getVar2FactorMsg (links[i]));
|
getVar2FactorMsg (links[i]));
|
||||||
r.multiply (msg);
|
res.multiply (msg);
|
||||||
}
|
}
|
||||||
r.sumOutAllExcept (jointVarIds);
|
res.sumOutAllExcept (jointVarIds);
|
||||||
r.reorderVariables (jointVarIds);
|
res.reorderArguments (jointVarIds);
|
||||||
r.normalize();
|
res.normalize();
|
||||||
Params jointDist = r.getParameters();
|
Params jointDist = res.params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (jointDist);
|
Util::fromLog (jointDist);
|
||||||
}
|
}
|
||||||
@ -136,35 +121,29 @@ FgBpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FgBpSolver::runLoopySolver (void)
|
BpSolver::runSolver (void)
|
||||||
{
|
{
|
||||||
|
clock_t start;
|
||||||
|
if (Constants::COLLECT_STATS) {
|
||||||
|
start = clock();
|
||||||
|
}
|
||||||
initializeSolver();
|
initializeSolver();
|
||||||
nIters_ = 0;
|
nIters_ = 0;
|
||||||
|
|
||||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||||
|
|
||||||
nIters_ ++;
|
nIters_ ++;
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << "****************************************" ;
|
Util::printHeader (string ("Iteration ") + Util::toString (nIters_));
|
||||||
cout << "****************************************" ;
|
// cout << endl;
|
||||||
cout << endl;
|
|
||||||
cout << " Iteration " << nIters_ << endl;
|
|
||||||
cout << "****************************************" ;
|
|
||||||
cout << "****************************************" ;
|
|
||||||
cout << endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (BpOptions::schedule) {
|
switch (BpOptions::schedule) {
|
||||||
case BpOptions::Schedule::SEQ_RANDOM:
|
case BpOptions::Schedule::SEQ_RANDOM:
|
||||||
random_shuffle (links_.begin(), links_.end());
|
random_shuffle (links_.begin(), links_.end());
|
||||||
// no break
|
// no break
|
||||||
|
|
||||||
case BpOptions::Schedule::SEQ_FIXED:
|
case BpOptions::Schedule::SEQ_FIXED:
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
calculateAndUpdateMessage (links_[i]);
|
calculateAndUpdateMessage (links_[i]);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case BpOptions::Schedule::PARALLEL:
|
case BpOptions::Schedule::PARALLEL:
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
calculateMessage (links_[i]);
|
calculateMessage (links_[i]);
|
||||||
@ -173,61 +152,43 @@ FgBpSolver::runLoopySolver (void)
|
|||||||
updateMessage(links_[i]);
|
updateMessage(links_[i]);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||||
maxResidualSchedule();
|
maxResidualSchedule();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (Constants::DEBUG >= 2) {
|
||||||
|
cout << endl;
|
||||||
|
if (nIters_ < BpOptions::maxIter) {
|
||||||
|
cout << "Sum-Product converged in " ;
|
||||||
|
cout << nIters_ << " iterations" << endl;
|
||||||
|
} else {
|
||||||
|
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
unsigned size = fg_->varNodes().size();
|
||||||
|
if (Constants::COLLECT_STATS) {
|
||||||
|
unsigned nIters = 0;
|
||||||
|
bool loopy = fg_->isTree() == false;
|
||||||
|
if (loopy) nIters = nIters_;
|
||||||
|
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||||
|
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||||
|
}
|
||||||
|
runned_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FgBpSolver::initializeSolver (void)
|
BpSolver::createLinks (void)
|
||||||
{
|
{
|
||||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
const FacNodes& facNodes = fg_->facNodes();
|
||||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
|
||||||
delete varsI_[i];
|
|
||||||
}
|
|
||||||
varsI_.reserve (varNodes.size());
|
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
|
||||||
varsI_.push_back (new SPNodeInfo());
|
|
||||||
}
|
|
||||||
|
|
||||||
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
|
|
||||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
|
||||||
delete facsI_[i];
|
|
||||||
}
|
|
||||||
facsI_.reserve (facNodes.size());
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
facsI_.push_back (new SPNodeInfo());
|
const VarNodes& neighbors = facNodes[i]->neighbors();
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
delete links_[i];
|
|
||||||
}
|
|
||||||
createLinks();
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
FgFacNode* src = links_[i]->getFactor();
|
|
||||||
FgVarNode* dst = links_[i]->getVariable();
|
|
||||||
ninf (dst)->addSpLink (links_[i]);
|
|
||||||
ninf (src)->addSpLink (links_[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FgBpSolver::createLinks (void)
|
|
||||||
{
|
|
||||||
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
|
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
|
||||||
const FgVarSet& neighbors = facNodes[i]->neighbors();
|
|
||||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||||
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
|
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
|
||||||
}
|
}
|
||||||
@ -236,42 +197,8 @@ FgBpSolver::createLinks (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
FgBpSolver::converged (void)
|
|
||||||
{
|
|
||||||
if (links_.size() == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (nIters_ == 0 || nIters_ == 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool converged = true;
|
|
||||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
|
||||||
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
|
||||||
if (maxResidual > BpOptions::accuracy) {
|
|
||||||
converged = false;
|
|
||||||
} else {
|
|
||||||
converged = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
|
||||||
double residual = links_[i]->getResidual();
|
|
||||||
if (DL >= 2) {
|
|
||||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
|
||||||
}
|
|
||||||
if (residual > BpOptions::accuracy) {
|
|
||||||
converged = false;
|
|
||||||
if (DL == 0) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return converged;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FgBpSolver::maxResidualSchedule (void)
|
BpSolver::maxResidualSchedule (void)
|
||||||
{
|
{
|
||||||
if (nIters_ == 1) {
|
if (nIters_ == 1) {
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
@ -283,7 +210,7 @@ FgBpSolver::maxResidualSchedule (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned c = 0; c < links_.size(); c++) {
|
for (unsigned c = 0; c < links_.size(); c++) {
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << "current residuals:" << endl;
|
cout << "current residuals:" << endl;
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
it != sortedOrder_.end(); it ++) {
|
it != sortedOrder_.end(); it ++) {
|
||||||
@ -303,7 +230,7 @@ FgBpSolver::maxResidualSchedule (void)
|
|||||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||||
|
|
||||||
// update the messages that depend on message source --> destin
|
// update the messages that depend on message source --> destin
|
||||||
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
|
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
|
||||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||||
if (factorNeighbors[i] != link->getFactor()) {
|
if (factorNeighbors[i] != link->getFactor()) {
|
||||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||||
@ -317,9 +244,8 @@ FgBpSolver::maxResidualSchedule (void)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << "----------------------------------------" ;
|
Util::printDashedLine();
|
||||||
cout << "----------------------------------------" << endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -327,26 +253,26 @@ FgBpSolver::maxResidualSchedule (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
BpSolver::calculateFactor2VariableMsg (SpLink* link)
|
||||||
{
|
{
|
||||||
const FgFacNode* src = link->getFactor();
|
FacNode* src = link->getFactor();
|
||||||
const FgVarNode* dst = link->getVariable();
|
const VarNode* dst = link->getVariable();
|
||||||
const SpLinkSet& links = ninf(src)->getLinks();
|
const SpLinkSet& links = ninf(src)->getLinks();
|
||||||
// calculate the product of messages that were sent
|
// calculate the product of messages that were sent
|
||||||
// to factor `src', except from var `dst'
|
// to factor `src', except from var `dst'
|
||||||
unsigned msgSize = 1;
|
unsigned msgSize = 1;
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
msgSize *= links[i]->getVariable()->nrStates();
|
msgSize *= links[i]->getVariable()->range();
|
||||||
}
|
}
|
||||||
unsigned repetitions = 1;
|
unsigned repetitions = 1;
|
||||||
Params msgProduct (msgSize, Util::multIdenty());
|
Params msgProduct (msgSize, LogAware::multIdenty());
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (int i = links.size() - 1; i >= 0; i--) {
|
for (int i = links.size() - 1; i >= 0; i--) {
|
||||||
if (links[i]->getVariable() != dst) {
|
if (links[i]->getVariable() != dst) {
|
||||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||||
repetitions *= links[i]->getVariable()->nrStates();
|
repetitions *= links[i]->getVariable()->range();
|
||||||
} else {
|
} else {
|
||||||
unsigned ds = links[i]->getVariable()->nrStates();
|
unsigned ds = links[i]->getVariable()->range();
|
||||||
Util::add (msgProduct, Params (ds, 1.0), repetitions);
|
Util::add (msgProduct, Params (ds, 1.0), repetitions);
|
||||||
repetitions *= ds;
|
repetitions *= ds;
|
||||||
}
|
}
|
||||||
@ -354,70 +280,64 @@ FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
|||||||
} else {
|
} else {
|
||||||
for (int i = links.size() - 1; i >= 0; i--) {
|
for (int i = links.size() - 1; i >= 0; i--) {
|
||||||
if (links[i]->getVariable() != dst) {
|
if (links[i]->getVariable() != dst) {
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " message from " << links[i]->getVariable()->label();
|
cout << " message from " << links[i]->getVariable()->label();
|
||||||
cout << ": " << endl;
|
cout << ": " << endl;
|
||||||
}
|
}
|
||||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||||
repetitions *= links[i]->getVariable()->nrStates();
|
repetitions *= links[i]->getVariable()->range();
|
||||||
} else {
|
} else {
|
||||||
unsigned ds = links[i]->getVariable()->nrStates();
|
unsigned ds = links[i]->getVariable()->range();
|
||||||
Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
|
Util::multiply (msgProduct, Params (ds, 1.0), repetitions);
|
||||||
repetitions *= ds;
|
repetitions *= ds;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Factor result (src->factor()->getVarIds(),
|
Factor result (src->factor().arguments(),
|
||||||
src->factor()->getRanges(),
|
src->factor().ranges(), msgProduct);
|
||||||
msgProduct);
|
result.multiply (src->factor());
|
||||||
result.multiply (*(src->factor()));
|
if (Constants::DEBUG >= 5) {
|
||||||
if (DL >= 5) {
|
cout << " message product: " << msgProduct << endl;
|
||||||
cout << " message product: " ;
|
cout << " original factor: " << src->factor().params() << endl;
|
||||||
cout << Util::parametersToString (msgProduct) << endl;
|
cout << " factor product: " << result.params() << endl;
|
||||||
cout << " original factor: " ;
|
|
||||||
cout << Util::parametersToString (src->getParameters()) << endl;
|
|
||||||
cout << " factor product: " ;
|
|
||||||
cout << Util::parametersToString (result.getParameters()) << endl;
|
|
||||||
}
|
}
|
||||||
result.sumOutAllExcept (dst->varId());
|
result.sumOutAllExcept (dst->varId());
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " marginalized: " ;
|
cout << " marginalized: " ;
|
||||||
cout << Util::parametersToString (result.getParameters()) << endl;
|
cout << result.params() << endl;
|
||||||
}
|
}
|
||||||
const Params& resultParams = result.getParameters();
|
const Params& resultParams = result.params();
|
||||||
Params& message = link->getNextMessage();
|
Params& message = link->getNextMessage();
|
||||||
for (unsigned i = 0; i < resultParams.size(); i++) {
|
for (unsigned i = 0; i < resultParams.size(); i++) {
|
||||||
message[i] = resultParams[i];
|
message[i] = resultParams[i];
|
||||||
}
|
}
|
||||||
Util::normalize (message);
|
LogAware::normalize (message);
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " curr msg: " ;
|
cout << " curr msg: " << link->getMessage() << endl;
|
||||||
cout << Util::parametersToString (link->getMessage()) << endl;
|
cout << " next msg: " << message << endl;
|
||||||
cout << " next msg: " ;
|
|
||||||
cout << Util::parametersToString (message) << endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
BpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||||
{
|
{
|
||||||
const FgVarNode* src = link->getVariable();
|
const VarNode* src = link->getVariable();
|
||||||
const FgFacNode* dst = link->getFactor();
|
const FacNode* dst = link->getFactor();
|
||||||
Params msg;
|
Params msg;
|
||||||
if (src->hasEvidence()) {
|
if (src->hasEvidence()) {
|
||||||
msg.resize (src->nrStates(), Util::noEvidence());
|
msg.resize (src->range(), LogAware::noEvidence());
|
||||||
msg[src->getEvidence()] = Util::withEvidence();
|
msg[src->getEvidence()] = LogAware::withEvidence();
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << Util::parametersToString (msg);
|
cout << msg;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
msg.resize (src->nrStates(), Util::one());
|
msg.resize (src->range(), LogAware::one());
|
||||||
}
|
}
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << Util::parametersToString (msg);
|
cout << msg;
|
||||||
}
|
}
|
||||||
const SpLinkSet& links = ninf (src)->getLinks();
|
const SpLinkSet& links = ninf (src)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
@ -430,14 +350,14 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getFactor() != dst) {
|
if (links[i]->getFactor() != dst) {
|
||||||
Util::multiply (msg, links[i]->getMessage());
|
Util::multiply (msg, links[i]->getMessage());
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " x " << Util::parametersToString (links[i]->getMessage());
|
cout << " x " << links[i]->getMessage();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " = " << Util::parametersToString (msg);
|
cout << " = " << msg;
|
||||||
}
|
}
|
||||||
return msg;
|
return msg;
|
||||||
}
|
}
|
||||||
@ -445,16 +365,16 @@ FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
BpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
||||||
{
|
{
|
||||||
FgVarSet jointVars;
|
VarNodes jointVars;
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||||
assert (factorGraph_->getFgVarNode (jointVarIds[i]));
|
assert (fg_->getVarNode (jointVarIds[i]));
|
||||||
jointVars.push_back (factorGraph_->getFgVarNode (jointVarIds[i]));
|
jointVars.push_back (fg_->getVarNode (jointVarIds[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
FactorGraph* fg = new FactorGraph (*factorGraph_);
|
FactorGraph* fg = new FactorGraph (*fg_);
|
||||||
FgBpSolver solver (*fg);
|
BpSolver solver (*fg);
|
||||||
solver.runSolver();
|
solver.runSolver();
|
||||||
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
Params prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||||
|
|
||||||
@ -463,9 +383,9 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
||||||
assert (jointVars[i]->hasEvidence() == false);
|
assert (jointVars[i]->hasEvidence() == false);
|
||||||
Params newBeliefs;
|
Params newBeliefs;
|
||||||
VarNodes observedVars;
|
Vars observedVars;
|
||||||
for (unsigned j = 0; j < observedVids.size(); j++) {
|
for (unsigned j = 0; j < observedVids.size(); j++) {
|
||||||
observedVars.push_back (fg->getFgVarNode (observedVids[j]));
|
observedVars.push_back (fg->getVarNode (observedVids[j]));
|
||||||
}
|
}
|
||||||
StatesIndexer idx (observedVars, false);
|
StatesIndexer idx (observedVars, false);
|
||||||
while (idx.valid()) {
|
while (idx.valid()) {
|
||||||
@ -473,7 +393,7 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
observedVars[j]->setEvidence (idx[j]);
|
observedVars[j]->setEvidence (idx[j]);
|
||||||
}
|
}
|
||||||
++ idx;
|
++ idx;
|
||||||
FgBpSolver solver (*fg);
|
BpSolver solver (*fg);
|
||||||
solver.runSolver();
|
solver.runSolver();
|
||||||
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
Params beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
||||||
for (unsigned k = 0; k < beliefs.size(); k++) {
|
for (unsigned k = 0; k < beliefs.size(); k++) {
|
||||||
@ -483,7 +403,7 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
|
|
||||||
int count = -1;
|
int count = -1;
|
||||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
||||||
if (j % jointVars[i]->nrStates() == 0) {
|
if (j % jointVars[i]->range() == 0) {
|
||||||
count ++;
|
count ++;
|
||||||
}
|
}
|
||||||
newBeliefs[j] *= prevBeliefs[count];
|
newBeliefs[j] *= prevBeliefs[count];
|
||||||
@ -497,15 +417,76 @@ FgBpSolver::getJointByConditioning (const VarIds& jointVarIds) const
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FgBpSolver::printLinkInformation (void) const
|
BpSolver::initializeSolver (void)
|
||||||
|
{
|
||||||
|
const VarNodes& varNodes = fg_->varNodes();
|
||||||
|
varsI_.reserve (varNodes.size());
|
||||||
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
|
varsI_.push_back (new SPNodeInfo());
|
||||||
|
}
|
||||||
|
const FacNodes& facNodes = fg_->facNodes();
|
||||||
|
facsI_.reserve (facNodes.size());
|
||||||
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
|
facsI_.push_back (new SPNodeInfo());
|
||||||
|
}
|
||||||
|
createLinks();
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
FacNode* src = links_[i]->getFactor();
|
||||||
|
VarNode* dst = links_[i]->getVariable();
|
||||||
|
ninf (dst)->addSpLink (links_[i]);
|
||||||
|
ninf (src)->addSpLink (links_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
BpSolver::converged (void)
|
||||||
|
{
|
||||||
|
if (links_.size() == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (nIters_ <= 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool converged = true;
|
||||||
|
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||||
|
double maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||||
|
if (maxResidual > BpOptions::accuracy) {
|
||||||
|
converged = false;
|
||||||
|
} else {
|
||||||
|
converged = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
|
double residual = links_[i]->getResidual();
|
||||||
|
if (Constants::DEBUG >= 2) {
|
||||||
|
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||||
|
}
|
||||||
|
if (residual > BpOptions::accuracy) {
|
||||||
|
converged = false;
|
||||||
|
if (Constants::DEBUG == 0) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Constants::DEBUG >= 2) {
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return converged;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
BpSolver::printLinkInformation (void) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
SpLink* l = links_[i];
|
SpLink* l = links_[i];
|
||||||
cout << l->toString() << ":" << endl;
|
cout << l->toString() << ":" << endl;
|
||||||
cout << " curr msg = " ;
|
cout << " curr msg = " ;
|
||||||
cout << Util::parametersToString (l->getMessage()) << endl;
|
cout << l->getMessage() << endl;
|
||||||
cout << " next msg = " ;
|
cout << " next msg = " ;
|
||||||
cout << Util::parametersToString (l->getNextMessage()) << endl;
|
cout << l->getNextMessage() << endl;
|
||||||
cout << " residual = " << l->getResidual() << endl;
|
cout << " residual = " << l->getResidual() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
188
packages/CLPBN/clpbn/bp/BpSolver.h
Normal file
188
packages/CLPBN/clpbn/bp/BpSolver.h
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
#ifndef HORUS_BPSOLVER_H
|
||||||
|
#define HORUS_BPSOLVER_H
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "Solver.h"
|
||||||
|
#include "Factor.h"
|
||||||
|
#include "FactorGraph.h"
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
class SpLink
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
SpLink (FacNode* fn, VarNode* vn)
|
||||||
|
{
|
||||||
|
fac_ = fn;
|
||||||
|
var_ = vn;
|
||||||
|
v1_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||||
|
v2_.resize (vn->range(), LogAware::tl (1.0 / vn->range()));
|
||||||
|
currMsg_ = &v1_;
|
||||||
|
nextMsg_ = &v2_;
|
||||||
|
msgSended_ = false;
|
||||||
|
residual_ = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~SpLink (void) { };
|
||||||
|
|
||||||
|
FacNode* getFactor (void) const { return fac_; }
|
||||||
|
|
||||||
|
VarNode* getVariable (void) const { return var_; }
|
||||||
|
|
||||||
|
const Params& getMessage (void) const { return *currMsg_; }
|
||||||
|
|
||||||
|
Params& getNextMessage (void) { return *nextMsg_; }
|
||||||
|
|
||||||
|
bool messageWasSended (void) const { return msgSended_; }
|
||||||
|
|
||||||
|
double getResidual (void) const { return residual_; }
|
||||||
|
|
||||||
|
void clearResidual (void) { residual_ = 0.0; }
|
||||||
|
|
||||||
|
void updateResidual (void)
|
||||||
|
{
|
||||||
|
residual_ = LogAware::getMaxNorm (v1_,v2_);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void updateMessage (void)
|
||||||
|
{
|
||||||
|
swap (currMsg_, nextMsg_);
|
||||||
|
msgSended_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
string toString (void) const
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss << fac_->getLabel();
|
||||||
|
ss << " -- " ;
|
||||||
|
ss << var_->label();
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
FacNode* fac_;
|
||||||
|
VarNode* var_;
|
||||||
|
Params v1_;
|
||||||
|
Params v2_;
|
||||||
|
Params* currMsg_;
|
||||||
|
Params* nextMsg_;
|
||||||
|
bool msgSended_;
|
||||||
|
double residual_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef vector<SpLink*> SpLinkSet;
|
||||||
|
|
||||||
|
|
||||||
|
class SPNodeInfo
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
void addSpLink (SpLink* link) { links_.push_back (link); }
|
||||||
|
const SpLinkSet& getLinks (void) { return links_; }
|
||||||
|
private:
|
||||||
|
SpLinkSet links_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class BpSolver : public Solver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
BpSolver (const FactorGraph&);
|
||||||
|
|
||||||
|
virtual ~BpSolver (void);
|
||||||
|
|
||||||
|
Params solveQuery (VarIds);
|
||||||
|
|
||||||
|
virtual Params getPosterioriOf (VarId);
|
||||||
|
|
||||||
|
virtual Params getJointDistributionOf (const VarIds&);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void runSolver (void);
|
||||||
|
|
||||||
|
virtual void createLinks (void);
|
||||||
|
|
||||||
|
virtual void maxResidualSchedule (void);
|
||||||
|
|
||||||
|
virtual void calculateFactor2VariableMsg (SpLink*);
|
||||||
|
|
||||||
|
virtual Params getVar2FactorMsg (const SpLink*) const;
|
||||||
|
|
||||||
|
virtual Params getJointByConditioning (const VarIds&) const;
|
||||||
|
|
||||||
|
SPNodeInfo* ninf (const VarNode* var) const
|
||||||
|
{
|
||||||
|
return varsI_[var->getIndex()];
|
||||||
|
}
|
||||||
|
|
||||||
|
SPNodeInfo* ninf (const FacNode* fac) const
|
||||||
|
{
|
||||||
|
return facsI_[fac->getIndex()];
|
||||||
|
}
|
||||||
|
|
||||||
|
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
||||||
|
{
|
||||||
|
if (Constants::DEBUG >= 3) {
|
||||||
|
cout << "calculating & updating " << link->toString() << endl;
|
||||||
|
}
|
||||||
|
calculateFactor2VariableMsg (link);
|
||||||
|
if (calcResidual) {
|
||||||
|
link->updateResidual();
|
||||||
|
}
|
||||||
|
link->updateMessage();
|
||||||
|
}
|
||||||
|
|
||||||
|
void calculateMessage (SpLink* link, bool calcResidual = true)
|
||||||
|
{
|
||||||
|
if (Constants::DEBUG >= 3) {
|
||||||
|
cout << "calculating " << link->toString() << endl;
|
||||||
|
}
|
||||||
|
calculateFactor2VariableMsg (link);
|
||||||
|
if (calcResidual) {
|
||||||
|
link->updateResidual();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void updateMessage (SpLink* link)
|
||||||
|
{
|
||||||
|
link->updateMessage();
|
||||||
|
if (Constants::DEBUG >= 3) {
|
||||||
|
cout << "updating " << link->toString() << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CompareResidual
|
||||||
|
{
|
||||||
|
inline bool operator() (const SpLink* link1, const SpLink* link2)
|
||||||
|
{
|
||||||
|
return link1->getResidual() > link2->getResidual();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
SpLinkSet links_;
|
||||||
|
unsigned nIters_;
|
||||||
|
vector<SPNodeInfo*> varsI_;
|
||||||
|
vector<SPNodeInfo*> facsI_;
|
||||||
|
bool runned_;
|
||||||
|
const FactorGraph* fg_;
|
||||||
|
|
||||||
|
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
||||||
|
SortedOrder sortedOrder_;
|
||||||
|
|
||||||
|
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
|
||||||
|
SpLinkMap linkMap_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void initializeSolver (void);
|
||||||
|
|
||||||
|
bool converged (void);
|
||||||
|
|
||||||
|
void printLinkInformation (void) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // HORUS_BPSOLVER_H
|
||||||
|
|
@ -1,7 +1,6 @@
|
|||||||
|
|
||||||
#include "CFactorGraph.h"
|
#include "CFactorGraph.h"
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "Distribution.h"
|
|
||||||
|
|
||||||
|
|
||||||
bool CFactorGraph::checkForIdenticalFactors = true;
|
bool CFactorGraph::checkForIdenticalFactors = true;
|
||||||
@ -11,22 +10,22 @@ CFactorGraph::CFactorGraph (const FactorGraph& fg)
|
|||||||
groundFg_ = &fg;
|
groundFg_ = &fg;
|
||||||
freeColor_ = 0;
|
freeColor_ = 0;
|
||||||
|
|
||||||
const FgVarSet& varNodes = fg.getVarNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
varSignatures_.reserve (varNodes.size());
|
varSignatures_.reserve (varNodes.size());
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
|
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
|
||||||
varSignatures_.push_back (Signature (c));
|
varSignatures_.push_back (Signature (c));
|
||||||
}
|
}
|
||||||
|
|
||||||
const FgFacSet& facNodes = fg.getFactorNodes();
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
factorSignatures_.reserve (facNodes.size());
|
facSignatures_.reserve (facNodes.size());
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
unsigned c = facNodes[i]->neighbors().size() + 1;
|
unsigned c = facNodes[i]->neighbors().size() + 1;
|
||||||
factorSignatures_.push_back (Signature (c));
|
facSignatures_.push_back (Signature (c));
|
||||||
}
|
}
|
||||||
|
|
||||||
varColors_.resize (varNodes.size());
|
varColors_.resize (varNodes.size());
|
||||||
factorColors_.resize (facNodes.size());
|
facColors_.resize (facNodes.size());
|
||||||
setInitialColors();
|
setInitialColors();
|
||||||
createGroups();
|
createGroups();
|
||||||
}
|
}
|
||||||
@ -50,9 +49,9 @@ CFactorGraph::setInitialColors (void)
|
|||||||
{
|
{
|
||||||
// create the initial variable colors
|
// create the initial variable colors
|
||||||
VarColorMap colorMap;
|
VarColorMap colorMap;
|
||||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
const VarNodes& varNodes = groundFg_->varNodes();
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
unsigned dsize = varNodes[i]->nrStates();
|
unsigned dsize = varNodes[i]->range();
|
||||||
VarColorMap::iterator it = colorMap.find (dsize);
|
VarColorMap::iterator it = colorMap.find (dsize);
|
||||||
if (it == colorMap.end()) {
|
if (it == colorMap.end()) {
|
||||||
it = colorMap.insert (make_pair (
|
it = colorMap.insert (make_pair (
|
||||||
@ -71,29 +70,40 @@ CFactorGraph::setInitialColors (void)
|
|||||||
setColor (varNodes[i], stateColors[idx]);
|
setColor (varNodes[i], stateColors[idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
const FacNodes& facNodes = groundFg_->facNodes();
|
||||||
if (checkForIdenticalFactors) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
for (unsigned i = 0, s = facNodes.size(); i < s; i++) {
|
facNodes[i]->factor().setDistId (Util::maxUnsigned());
|
||||||
Distribution* dist1 = facNodes[i]->getDistribution();
|
}
|
||||||
for (unsigned j = 0; j < i; j++) {
|
// FIXME FIXME FIXME : pfl should give correct dist ids.
|
||||||
Distribution* dist2 = facNodes[j]->getDistribution();
|
if (checkForIdenticalFactors || true) {
|
||||||
if (dist1 != dist2 && dist1->params == dist2->params) {
|
unsigned groupCount = 1;
|
||||||
if (facNodes[i]->factor()->getRanges() ==
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
facNodes[j]->factor()->getRanges()) {
|
Factor& f1 = facNodes[i]->factor();
|
||||||
facNodes[i]->factor()->setDistribution (dist2);
|
if (f1.distId() != Util::maxUnsigned()) {
|
||||||
}
|
continue;
|
||||||
|
}
|
||||||
|
f1.setDistId (groupCount);
|
||||||
|
for (unsigned j = i + 1; j < facNodes.size(); j++) {
|
||||||
|
Factor& f2 = facNodes[j]->factor();
|
||||||
|
if (f2.distId() != Util::maxUnsigned()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (f1.size() == f2.size() &&
|
||||||
|
f1.ranges() == f2.ranges() &&
|
||||||
|
f1.params() == f2.params()) {
|
||||||
|
f2.setDistId (groupCount);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
groupCount ++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// create the initial factor colors
|
// create the initial factor colors
|
||||||
DistColorMap distColors;
|
DistColorMap distColors;
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
const Distribution* dist = facNodes[i]->getDistribution();
|
unsigned distId = facNodes[i]->factor().distId();
|
||||||
DistColorMap::iterator it = distColors.find (dist);
|
DistColorMap::iterator it = distColors.find (distId);
|
||||||
if (it == distColors.end()) {
|
if (it == distColors.end()) {
|
||||||
it = distColors.insert (make_pair (dist, getFreeColor())).first;
|
it = distColors.insert (make_pair (distId, getFreeColor())).first;
|
||||||
}
|
}
|
||||||
setColor (facNodes[i], it->second);
|
setColor (facNodes[i], it->second);
|
||||||
}
|
}
|
||||||
@ -104,31 +114,31 @@ CFactorGraph::setInitialColors (void)
|
|||||||
void
|
void
|
||||||
CFactorGraph::createGroups (void)
|
CFactorGraph::createGroups (void)
|
||||||
{
|
{
|
||||||
VarSignMap varGroups;
|
VarSignMap varGroups;
|
||||||
FacSignMap factorGroups;
|
FacSignMap facGroups;
|
||||||
unsigned nIters = 0;
|
unsigned nIters = 0;
|
||||||
bool groupsHaveChanged = true;
|
bool groupsHaveChanged = true;
|
||||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
const VarNodes& varNodes = groundFg_->varNodes();
|
||||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
const FacNodes& facNodes = groundFg_->facNodes();
|
||||||
|
|
||||||
while (groupsHaveChanged || nIters == 1) {
|
while (groupsHaveChanged || nIters == 1) {
|
||||||
nIters ++;
|
nIters ++;
|
||||||
|
|
||||||
unsigned prevFactorGroupsSize = factorGroups.size();
|
unsigned prevFactorGroupsSize = facGroups.size();
|
||||||
factorGroups.clear();
|
facGroups.clear();
|
||||||
// set a new color to the factors with the same signature
|
// set a new color to the factors with the same signature
|
||||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
const Signature& signature = getSignature (facNodes[i]);
|
const Signature& signature = getSignature (facNodes[i]);
|
||||||
FacSignMap::iterator it = factorGroups.find (signature);
|
FacSignMap::iterator it = facGroups.find (signature);
|
||||||
if (it == factorGroups.end()) {
|
if (it == facGroups.end()) {
|
||||||
it = factorGroups.insert (make_pair (signature, FgFacSet())).first;
|
it = facGroups.insert (make_pair (signature, FacNodes())).first;
|
||||||
}
|
}
|
||||||
it->second.push_back (facNodes[i]);
|
it->second.push_back (facNodes[i]);
|
||||||
}
|
}
|
||||||
for (FacSignMap::iterator it = factorGroups.begin();
|
for (FacSignMap::iterator it = facGroups.begin();
|
||||||
it != factorGroups.end(); it++) {
|
it != facGroups.end(); it++) {
|
||||||
Color newColor = getFreeColor();
|
Color newColor = getFreeColor();
|
||||||
FgFacSet& groupMembers = it->second;
|
FacNodes& groupMembers = it->second;
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
setColor (groupMembers[i], newColor);
|
setColor (groupMembers[i], newColor);
|
||||||
}
|
}
|
||||||
@ -141,36 +151,37 @@ CFactorGraph::createGroups (void)
|
|||||||
const Signature& signature = getSignature (varNodes[i]);
|
const Signature& signature = getSignature (varNodes[i]);
|
||||||
VarSignMap::iterator it = varGroups.find (signature);
|
VarSignMap::iterator it = varGroups.find (signature);
|
||||||
if (it == varGroups.end()) {
|
if (it == varGroups.end()) {
|
||||||
it = varGroups.insert (make_pair (signature, FgVarSet())).first;
|
it = varGroups.insert (make_pair (signature, VarNodes())).first;
|
||||||
}
|
}
|
||||||
it->second.push_back (varNodes[i]);
|
it->second.push_back (varNodes[i]);
|
||||||
}
|
}
|
||||||
for (VarSignMap::iterator it = varGroups.begin();
|
for (VarSignMap::iterator it = varGroups.begin();
|
||||||
it != varGroups.end(); it++) {
|
it != varGroups.end(); it++) {
|
||||||
Color newColor = getFreeColor();
|
Color newColor = getFreeColor();
|
||||||
FgVarSet& groupMembers = it->second;
|
VarNodes& groupMembers = it->second;
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
setColor (groupMembers[i], newColor);
|
setColor (groupMembers[i], newColor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||||
|| prevFactorGroupsSize != factorGroups.size();
|
|| prevFactorGroupsSize != facGroups.size();
|
||||||
}
|
}
|
||||||
//printGroups (varGroups, factorGroups);
|
printGroups (varGroups, facGroups);
|
||||||
createClusters (varGroups, factorGroups);
|
createClusters (varGroups, facGroups);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CFactorGraph::createClusters (const VarSignMap& varGroups,
|
CFactorGraph::createClusters (
|
||||||
const FacSignMap& factorGroups)
|
const VarSignMap& varGroups,
|
||||||
|
const FacSignMap& facGroups)
|
||||||
{
|
{
|
||||||
varClusters_.reserve (varGroups.size());
|
varClusters_.reserve (varGroups.size());
|
||||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
it != varGroups.end(); it++) {
|
it != varGroups.end(); it++) {
|
||||||
const FgVarSet& groupVars = it->second;
|
const VarNodes& groupVars = it->second;
|
||||||
VarCluster* vc = new VarCluster (groupVars);
|
VarCluster* vc = new VarCluster (groupVars);
|
||||||
for (unsigned i = 0; i < groupVars.size(); i++) {
|
for (unsigned i = 0; i < groupVars.size(); i++) {
|
||||||
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
||||||
@ -178,12 +189,12 @@ CFactorGraph::createClusters (const VarSignMap& varGroups,
|
|||||||
varClusters_.push_back (vc);
|
varClusters_.push_back (vc);
|
||||||
}
|
}
|
||||||
|
|
||||||
facClusters_.reserve (factorGroups.size());
|
facClusters_.reserve (facGroups.size());
|
||||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||||
it != factorGroups.end(); it++) {
|
it != facGroups.end(); it++) {
|
||||||
FgFacNode* groupFactor = it->second[0];
|
FacNode* groupFactor = it->second[0];
|
||||||
const FgVarSet& neighs = groupFactor->neighbors();
|
const VarNodes& neighs = groupFactor->neighbors();
|
||||||
VarClusterSet varClusters;
|
VarClusters varClusters;
|
||||||
varClusters.reserve (neighs.size());
|
varClusters.reserve (neighs.size());
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
VarId vid = neighs[i]->varId();
|
VarId vid = neighs[i]->varId();
|
||||||
@ -196,15 +207,15 @@ CFactorGraph::createClusters (const VarSignMap& varGroups,
|
|||||||
|
|
||||||
|
|
||||||
const Signature&
|
const Signature&
|
||||||
CFactorGraph::getSignature (const FgVarNode* varNode)
|
CFactorGraph::getSignature (const VarNode* varNode)
|
||||||
{
|
{
|
||||||
Signature& sign = varSignatures_[varNode->getIndex()];
|
Signature& sign = varSignatures_[varNode->getIndex()];
|
||||||
vector<Color>::iterator it = sign.colors.begin();
|
vector<Color>::iterator it = sign.colors.begin();
|
||||||
const FgFacSet& neighs = varNode->neighbors();
|
const FacNodes& neighs = varNode->neighbors();
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
*it = getColor (neighs[i]);
|
*it = getColor (neighs[i]);
|
||||||
it ++;
|
it ++;
|
||||||
*it = neighs[i]->factor()->indexOf (varNode->varId());
|
*it = neighs[i]->factor().indexOf (varNode->varId());
|
||||||
it ++;
|
it ++;
|
||||||
}
|
}
|
||||||
*it = getColor (varNode);
|
*it = getColor (varNode);
|
||||||
@ -214,11 +225,11 @@ CFactorGraph::getSignature (const FgVarNode* varNode)
|
|||||||
|
|
||||||
|
|
||||||
const Signature&
|
const Signature&
|
||||||
CFactorGraph::getSignature (const FgFacNode* facNode)
|
CFactorGraph::getSignature (const FacNode* facNode)
|
||||||
{
|
{
|
||||||
Signature& sign = factorSignatures_[facNode->getIndex()];
|
Signature& sign = facSignatures_[facNode->getIndex()];
|
||||||
vector<Color>::iterator it = sign.colors.begin();
|
vector<Color>::iterator it = sign.colors.begin();
|
||||||
const FgVarSet& neighs = facNode->neighbors();
|
const VarNodes& neighs = facNode->neighbors();
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
*it = getColor (neighs[i]);
|
*it = getColor (neighs[i]);
|
||||||
it ++;
|
it ++;
|
||||||
@ -230,55 +241,53 @@ CFactorGraph::getSignature (const FgFacNode* facNode)
|
|||||||
|
|
||||||
|
|
||||||
FactorGraph*
|
FactorGraph*
|
||||||
CFactorGraph::getCompressedFactorGraph (void)
|
CFactorGraph::getGroundFactorGraph (void) const
|
||||||
{
|
{
|
||||||
FactorGraph* fg = new FactorGraph();
|
FactorGraph* fg = new FactorGraph();
|
||||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||||
FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0];
|
VarNode* var = varClusters_[i]->getGroundVarNodes()[0];
|
||||||
FgVarNode* newVar = new FgVarNode (var);
|
VarNode* newVar = new VarNode (var);
|
||||||
varClusters_[i]->setRepresentativeVariable (newVar);
|
varClusters_[i]->setRepresentativeVariable (newVar);
|
||||||
fg->addVariable (newVar);
|
fg->addVarNode (newVar);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
for (unsigned i = 0; i < facClusters_.size(); i++) {
|
||||||
const VarClusterSet& myVarClusters = facClusters_[i]->getVarClusters();
|
const VarClusters& myVarClusters = facClusters_[i]->getVarClusters();
|
||||||
VarNodes myGroundVars;
|
Vars myGroundVars;
|
||||||
myGroundVars.reserve (myVarClusters.size());
|
myGroundVars.reserve (myVarClusters.size());
|
||||||
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
||||||
FgVarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
VarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
||||||
myGroundVars.push_back (v);
|
myGroundVars.push_back (v);
|
||||||
}
|
}
|
||||||
Factor* newFactor = new Factor (myGroundVars,
|
FacNode* fn = new FacNode (Factor (myGroundVars,
|
||||||
facClusters_[i]->getGroundFactors()[0]->getDistribution());
|
facClusters_[i]->getGroundFactors()[0]->factor().params()));
|
||||||
FgFacNode* fn = new FgFacNode (newFactor);
|
|
||||||
facClusters_[i]->setRepresentativeFactor (fn);
|
facClusters_[i]->setRepresentativeFactor (fn);
|
||||||
fg->addFactor (fn);
|
fg->addFacNode (fn);
|
||||||
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
||||||
fg->addEdge (fn, static_cast<FgVarNode*> (myGroundVars[j]));
|
fg->addEdge (static_cast<VarNode*> (myGroundVars[j]), fn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fg->setIndexes();
|
|
||||||
return fg;
|
return fg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
unsigned
|
||||||
CFactorGraph::getGroundEdgeCount (
|
CFactorGraph::getEdgeCount (
|
||||||
const FacCluster* fc,
|
const FacCluster* fc,
|
||||||
const VarCluster* vc) const
|
const VarCluster* vc) const
|
||||||
{
|
{
|
||||||
const FgFacSet& clusterGroundFactors = fc->getGroundFactors();
|
|
||||||
FgVarNode* varNode = vc->getGroundFgVarNodes()[0];
|
|
||||||
unsigned count = 0;
|
unsigned count = 0;
|
||||||
|
VarId vid = vc->getGroundVarNodes().front()->varId();
|
||||||
|
const FacNodes& clusterGroundFactors = fc->getGroundFactors();
|
||||||
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||||
if (clusterGroundFactors[i]->factor()->indexOf (varNode->varId()) != -1) {
|
if (clusterGroundFactors[i]->factor().contains (vid)) {
|
||||||
count ++;
|
count ++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// CFgVarSet vars = vc->getGroundFgVarNodes();
|
// CVarNodes vars = vc->getGroundVarNodes();
|
||||||
// for (unsigned i = 1; i < vars.size(); i++) {
|
// for (unsigned i = 1; i < vars.size(); i++) {
|
||||||
// FgVarNode* var = vc->getGroundFgVarNodes()[i];
|
// VarNode* var = vc->getGroundVarNodes()[i];
|
||||||
// unsigned count2 = 0;
|
// unsigned count2 = 0;
|
||||||
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||||
// if (clusterGroundFactors[i]->getPosition (var) != -1) {
|
// if (clusterGroundFactors[i]->getPosition (var) != -1) {
|
||||||
@ -293,14 +302,15 @@ CFactorGraph::getGroundEdgeCount (
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CFactorGraph::printGroups (const VarSignMap& varGroups,
|
CFactorGraph::printGroups (
|
||||||
const FacSignMap& factorGroups) const
|
const VarSignMap& varGroups,
|
||||||
|
const FacSignMap& facGroups) const
|
||||||
{
|
{
|
||||||
unsigned count = 1;
|
unsigned count = 1;
|
||||||
cout << "variable groups:" << endl;
|
cout << "variable groups:" << endl;
|
||||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||||
it != varGroups.end(); it++) {
|
it != varGroups.end(); it++) {
|
||||||
const FgVarSet& groupMembers = it->second;
|
const VarNodes& groupMembers = it->second;
|
||||||
if (groupMembers.size() > 0) {
|
if (groupMembers.size() > 0) {
|
||||||
cout << count << ": " ;
|
cout << count << ": " ;
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
@ -313,9 +323,9 @@ CFactorGraph::printGroups (const VarSignMap& varGroups,
|
|||||||
|
|
||||||
count = 1;
|
count = 1;
|
||||||
cout << endl << "factor groups:" << endl;
|
cout << endl << "factor groups:" << endl;
|
||||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
for (FacSignMap::const_iterator it = facGroups.begin();
|
||||||
it != factorGroups.end(); it++) {
|
it != facGroups.end(); it++) {
|
||||||
const FgFacSet& groupMembers = it->second;
|
const FacNodes& groupMembers = it->second;
|
||||||
if (groupMembers.size() > 0) {
|
if (groupMembers.size() > 0) {
|
||||||
cout << ++count << ": " ;
|
cout << ++count << ": " ;
|
||||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||||
|
@ -15,23 +15,25 @@ class Signature;
|
|||||||
class SignatureHash;
|
class SignatureHash;
|
||||||
|
|
||||||
|
|
||||||
typedef long Color;
|
typedef long Color;
|
||||||
typedef unordered_map<unsigned, vector<Color> > VarColorMap;
|
|
||||||
typedef unordered_map<const Distribution*, Color> DistColorMap;
|
typedef unordered_map<unsigned, vector<Color>> VarColorMap;
|
||||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
|
||||||
typedef vector<VarCluster*> VarClusterSet;
|
typedef unordered_map<unsigned, Color> DistColorMap;
|
||||||
typedef vector<FacCluster*> FacClusterSet;
|
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||||
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
|
|
||||||
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
|
typedef vector<VarCluster*> VarClusters;
|
||||||
|
typedef vector<FacCluster*> FacClusters;
|
||||||
|
|
||||||
|
typedef unordered_map<Signature, VarNodes, SignatureHash> VarSignMap;
|
||||||
|
typedef unordered_map<Signature, FacNodes, SignatureHash> FacSignMap;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
struct Signature
|
struct Signature
|
||||||
{
|
{
|
||||||
Signature (unsigned size)
|
Signature (unsigned size) : colors(size) { }
|
||||||
{
|
|
||||||
colors.resize (size);
|
|
||||||
}
|
|
||||||
bool operator< (const Signature& sig) const
|
bool operator< (const Signature& sig) const
|
||||||
{
|
{
|
||||||
if (colors.size() < sig.colors.size()) {
|
if (colors.size() < sig.colors.size()) {
|
||||||
@ -49,6 +51,7 @@ struct Signature
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator== (const Signature& sig) const
|
bool operator== (const Signature& sig) const
|
||||||
{
|
{
|
||||||
if (colors.size() != sig.colors.size()) {
|
if (colors.size() != sig.colors.size()) {
|
||||||
@ -61,12 +64,14 @@ struct Signature
|
|||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<Color> colors;
|
vector<Color> colors;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
struct SignatureHash {
|
struct SignatureHash
|
||||||
|
{
|
||||||
size_t operator() (const Signature &sig) const
|
size_t operator() (const Signature &sig) const
|
||||||
{
|
{
|
||||||
size_t val = hash<size_t>()(sig.colors.size());
|
size_t val = hash<size_t>()(sig.colors.size());
|
||||||
@ -82,7 +87,7 @@ struct SignatureHash {
|
|||||||
class VarCluster
|
class VarCluster
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
VarCluster (const FgVarSet& vs)
|
VarCluster (const VarNodes& vs)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < vs.size(); i++) {
|
for (unsigned i = 0; i < vs.size(); i++) {
|
||||||
groundVars_.push_back (vs[i]);
|
groundVars_.push_back (vs[i]);
|
||||||
@ -94,26 +99,28 @@ class VarCluster
|
|||||||
facClusters_.push_back (fc);
|
facClusters_.push_back (fc);
|
||||||
}
|
}
|
||||||
|
|
||||||
const FacClusterSet& getFacClusters (void) const
|
const FacClusters& getFacClusters (void) const
|
||||||
{
|
{
|
||||||
return facClusters_;
|
return facClusters_;
|
||||||
}
|
}
|
||||||
|
|
||||||
FgVarNode* getRepresentativeVariable (void) const { return representVar_; }
|
VarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||||
void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; }
|
|
||||||
const FgVarSet& getGroundFgVarNodes (void) const { return groundVars_; }
|
void setRepresentativeVariable (VarNode* v) { representVar_ = v; }
|
||||||
|
|
||||||
|
const VarNodes& getGroundVarNodes (void) const { return groundVars_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FgVarSet groundVars_;
|
VarNodes groundVars_;
|
||||||
FacClusterSet facClusters_;
|
FacClusters facClusters_;
|
||||||
FgVarNode* representVar_;
|
VarNode* representVar_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class FacCluster
|
class FacCluster
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FacCluster (const FgFacSet& groundFactors, const VarClusterSet& vcs)
|
FacCluster (const FacNodes& groundFactors, const VarClusters& vcs)
|
||||||
{
|
{
|
||||||
groundFactors_ = groundFactors;
|
groundFactors_ = groundFactors;
|
||||||
varClusters_ = vcs;
|
varClusters_ = vcs;
|
||||||
@ -122,12 +129,12 @@ class FacCluster
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const VarClusterSet& getVarClusters (void) const
|
const VarClusters& getVarClusters (void) const
|
||||||
{
|
{
|
||||||
return varClusters_;
|
return varClusters_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool containsGround (const FgFacNode* fn)
|
bool containsGround (const FacNode* fn)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
||||||
if (groundFactors_[i] == fn) {
|
if (groundFactors_[i] == fn) {
|
||||||
@ -137,24 +144,26 @@ class FacCluster
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
FgFacNode* getRepresentativeFactor (void) const
|
FacNode* getRepresentativeFactor (void) const
|
||||||
{
|
{
|
||||||
return representFactor_;
|
return representFactor_;
|
||||||
}
|
}
|
||||||
void setRepresentativeFactor (FgFacNode* fn)
|
|
||||||
|
void setRepresentativeFactor (FacNode* fn)
|
||||||
{
|
{
|
||||||
representFactor_ = fn;
|
representFactor_ = fn;
|
||||||
}
|
}
|
||||||
const FgFacSet& getGroundFactors (void) const
|
|
||||||
|
const FacNodes& getGroundFactors (void) const
|
||||||
{
|
{
|
||||||
return groundFactors_;
|
return groundFactors_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FgFacSet groundFactors_;
|
FacNodes groundFactors_;
|
||||||
VarClusterSet varClusters_;
|
VarClusters varClusters_;
|
||||||
FgFacNode* representFactor_;
|
FacNode* representFactor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -162,51 +171,48 @@ class CFactorGraph
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CFactorGraph (const FactorGraph&);
|
CFactorGraph (const FactorGraph&);
|
||||||
|
|
||||||
~CFactorGraph (void);
|
~CFactorGraph (void);
|
||||||
|
|
||||||
FactorGraph* getCompressedFactorGraph (void);
|
const VarClusters& getVarClusters (void) { return varClusters_; }
|
||||||
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
|
|
||||||
|
|
||||||
FgVarNode* getEquivalentVariable (VarId vid)
|
const FacClusters& getFacClusters (void) { return facClusters_; }
|
||||||
|
|
||||||
|
VarNode* getEquivalentVariable (VarId vid)
|
||||||
{
|
{
|
||||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||||
return vc->getRepresentativeVariable();
|
return vc->getRepresentativeVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
const VarClusterSet& getVarClusters (void) { return varClusters_; }
|
|
||||||
const FacClusterSet& getFacClusters (void) { return facClusters_; }
|
|
||||||
|
|
||||||
|
FactorGraph* getGroundFactorGraph (void) const;
|
||||||
|
|
||||||
|
unsigned getEdgeCount (const FacCluster*, const VarCluster*) const;
|
||||||
|
|
||||||
static bool checkForIdenticalFactors;
|
static bool checkForIdenticalFactors;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void setInitialColors (void);
|
Color getFreeColor (void)
|
||||||
void createGroups (void);
|
{
|
||||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
|
||||||
const Signature& getSignature (const FgVarNode*);
|
|
||||||
const Signature& getSignature (const FgFacNode*);
|
|
||||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
|
||||||
|
|
||||||
Color getFreeColor (void) {
|
|
||||||
++ freeColor_;
|
++ freeColor_;
|
||||||
return freeColor_ - 1;
|
return freeColor_ - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Color getColor (const FgVarNode* vn) const
|
Color getColor (const VarNode* vn) const
|
||||||
{
|
{
|
||||||
return varColors_[vn->getIndex()];
|
return varColors_[vn->getIndex()];
|
||||||
}
|
}
|
||||||
Color getColor (const FgFacNode* fn) const {
|
Color getColor (const FacNode* fn) const {
|
||||||
return factorColors_[fn->getIndex()];
|
return facColors_[fn->getIndex()];
|
||||||
}
|
}
|
||||||
|
|
||||||
void setColor (const FgVarNode* vn, Color c)
|
void setColor (const VarNode* vn, Color c)
|
||||||
{
|
{
|
||||||
varColors_[vn->getIndex()] = c;
|
varColors_[vn->getIndex()] = c;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setColor (const FgFacNode* fn, Color c)
|
void setColor (const FacNode* fn, Color c)
|
||||||
{
|
{
|
||||||
factorColors_[fn->getIndex()] = c;
|
facColors_[fn->getIndex()] = c;
|
||||||
}
|
}
|
||||||
|
|
||||||
VarCluster* getVariableCluster (VarId vid) const
|
VarCluster* getVariableCluster (VarId vid) const
|
||||||
@ -214,14 +220,26 @@ class CFactorGraph
|
|||||||
return vid2VarCluster_.find (vid)->second;
|
return vid2VarCluster_.find (vid)->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void setInitialColors (void);
|
||||||
|
|
||||||
|
void createGroups (void);
|
||||||
|
|
||||||
|
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||||
|
|
||||||
|
const Signature& getSignature (const VarNode*);
|
||||||
|
|
||||||
|
const Signature& getSignature (const FacNode*);
|
||||||
|
|
||||||
|
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||||
|
|
||||||
Color freeColor_;
|
Color freeColor_;
|
||||||
vector<Color> varColors_;
|
vector<Color> varColors_;
|
||||||
vector<Color> factorColors_;
|
vector<Color> facColors_;
|
||||||
vector<Signature> varSignatures_;
|
vector<Signature> varSignatures_;
|
||||||
vector<Signature> factorSignatures_;
|
vector<Signature> facSignatures_;
|
||||||
VarClusterSet varClusters_;
|
VarClusters varClusters_;
|
||||||
FacClusterSet facClusters_;
|
FacClusters facClusters_;
|
||||||
VarId2VarCluster vid2VarCluster_;
|
VarId2VarCluster vid2VarCluster_;
|
||||||
const FactorGraph* groundFg_;
|
const FactorGraph* groundFg_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,10 +1,41 @@
|
|||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
|
|
||||||
|
|
||||||
|
CbpSolver::CbpSolver (const FactorGraph& fg) : BpSolver (fg)
|
||||||
|
{
|
||||||
|
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
||||||
|
if (Constants::COLLECT_STATS) {
|
||||||
|
nGroundVars = fg_->varNodes().size();
|
||||||
|
nGroundFacs = fg_->facNodes().size();
|
||||||
|
const VarNodes& vars = fg_->varNodes();
|
||||||
|
nWithoutNeighs = 0;
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
const FacNodes& factors = vars[i]->neighbors();
|
||||||
|
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
||||||
|
nWithoutNeighs ++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfg_ = new CFactorGraph (fg);
|
||||||
|
fg_ = cfg_->getGroundFactorGraph();
|
||||||
|
if (Constants::COLLECT_STATS) {
|
||||||
|
unsigned nClusterVars = fg_->varNodes().size();
|
||||||
|
unsigned nClusterFacs = fg_->facNodes().size();
|
||||||
|
Statistics::updateCompressingStatistics (nGroundVars,
|
||||||
|
nGroundFacs, nClusterVars, nClusterFacs, nWithoutNeighs);
|
||||||
|
}
|
||||||
|
Util::printHeader ("Uncompressed Factor Graph");
|
||||||
|
fg.print();
|
||||||
|
Util::printHeader ("Compressed Factor Graph");
|
||||||
|
fg_->print();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CbpSolver::~CbpSolver (void)
|
CbpSolver::~CbpSolver (void)
|
||||||
{
|
{
|
||||||
delete lfg_;
|
delete cfg_;
|
||||||
delete factorGraph_;
|
delete fg_;
|
||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
delete links_[i];
|
delete links_[i];
|
||||||
}
|
}
|
||||||
@ -16,28 +47,31 @@ CbpSolver::~CbpSolver (void)
|
|||||||
Params
|
Params
|
||||||
CbpSolver::getPosterioriOf (VarId vid)
|
CbpSolver::getPosterioriOf (VarId vid)
|
||||||
{
|
{
|
||||||
assert (lfg_->getEquivalentVariable (vid));
|
if (runned_ == false) {
|
||||||
FgVarNode* var = lfg_->getEquivalentVariable (vid);
|
runSolver();
|
||||||
|
}
|
||||||
|
assert (cfg_->getEquivalentVariable (vid));
|
||||||
|
VarNode* var = cfg_->getEquivalentVariable (vid);
|
||||||
Params probs;
|
Params probs;
|
||||||
if (var->hasEvidence()) {
|
if (var->hasEvidence()) {
|
||||||
probs.resize (var->nrStates(), Util::noEvidence());
|
probs.resize (var->range(), LogAware::noEvidence());
|
||||||
probs[var->getEvidence()] = Util::withEvidence();
|
probs[var->getEvidence()] = LogAware::withEvidence();
|
||||||
} else {
|
} else {
|
||||||
probs.resize (var->nrStates(), Util::multIdenty());
|
probs.resize (var->range(), LogAware::multIdenty());
|
||||||
const SpLinkSet& links = ninf(var)->getLinks();
|
const SpLinkSet& links = ninf(var)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::add (probs, l->getPoweredMessage());
|
Util::add (probs, l->poweredMessage());
|
||||||
}
|
}
|
||||||
Util::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
Util::fromLog (probs);
|
Util::fromLog (probs);
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::multiply (probs, l->getPoweredMessage());
|
Util::multiply (probs, l->poweredMessage());
|
||||||
}
|
}
|
||||||
Util::normalize (probs);
|
LogAware::normalize (probs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return probs;
|
return probs;
|
||||||
@ -46,55 +80,14 @@ CbpSolver::getPosterioriOf (VarId vid)
|
|||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
CbpSolver::getJointDistributionOf (const VarIds& jointVarIds)
|
CbpSolver::getJointDistributionOf (const VarIds& jointVids)
|
||||||
{
|
{
|
||||||
VarIds eqVarIds;
|
VarIds eqVarIds;
|
||||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||||
eqVarIds.push_back (lfg_->getEquivalentVariable (jointVarIds[i])->varId());
|
VarNode* vn = cfg_->getEquivalentVariable (jointVids[i]);
|
||||||
|
eqVarIds.push_back (vn->varId());
|
||||||
}
|
}
|
||||||
return FgBpSolver::getJointDistributionOf (eqVarIds);
|
return BpSolver::getJointDistributionOf (eqVarIds);
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
CbpSolver::initializeSolver (void)
|
|
||||||
{
|
|
||||||
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
|
||||||
if (COLLECT_STATISTICS) {
|
|
||||||
nGroundVars = factorGraph_->getVarNodes().size();
|
|
||||||
nGroundFacs = factorGraph_->getFactorNodes().size();
|
|
||||||
const FgVarSet& vars = factorGraph_->getVarNodes();
|
|
||||||
nWithoutNeighs = 0;
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
const FgFacSet& factors = vars[i]->neighbors();
|
|
||||||
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
|
||||||
nWithoutNeighs ++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
lfg_ = new CFactorGraph (*factorGraph_);
|
|
||||||
|
|
||||||
// cout << "Uncompressed Factor Graph" << endl;
|
|
||||||
// factorGraph_->printGraphicalModel();
|
|
||||||
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
|
|
||||||
factorGraph_ = lfg_->getCompressedFactorGraph();
|
|
||||||
|
|
||||||
if (COLLECT_STATISTICS) {
|
|
||||||
unsigned nClusterVars = factorGraph_->getVarNodes().size();
|
|
||||||
unsigned nClusterFacs = factorGraph_->getFactorNodes().size();
|
|
||||||
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
|
|
||||||
nClusterVars, nClusterFacs,
|
|
||||||
nWithoutNeighs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// cout << "Compressed Factor Graph" << endl;
|
|
||||||
// factorGraph_->printGraphicalModel();
|
|
||||||
// factorGraph_->exportToGraphViz ("compressed_fg.dot");
|
|
||||||
// abort();
|
|
||||||
FgBpSolver::initializeSolver();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -102,12 +95,13 @@ CbpSolver::initializeSolver (void)
|
|||||||
void
|
void
|
||||||
CbpSolver::createLinks (void)
|
CbpSolver::createLinks (void)
|
||||||
{
|
{
|
||||||
const FacClusterSet fcs = lfg_->getFacClusters();
|
const FacClusters& fcs = cfg_->getFacClusters();
|
||||||
for (unsigned i = 0; i < fcs.size(); i++) {
|
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||||
const VarClusterSet vcs = fcs[i]->getVarClusters();
|
const VarClusters& vcs = fcs[i]->getVarClusters();
|
||||||
for (unsigned j = 0; j < vcs.size(); j++) {
|
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||||
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]);
|
unsigned c = cfg_->getEdgeCount (fcs[i], vcs[j]);
|
||||||
links_.push_back (new CbpSolverLink (fcs[i]->getRepresentativeFactor(),
|
links_.push_back (new CbpSolverLink (
|
||||||
|
fcs[i]->getRepresentativeFactor(),
|
||||||
vcs[j]->getRepresentativeVariable(), c));
|
vcs[j]->getRepresentativeVariable(), c));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -123,7 +117,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
calculateMessage (links_[i]);
|
calculateMessage (links_[i]);
|
||||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||||
linkMap_.insert (make_pair (links_[i], it));
|
linkMap_.insert (make_pair (links_[i], it));
|
||||||
if (DL >= 2 && DL < 5) {
|
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||||
cout << "calculating " << links_[i]->toString() << endl;
|
cout << "calculating " << links_[i]->toString() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -131,7 +125,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned c = 0; c < links_.size(); c++) {
|
for (unsigned c = 0; c < links_.size(); c++) {
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << endl << "current residuals:" << endl;
|
cout << endl << "current residuals:" << endl;
|
||||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
it != sortedOrder_.end(); it ++) {
|
it != sortedOrder_.end(); it ++) {
|
||||||
@ -142,7 +136,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
|
|
||||||
SortedOrder::iterator it = sortedOrder_.begin();
|
SortedOrder::iterator it = sortedOrder_.begin();
|
||||||
SpLink* link = *it;
|
SpLink* link = *it;
|
||||||
if (DL >= 2) {
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||||
}
|
}
|
||||||
if (link->getResidual() < BpOptions::accuracy) {
|
if (link->getResidual() < BpOptions::accuracy) {
|
||||||
@ -154,12 +148,12 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||||
|
|
||||||
// update the messages that depend on message source --> destin
|
// update the messages that depend on message source --> destin
|
||||||
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
|
const FacNodes& factorNeighbors = link->getVariable()->neighbors();
|
||||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||||
for (unsigned j = 0; j < links.size(); j++) {
|
for (unsigned j = 0; j < links.size(); j++) {
|
||||||
if (links[j]->getVariable() != link->getVariable()) {
|
if (links[j]->getVariable() != link->getVariable()) {
|
||||||
if (DL >= 2 && DL < 5) {
|
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||||
cout << " calculating " << links[j]->toString() << endl;
|
cout << " calculating " << links[j]->toString() << endl;
|
||||||
}
|
}
|
||||||
calculateMessage (links[j]);
|
calculateMessage (links[j]);
|
||||||
@ -174,7 +168,7 @@ CbpSolver::maxResidualSchedule (void)
|
|||||||
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
|
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getVariable() != link->getVariable()) {
|
if (links[i]->getVariable() != link->getVariable()) {
|
||||||
if (DL >= 2 && DL < 5) {
|
if (Constants::DEBUG >= 2 && Constants::DEBUG < 5) {
|
||||||
cout << " calculating " << links[i]->toString() << endl;
|
cout << " calculating " << links[i]->toString() << endl;
|
||||||
}
|
}
|
||||||
calculateMessage (links[i]);
|
calculateMessage (links[i]);
|
||||||
@ -192,43 +186,43 @@ Params
|
|||||||
CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||||
{
|
{
|
||||||
Params msg;
|
Params msg;
|
||||||
const FgVarNode* src = link->getVariable();
|
const VarNode* src = link->getVariable();
|
||||||
const FgFacNode* dst = link->getFactor();
|
const FacNode* dst = link->getFactor();
|
||||||
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
||||||
if (src->hasEvidence()) {
|
if (src->hasEvidence()) {
|
||||||
msg.resize (src->nrStates(), Util::noEvidence());
|
msg.resize (src->range(), LogAware::noEvidence());
|
||||||
double value = link->getMessage()[src->getEvidence()];
|
double value = link->getMessage()[src->getEvidence()];
|
||||||
msg[src->getEvidence()] = Util::pow (value, l->getNumberOfEdges() - 1);
|
msg[src->getEvidence()] = LogAware::pow (value, l->nrEdges() - 1);
|
||||||
} else {
|
} else {
|
||||||
msg = link->getMessage();
|
msg = link->getMessage();
|
||||||
Util::pow (msg, l->getNumberOfEdges() - 1);
|
LogAware::pow (msg, l->nrEdges() - 1);
|
||||||
}
|
}
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " " << "init: " << Util::parametersToString (msg) << endl;
|
cout << " " << "init: " << msg << endl;
|
||||||
}
|
}
|
||||||
const SpLinkSet& links = ninf(src)->getLinks();
|
const SpLinkSet& links = ninf(src)->getLinks();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getFactor() != dst) {
|
if (links[i]->getFactor() != dst) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::add (msg, l->getPoweredMessage());
|
Util::add (msg, l->poweredMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = 0; i < links.size(); i++) {
|
for (unsigned i = 0; i < links.size(); i++) {
|
||||||
if (links[i]->getFactor() != dst) {
|
if (links[i]->getFactor() != dst) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||||
Util::multiply (msg, l->getPoweredMessage());
|
Util::multiply (msg, l->poweredMessage());
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
||||||
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
|
cout << l->poweredMessage() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (DL >= 5) {
|
if (Constants::DEBUG >= 5) {
|
||||||
cout << " result = " << Util::parametersToString (msg) << endl;
|
cout << " result = " << msg << endl;
|
||||||
}
|
}
|
||||||
return msg;
|
return msg;
|
||||||
}
|
}
|
||||||
@ -241,12 +235,9 @@ CbpSolver::printLinkInformation (void) const
|
|||||||
for (unsigned i = 0; i < links_.size(); i++) {
|
for (unsigned i = 0; i < links_.size(); i++) {
|
||||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
|
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
|
||||||
cout << l->toString() << ":" << endl;
|
cout << l->toString() << ":" << endl;
|
||||||
cout << " curr msg = " ;
|
cout << " curr msg = " << l->getMessage() << endl;
|
||||||
cout << Util::parametersToString (l->getMessage()) << endl;
|
cout << " next msg = " << l->getNextMessage() << endl;
|
||||||
cout << " next msg = " ;
|
cout << " powered = " << l->poweredMessage() << endl;
|
||||||
cout << Util::parametersToString (l->getNextMessage()) << endl;
|
|
||||||
cout << " powered = " ;
|
|
||||||
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
|
|
||||||
cout << " residual = " << l->getResidual() << endl;
|
cout << " residual = " << l->getResidual() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#ifndef HORUS_CBP_H
|
#ifndef HORUS_CBP_H
|
||||||
#define HORUS_CBP_H
|
#define HORUS_CBP_H
|
||||||
|
|
||||||
#include "FgBpSolver.h"
|
#include "BpSolver.h"
|
||||||
#include "CFactorGraph.h"
|
#include "CFactorGraph.h"
|
||||||
|
|
||||||
class Factor;
|
class Factor;
|
||||||
@ -9,49 +9,51 @@ class Factor;
|
|||||||
class CbpSolverLink : public SpLink
|
class CbpSolverLink : public SpLink
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn)
|
CbpSolverLink (FacNode* fn, VarNode* vn, unsigned c)
|
||||||
{
|
: SpLink (fn, vn), nrEdges_(c),
|
||||||
edgeCount_ = c;
|
pwdMsg_(vn->range(), LogAware::one()) { }
|
||||||
poweredMsg_.resize (vn->nrStates(), Util::one());
|
|
||||||
}
|
unsigned nrEdges (void) const { return nrEdges_; }
|
||||||
|
|
||||||
|
const Params& poweredMessage (void) const { return pwdMsg_; }
|
||||||
|
|
||||||
void updateMessage (void)
|
void updateMessage (void)
|
||||||
{
|
{
|
||||||
poweredMsg_ = *nextMsg_;
|
pwdMsg_ = *nextMsg_;
|
||||||
swap (currMsg_, nextMsg_);
|
swap (currMsg_, nextMsg_);
|
||||||
msgSended_ = true;
|
msgSended_ = true;
|
||||||
Util::pow (poweredMsg_, edgeCount_);
|
LogAware::pow (pwdMsg_, nrEdges_);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
|
||||||
const Params& getPoweredMessage (void) const { return poweredMsg_; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Params poweredMsg_;
|
unsigned nrEdges_;
|
||||||
unsigned edgeCount_;
|
Params pwdMsg_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CbpSolver : public FgBpSolver
|
class CbpSolver : public BpSolver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CbpSolver (FactorGraph& fg) : FgBpSolver (fg) { }
|
CbpSolver (const FactorGraph& fg);
|
||||||
~CbpSolver (void);
|
|
||||||
|
|
||||||
Params getPosterioriOf (VarId);
|
~CbpSolver (void);
|
||||||
Params getJointDistributionOf (const VarIds&);
|
|
||||||
|
Params getPosterioriOf (VarId);
|
||||||
|
|
||||||
|
Params getJointDistributionOf (const VarIds&);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void initializeSolver (void);
|
|
||||||
void createLinks (void);
|
|
||||||
|
|
||||||
void maxResidualSchedule (void);
|
void createLinks (void);
|
||||||
Params getVar2FactorMsg (const SpLink*) const;
|
|
||||||
void printLinkInformation (void) const;
|
|
||||||
|
|
||||||
|
void maxResidualSchedule (void);
|
||||||
|
|
||||||
CFactorGraph* lfg_;
|
Params getVar2FactorMsg (const SpLink*) const;
|
||||||
|
|
||||||
|
void printLinkInformation (void) const;
|
||||||
|
|
||||||
|
CFactorGraph* cfg_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_CBP_H
|
#endif // HORUS_CBP_H
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
#include <queue>
|
#include <queue>
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
#include "ConstraintTree.h"
|
#include "ConstraintTree.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
CTNode::addChild (CTNode* child, bool updateLevels)
|
CTNode::addChild (CTNode* child, bool updateLevels)
|
||||||
{
|
{
|
||||||
@ -42,6 +43,34 @@ CTNode::removeChild (CTNode* child)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CTNode::removeChilds (void)
|
||||||
|
{
|
||||||
|
childs_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CTNode::removeAndDeleteChild (CTNode* child)
|
||||||
|
{
|
||||||
|
removeChild (child);
|
||||||
|
CTNode::deleteSubtree (child);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CTNode::removeAndDeleteAllChilds (void)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < childs_.size(); i++) {
|
||||||
|
deleteSubtree (childs_[i]);
|
||||||
|
}
|
||||||
|
childs_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
SymbolSet
|
SymbolSet
|
||||||
CTNode::childSymbols (void) const
|
CTNode::childSymbols (void) const
|
||||||
{
|
{
|
||||||
@ -66,6 +95,32 @@ CTNode::updateChildLevels (CTNode* n, unsigned level)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
CTNode*
|
||||||
|
CTNode::copySubtree (const CTNode* n)
|
||||||
|
{
|
||||||
|
CTNode* newNode = new CTNode (*n);
|
||||||
|
const CTNodes& childs = n->childs();
|
||||||
|
for (unsigned i = 0; i < childs.size(); i++) {
|
||||||
|
newNode->addChild (copySubtree (childs[i]));
|
||||||
|
}
|
||||||
|
return newNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
CTNode::deleteSubtree (CTNode* n)
|
||||||
|
{
|
||||||
|
assert (n);
|
||||||
|
const CTNodes& childs = n->childs();
|
||||||
|
for (unsigned i = 0; i < childs.size(); i++) {
|
||||||
|
deleteSubtree (childs[i]);
|
||||||
|
}
|
||||||
|
delete n;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ostream& operator<< (ostream &out, const CTNode& n)
|
ostream& operator<< (ostream &out, const CTNode& n)
|
||||||
{
|
{
|
||||||
// out << "(" << n.level() << ") " ;
|
// out << "(" << n.level() << ") " ;
|
||||||
@ -75,6 +130,17 @@ ostream& operator<< (ostream &out, const CTNode& n)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ConstraintTree::ConstraintTree (unsigned nrLvs)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < nrLvs; i++) {
|
||||||
|
logVars_.push_back (LogVar (i));
|
||||||
|
}
|
||||||
|
root_ = new CTNode (0, 0);
|
||||||
|
logVarSet_ = LogVarSet (logVars_);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ConstraintTree::ConstraintTree (const LogVars& logVars)
|
ConstraintTree::ConstraintTree (const LogVars& logVars)
|
||||||
{
|
{
|
||||||
root_ = new CTNode (0, 0);
|
root_ = new CTNode (0, 0);
|
||||||
@ -99,7 +165,7 @@ ConstraintTree::ConstraintTree (const LogVars& logVars,
|
|||||||
|
|
||||||
ConstraintTree::ConstraintTree (const ConstraintTree& ct)
|
ConstraintTree::ConstraintTree (const ConstraintTree& ct)
|
||||||
{
|
{
|
||||||
root_ = copySubtree (ct.root_);
|
root_ = CTNode::copySubtree (ct.root_);
|
||||||
logVars_ = ct.logVars_;
|
logVars_ = ct.logVars_;
|
||||||
logVarSet_ = ct.logVarSet_;
|
logVarSet_ = ct.logVarSet_;
|
||||||
}
|
}
|
||||||
@ -108,7 +174,7 @@ ConstraintTree::ConstraintTree (const ConstraintTree& ct)
|
|||||||
|
|
||||||
ConstraintTree::~ConstraintTree (void)
|
ConstraintTree::~ConstraintTree (void)
|
||||||
{
|
{
|
||||||
deleteSubtree (root_);
|
CTNode::deleteSubtree (root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -200,21 +266,28 @@ ConstraintTree::moveToBottom (const LogVars& lvs)
|
|||||||
|
|
||||||
void
|
void
|
||||||
ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
|
ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
|
||||||
{
|
{
|
||||||
|
if (logVarSet_.empty()) {
|
||||||
|
delete root_;
|
||||||
|
root_ = CTNode::copySubtree (ct->root());
|
||||||
|
logVars_ = ct->logVars();
|
||||||
|
logVarSet_ = ct->logVarSet();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
LogVarSet intersect = logVarSet_ & ct->logVarSet_;
|
LogVarSet intersect = logVarSet_ & ct->logVarSet_;
|
||||||
if (intersect.empty()) {
|
if (intersect.empty()) {
|
||||||
const CTNodes& childs = ct->root()->childs();
|
const CTNodes& childs = ct->root()->childs();
|
||||||
CTNodes leafs = getNodesAtLevel (getLevel (logVars_.back()));
|
CTNodes leafs = getNodesAtLevel (getLevel (logVars_.back()));
|
||||||
for (unsigned i = 0; i < leafs.size(); i++) {
|
for (unsigned i = 0; i < leafs.size(); i++) {
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
for (unsigned j = 0; j < childs.size(); j++) {
|
||||||
leafs[i]->addChild (copySubtree (childs[j]));
|
leafs[i]->addChild (CTNode::copySubtree (childs[j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logVars_.insert (logVars_.end(), ct->logVars_.begin(), ct->logVars_.end());
|
Util::addToVector (logVars_, ct->logVars_);
|
||||||
logVarSet_ |= ct->logVarSet_;
|
logVarSet_ |= ct->logVarSet_;
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
moveToBottom (intersect.elements());
|
moveToBottom (intersect.elements());
|
||||||
ct->moveToTop (intersect.elements());
|
ct->moveToTop (intersect.elements());
|
||||||
|
|
||||||
@ -222,25 +295,27 @@ ConstraintTree::join (ConstraintTree* ct, bool assertWhenNotFound)
|
|||||||
CTNodes nodes = getNodesAtLevel (level);
|
CTNodes nodes = getNodesAtLevel (level);
|
||||||
|
|
||||||
Tuples tuples;
|
Tuples tuples;
|
||||||
CTNodes continuationNodes;
|
CTNodes continNodes;
|
||||||
getTuples (ct->root(),
|
getTuples (ct->root(),
|
||||||
Tuples(),
|
Tuples(),
|
||||||
intersect.size(),
|
intersect.size(),
|
||||||
tuples,
|
tuples,
|
||||||
continuationNodes);
|
continNodes);
|
||||||
|
|
||||||
for (unsigned i = 0; i < tuples.size(); i++) {
|
for (unsigned i = 0; i < tuples.size(); i++) {
|
||||||
bool tupleFounded = false;
|
bool tupleFounded = false;
|
||||||
for (unsigned j = 0; j < nodes.size(); j++) {
|
for (unsigned j = 0; j < nodes.size(); j++) {
|
||||||
tupleFounded |= join (nodes[j], tuples[i], 0, continuationNodes[i]);
|
tupleFounded |= join (nodes[j], tuples[i], 0, continNodes[i]);
|
||||||
}
|
}
|
||||||
if (assertWhenNotFound) {
|
if (assertWhenNotFound) {
|
||||||
assert (tupleFounded);
|
assert (tupleFounded);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LogVarSet newLvs = ct->logVarSet_ - intersect;
|
|
||||||
logVars_.insert (logVars_.end(), newLvs.begin(), newLvs.end());
|
LogVars newLvs (ct->logVars().begin() + intersect.size(),
|
||||||
logVarSet_ |= newLvs;
|
ct->logVars().end());
|
||||||
|
Util::addToVector (logVars_, newLvs);
|
||||||
|
logVarSet_ |= LogVarSet (newLvs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -280,6 +355,10 @@ ConstraintTree::rename (LogVar X_old, LogVar X_new)
|
|||||||
void
|
void
|
||||||
ConstraintTree::applySubstitution (const Substitution& theta)
|
ConstraintTree::applySubstitution (const Substitution& theta)
|
||||||
{
|
{
|
||||||
|
LogVars discardedLvs = theta.getDiscardedLogVars();
|
||||||
|
for (unsigned i = 0; i < discardedLvs.size(); i++) {
|
||||||
|
remove(discardedLvs[i]);
|
||||||
|
}
|
||||||
for (unsigned i = 0; i < logVars_.size(); i++) {
|
for (unsigned i = 0; i < logVars_.size(); i++) {
|
||||||
logVars_[i] = theta.newNameFor (logVars_[i]);
|
logVars_[i] = theta.newNameFor (logVars_[i]);
|
||||||
}
|
}
|
||||||
@ -308,11 +387,7 @@ ConstraintTree::remove (const LogVarSet& X)
|
|||||||
unsigned level = getLevel (X.front()) - 1;
|
unsigned level = getLevel (X.front()) - 1;
|
||||||
CTNodes nodes = getNodesAtLevel (level);
|
CTNodes nodes = getNodesAtLevel (level);
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
CTNodes childs = nodes[i]->childs();
|
nodes[i]->removeAndDeleteAllChilds();
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
|
||||||
nodes[i]->removeChild (childs[j]);
|
|
||||||
deleteSubtree (childs[j]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
logVars_.resize (logVars_.size() - X.size());
|
logVars_.resize (logVars_.size() - X.size());
|
||||||
logVarSet_ -= X;
|
logVarSet_ -= X;
|
||||||
@ -545,16 +620,16 @@ ConstraintTree::split (
|
|||||||
for (unsigned i = 0; i < commNodes.size(); i++) {
|
for (unsigned i = 0; i < commNodes.size(); i++) {
|
||||||
commCt->root()->addChild (commNodes[i]);
|
commCt->root()->addChild (commNodes[i]);
|
||||||
}
|
}
|
||||||
//cout << commCt->tupleSet() << " + " ;
|
// cout << commCt->tupleSet() << " + " ;
|
||||||
//cout << exclCt->tupleSet() << " = " ;
|
// cout << exclCt->tupleSet() << " = " ;
|
||||||
//cout << tupleSet() << endl << endl;
|
// cout << tupleSet() << endl << endl;
|
||||||
// if (((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet()) == false) {
|
// if (((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet()) == false) {
|
||||||
// exportToGraphViz ("_fail.dot", true);
|
// exportToGraphViz ("_fail.dot", true);
|
||||||
// commCt->exportToGraphViz ("_fail_comm.dot", true);
|
// commCt->exportToGraphViz ("_fail_comm.dot", true);
|
||||||
// exclCt->exportToGraphViz ("_fail_excl.dot", true);
|
// exclCt->exportToGraphViz ("_fail_excl.dot", true);
|
||||||
// }
|
// }
|
||||||
assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
|
// assert ((commCt->tupleSet() | exclCt->tupleSet()) == tupleSet());
|
||||||
assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
|
// assert ((exclCt->tupleSet (stopLevel) & ct->tupleSet (stopLevel)).empty());
|
||||||
return {commCt, exclCt};
|
return {commCt, exclCt};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -601,36 +676,32 @@ ConstraintTree::jointCountNormalize (
|
|||||||
LogVar X_new1,
|
LogVar X_new1,
|
||||||
LogVar X_new2)
|
LogVar X_new2)
|
||||||
{
|
{
|
||||||
exportToGraphViz ("C.dot", true);
|
|
||||||
commCt->exportToGraphViz ("C_comm.dot", true);
|
|
||||||
exclCt->exportToGraphViz ("C_exlc.dot", true);
|
|
||||||
unsigned N = getConditionalCount (X);
|
unsigned N = getConditionalCount (X);
|
||||||
cout << "My tuples: " << tupleSet() << endl;
|
// cout << "My tuples: " << tupleSet() << endl;
|
||||||
cout << "CommCt tuples: " << commCt->tupleSet() << endl;
|
// cout << "CommCt tuples: " << commCt->tupleSet() << endl;
|
||||||
cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
|
// cout << "ExclCt tuples: " << exclCt->tupleSet() << endl;
|
||||||
cout << "Counted Lv: " << X << endl;
|
// cout << "Counted Lv: " << X << endl;
|
||||||
cout << "Original N: " << N << endl;
|
// cout << "X_new1: " << X_new1 << endl;
|
||||||
cout << endl;
|
// cout << "X_new2: " << X_new2 << endl;
|
||||||
|
// cout << "Original N: " << N << endl;
|
||||||
|
// cout << endl;
|
||||||
|
|
||||||
ConstraintTrees normCts1 = commCt->countNormalize (X);
|
ConstraintTrees normCts1 = commCt->countNormalize (X);
|
||||||
vector<unsigned> counts1 (normCts1.size());
|
vector<unsigned> counts1 (normCts1.size());
|
||||||
for (unsigned i = 0; i < normCts1.size(); i++) {
|
for (unsigned i = 0; i < normCts1.size(); i++) {
|
||||||
counts1[i] = normCts1[i]->getConditionalCount (X);
|
counts1[i] = normCts1[i]->getConditionalCount (X);
|
||||||
cout << "normCts1[" << i << "] #" << counts1[i] ;
|
// cout << "normCts1[" << i << "] #" << counts1[i] ;
|
||||||
cout << " " << normCts1[i]->tupleSet() << endl;
|
// cout << " " << normCts1[i]->tupleSet() << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstraintTrees normCts2 = exclCt->countNormalize (X);
|
ConstraintTrees normCts2 = exclCt->countNormalize (X);
|
||||||
vector<unsigned> counts2 (normCts2.size());
|
vector<unsigned> counts2 (normCts2.size());
|
||||||
for (unsigned i = 0; i < normCts2.size(); i++) {
|
for (unsigned i = 0; i < normCts2.size(); i++) {
|
||||||
counts2[i] = normCts2[i]->getConditionalCount (X);
|
counts2[i] = normCts2[i]->getConditionalCount (X);
|
||||||
cout << "normCts2[" << i << "] #" << counts2[i] ;
|
// cout << "normCts2[" << i << "] #" << counts2[i] ;
|
||||||
cout << " " << normCts2[i]->tupleSet() << endl;
|
// cout << " " << normCts2[i]->tupleSet() << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
// cout << endl;
|
||||||
|
|
||||||
cout << "1###### " << normCts1.size() << endl;
|
|
||||||
cout << "2###### " << normCts2.size() << endl;
|
|
||||||
|
|
||||||
ConstraintTree* excl1 = 0;
|
ConstraintTree* excl1 = 0;
|
||||||
for (unsigned i = 0; i < normCts1.size(); i++) {
|
for (unsigned i = 0; i < normCts1.size(); i++) {
|
||||||
@ -638,7 +709,7 @@ ConstraintTree::jointCountNormalize (
|
|||||||
excl1 = normCts1[i];
|
excl1 = normCts1[i];
|
||||||
normCts1.erase (normCts1.begin() + i);
|
normCts1.erase (normCts1.begin() + i);
|
||||||
counts1.erase (counts1.begin() + i);
|
counts1.erase (counts1.begin() + i);
|
||||||
cout << ">joint-count(" << N << ",0)" << endl;
|
// cout << "joint-count(" << N << ",0)" << endl;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -649,22 +720,21 @@ ConstraintTree::jointCountNormalize (
|
|||||||
excl2 = normCts2[i];
|
excl2 = normCts2[i];
|
||||||
normCts2.erase (normCts2.begin() + i);
|
normCts2.erase (normCts2.begin() + i);
|
||||||
counts2.erase (counts2.begin() + i);
|
counts2.erase (counts2.begin() + i);
|
||||||
cout << ">>joint-count(0," << N << ")" << endl;
|
// cout << "joint-count(0," << N << ")" << endl;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cout << "3###### " << normCts1.size() << endl;
|
|
||||||
cout << "4###### " << normCts2.size() << endl;
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < normCts1.size(); i++) {
|
for (unsigned i = 0; i < normCts1.size(); i++) {
|
||||||
unsigned j;
|
unsigned j;
|
||||||
for (j = 0; counts1[i] + counts2[j] != N; j++) ;
|
for (j = 0; counts1[i] + counts2[j] != N; j++) ;
|
||||||
cout << "joint-count(" << counts1[i] << "," << counts2[j] << ")" << endl;
|
// cout << "joint-count(" << counts1[i] ;
|
||||||
|
// cout << "," << counts2[j] << ")" << endl;
|
||||||
const CTNodes& childs = normCts2[j]->root_->childs();
|
const CTNodes& childs = normCts2[j]->root_->childs();
|
||||||
for (unsigned k = 0; k < childs.size(); k++) {
|
for (unsigned k = 0; k < childs.size(); k++) {
|
||||||
normCts1[i]->root_->addChild (childs[k]);
|
normCts1[i]->root_->addChild (CTNode::copySubtree (childs[k]));
|
||||||
}
|
}
|
||||||
|
delete normCts2[j];
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstraintTrees cts = normCts1;
|
ConstraintTrees cts = normCts1;
|
||||||
@ -683,11 +753,6 @@ ConstraintTree::jointCountNormalize (
|
|||||||
cts.push_back (excl2);
|
cts.push_back (excl2);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < cts.size(); i++) {
|
|
||||||
stringstream ss;
|
|
||||||
ss << "aaacts_" << i + 1 << ".dot" ;
|
|
||||||
cts[i]->exportToGraphViz (ss.str().c_str(), true);
|
|
||||||
}
|
|
||||||
return cts;
|
return cts;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -735,11 +800,11 @@ ConstraintTree::expand (LogVar X)
|
|||||||
unsigned nrSymbols = getConditionalCount (X);
|
unsigned nrSymbols = getConditionalCount (X);
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
Symbols symbols;
|
Symbols symbols;
|
||||||
CTNodes childs = nodes[i]->childs();
|
const CTNodes& childs = nodes[i]->childs();
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
for (unsigned j = 0; j < childs.size(); j++) {
|
||||||
symbols.push_back (childs[j]->symbol());
|
symbols.push_back (childs[j]->symbol());
|
||||||
nodes[i]->removeChild (childs[j]);
|
|
||||||
}
|
}
|
||||||
|
nodes[i]->removeAndDeleteAllChilds();
|
||||||
CTNode* prev = nodes[i];
|
CTNode* prev = nodes[i];
|
||||||
assert (symbols.size() == nrSymbols);
|
assert (symbols.size() == nrSymbols);
|
||||||
for (unsigned j = 0; j < nrSymbols; j++) {
|
for (unsigned j = 0; j < nrSymbols; j++) {
|
||||||
@ -768,7 +833,7 @@ ConstraintTree::ground (LogVar X)
|
|||||||
ConstraintTrees cts;
|
ConstraintTrees cts;
|
||||||
const CTNodes& nodes = root_->childs();
|
const CTNodes& nodes = root_->childs();
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
CTNode* copy = copySubtree (nodes[i]);
|
CTNode* copy = CTNode::copySubtree (nodes[i]);
|
||||||
copy->setSymbol (nodes[i]->symbol());
|
copy->setSymbol (nodes[i]->symbol());
|
||||||
ConstraintTree* newCt = new ConstraintTree (logVars_);
|
ConstraintTree* newCt = new ConstraintTree (logVars_);
|
||||||
newCt->root()->addChild (copy);
|
newCt->root()->addChild (copy);
|
||||||
@ -840,19 +905,19 @@ ConstraintTree::getNodesAtLevel (unsigned level) const
|
|||||||
void
|
void
|
||||||
ConstraintTree::swapLogVar (LogVar X)
|
ConstraintTree::swapLogVar (LogVar X)
|
||||||
{
|
{
|
||||||
|
TupleSet before = tupleSet();
|
||||||
LogVars::iterator it =
|
LogVars::iterator it =
|
||||||
std::find (logVars_.begin(),logVars_.end(), X);
|
std::find (logVars_.begin(),logVars_.end(), X);
|
||||||
assert (it != logVars_.end());
|
assert (it != logVars_.end());
|
||||||
unsigned pos = std::distance (logVars_.begin(), it);
|
unsigned pos = std::distance (logVars_.begin(), it);
|
||||||
|
|
||||||
const CTNodes& nodes = getNodesAtLevel (pos);
|
const CTNodes& nodes = getNodesAtLevel (pos);
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||||
const CTNodes childs = nodes[i]->childs();
|
CTNodes childsCopy = nodes[i]->childs();
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
nodes[i]->removeChilds();
|
||||||
nodes[i]->removeChild (childs[j]);
|
for (unsigned j = 0; j < childsCopy.size(); j++) {
|
||||||
const CTNodes grandsons = childs[j]->childs();
|
const CTNodes grandsons = childsCopy[j]->childs();
|
||||||
for (unsigned k = 0; k < grandsons.size(); k++) {
|
for (unsigned k = 0; k < grandsons.size(); k++) {
|
||||||
CTNode* childCopy = new CTNode (*childs[j]);
|
CTNode* childCopy = new CTNode (*childsCopy[j]);
|
||||||
const CTNodes greatGrandsons = grandsons[k]->childs();
|
const CTNodes greatGrandsons = grandsons[k]->childs();
|
||||||
for (unsigned t = 0; t < greatGrandsons.size(); t++) {
|
for (unsigned t = 0; t < greatGrandsons.size(); t++) {
|
||||||
grandsons[k]->removeChild (greatGrandsons[t]);
|
grandsons[k]->removeChild (greatGrandsons[t]);
|
||||||
@ -863,10 +928,9 @@ ConstraintTree::swapLogVar (LogVar X)
|
|||||||
grandsons[k]->setLevel (grandsons[k]->level() - 1);
|
grandsons[k]->setLevel (grandsons[k]->level() - 1);
|
||||||
nodes[i]->addChild (grandsons[k], false);
|
nodes[i]->addChild (grandsons[k], false);
|
||||||
}
|
}
|
||||||
delete childs[j];
|
delete childsCopy[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::swap (logVars_[pos], logVars_[pos + 1]);
|
std::swap (logVars_[pos], logVars_[pos + 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -884,7 +948,7 @@ ConstraintTree::join (
|
|||||||
if (currIdx == tuple.size() - 1) {
|
if (currIdx == tuple.size() - 1) {
|
||||||
const CTNodes& childs = appendNode->childs();
|
const CTNodes& childs = appendNode->childs();
|
||||||
for (unsigned i = 0; i < childs.size(); i++) {
|
for (unsigned i = 0; i < childs.size(); i++) {
|
||||||
n->addChild (copySubtree (childs[i]));
|
n->addChild (CTNode::copySubtree (childs[i]));
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -985,7 +1049,7 @@ ConstraintTree::countNormalize (
|
|||||||
{
|
{
|
||||||
if (n->level() == stopLevel) {
|
if (n->level() == stopLevel) {
|
||||||
return vector<pair<CTNode*, unsigned>>() = {
|
return vector<pair<CTNode*, unsigned>>() = {
|
||||||
make_pair (copySubtree (n), countTuples (n))
|
make_pair (CTNode::copySubtree (n), countTuples (n))
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1004,65 +1068,6 @@ ConstraintTree::countNormalize (
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
void
|
|
||||||
ConstraintTree::split (
|
|
||||||
CTNode* n1,
|
|
||||||
CTNode* n2,
|
|
||||||
CTNodes& nodes,
|
|
||||||
unsigned stopLevel)
|
|
||||||
{
|
|
||||||
CTNodes& childs1 = n1->childs();
|
|
||||||
CTNodes& childs2 = n2->childs();
|
|
||||||
// cout << string (n1->level() * 8, '-') << "Level = " << n1->level() + 1;
|
|
||||||
// cout << ", #I = " << childs1.size();
|
|
||||||
// cout << ", #J = " << childs2.size() << endl;
|
|
||||||
for (unsigned i = 0; i < childs1.size(); i++) {
|
|
||||||
for (unsigned j = 0; j < childs2.size(); j++) {
|
|
||||||
if (childs1[i]->symbol() != childs2[j]->symbol()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (childs1[i]->level() == stopLevel) {
|
|
||||||
CTNode* newNode = copySubtree (childs1[i]);
|
|
||||||
newNode->setSymbol (childs1[i]->symbol());
|
|
||||||
nodes.push_back (newNode);
|
|
||||||
childs1[i]->setSymbol (Symbol::invalid());
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
CTNodes lowerNodes;
|
|
||||||
split (childs1[i], childs2[j], lowerNodes, stopLevel);
|
|
||||||
if (lowerNodes.empty() == false) {
|
|
||||||
CTNode* me = new CTNode (childs1[i]->symbol(), childs1[i]->level());
|
|
||||||
for (unsigned k = 0; k < lowerNodes.size(); k++) {
|
|
||||||
me->addChild (lowerNodes[k]);
|
|
||||||
}
|
|
||||||
nodes.push_back (me);
|
|
||||||
}
|
|
||||||
if (childs1[i]->isLeaf()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < (int)childs1.size(); i++) {
|
|
||||||
// cout << string (n1->level() * 8, '-') << childs1[i];
|
|
||||||
if (childs1[i]->symbol() == Symbol::invalid()) {
|
|
||||||
// cout << " empty, removing..." ;
|
|
||||||
n1->removeChild (childs1[i]);
|
|
||||||
i --;
|
|
||||||
} else if (childs1[i]->isLeaf() &&
|
|
||||||
childs1[i]->level() != stopLevel) {
|
|
||||||
// cout << " leaf, removing..." ;
|
|
||||||
n1->removeChild (childs1[i]);
|
|
||||||
i --;
|
|
||||||
}
|
|
||||||
// cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
ConstraintTree::split (
|
ConstraintTree::split (
|
||||||
@ -1085,7 +1090,7 @@ ConstraintTree::split (
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (childs1[i]->level() == stopLevel) {
|
if (childs1[i]->level() == stopLevel) {
|
||||||
CTNode* newNode = copySubtree (childs1[i]);
|
CTNode* newNode = CTNode::copySubtree (childs1[i]);
|
||||||
nodes.push_back (newNode);
|
nodes.push_back (newNode);
|
||||||
childs1[i]->setSymbol (Symbol::invalid());
|
childs1[i]->setSymbol (Symbol::invalid());
|
||||||
} else {
|
} else {
|
||||||
@ -1103,11 +1108,11 @@ ConstraintTree::split (
|
|||||||
|
|
||||||
for (int i = 0; i < (int)childs1.size(); i++) {
|
for (int i = 0; i < (int)childs1.size(); i++) {
|
||||||
if (childs1[i]->symbol() == Symbol::invalid()) {
|
if (childs1[i]->symbol() == Symbol::invalid()) {
|
||||||
n1->removeChild (childs1[i]);
|
n1->removeAndDeleteChild (childs1[i]);
|
||||||
i --;
|
i --;
|
||||||
} else if (childs1[i]->isLeaf() &&
|
} else if (childs1[i]->isLeaf() &&
|
||||||
childs1[i]->level() != stopLevel) {
|
childs1[i]->level() != stopLevel) {
|
||||||
n1->removeChild (childs1[i]);
|
n1->removeAndDeleteChild (childs1[i]);
|
||||||
i --;
|
i --;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1141,29 +1146,3 @@ ConstraintTree::overlap (
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CTNode*
|
|
||||||
ConstraintTree::copySubtree (const CTNode* n)
|
|
||||||
{
|
|
||||||
CTNode* newNode = new CTNode (*n);
|
|
||||||
const CTNodes& childs = n->childs();
|
|
||||||
for (unsigned i = 0; i < childs.size(); i++) {
|
|
||||||
newNode->addChild (copySubtree (childs[i]));
|
|
||||||
}
|
|
||||||
return newNode;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ConstraintTree::deleteSubtree (CTNode* n)
|
|
||||||
{
|
|
||||||
assert (n);
|
|
||||||
const CTNodes& childs = n->childs();
|
|
||||||
for (unsigned i = 0; i < childs.size(); i++) {
|
|
||||||
deleteSubtree (childs[i]);
|
|
||||||
}
|
|
||||||
delete n;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -21,7 +21,6 @@ typedef vector<ConstraintTree*> ConstraintTrees;
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CTNode
|
class CTNode
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -47,29 +46,44 @@ class CTNode
|
|||||||
|
|
||||||
bool isLeaf (void) const { return childs_.empty(); }
|
bool isLeaf (void) const { return childs_.empty(); }
|
||||||
|
|
||||||
void addChild (CTNode*, bool = true);
|
void addChild (CTNode*, bool = true);
|
||||||
void removeChild (CTNode*);
|
|
||||||
SymbolSet childSymbols (void) const;
|
void removeChild (CTNode*);
|
||||||
|
|
||||||
|
void removeChilds (void);
|
||||||
|
|
||||||
|
void removeAndDeleteChild (CTNode*);
|
||||||
|
|
||||||
|
void removeAndDeleteAllChilds (void);
|
||||||
|
|
||||||
|
SymbolSet childSymbols (void) const;
|
||||||
|
|
||||||
|
static CTNode* copySubtree (const CTNode*);
|
||||||
|
|
||||||
|
static void deleteSubtree (CTNode*);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void updateChildLevels (CTNode*, unsigned);
|
void updateChildLevels (CTNode*, unsigned);
|
||||||
|
|
||||||
Symbol symbol_;
|
Symbol symbol_;
|
||||||
CTNodes childs_;
|
CTNodes childs_;
|
||||||
unsigned level_;
|
unsigned level_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ostream& operator<< (ostream &out, const CTNode&);
|
ostream& operator<< (ostream &out, const CTNode&);
|
||||||
|
|
||||||
|
|
||||||
class ConstraintTree
|
class ConstraintTree
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
ConstraintTree (unsigned);
|
||||||
|
|
||||||
ConstraintTree (const LogVars&);
|
ConstraintTree (const LogVars&);
|
||||||
|
|
||||||
ConstraintTree (const LogVars&, const Tuples&);
|
ConstraintTree (const LogVars&, const Tuples&);
|
||||||
|
|
||||||
ConstraintTree (const ConstraintTree&);
|
ConstraintTree (const ConstraintTree&);
|
||||||
|
|
||||||
~ConstraintTree (void);
|
~ConstraintTree (void);
|
||||||
|
|
||||||
CTNode* root (void) const { return root_; }
|
CTNode* root (void) const { return root_; }
|
||||||
@ -94,94 +108,95 @@ class ConstraintTree
|
|||||||
assert (LogVarSet (logVars_) == logVarSet_);
|
assert (LogVarSet (logVars_) == logVarSet_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void addTuple (const Tuple&);
|
void addTuple (const Tuple&);
|
||||||
bool containsTuple (const Tuple&);
|
|
||||||
void moveToTop (const LogVars&);
|
bool containsTuple (const Tuple&);
|
||||||
void moveToBottom (const LogVars&);
|
|
||||||
void join (ConstraintTree*, bool = false);
|
void moveToTop (const LogVars&);
|
||||||
unsigned getLevel (LogVar) const;
|
|
||||||
void rename (LogVar, LogVar);
|
void moveToBottom (const LogVars&);
|
||||||
void applySubstitution (const Substitution&);
|
|
||||||
void project (const LogVarSet&);
|
void join (ConstraintTree*, bool = false);
|
||||||
void remove (const LogVarSet&);
|
|
||||||
bool isSingleton (LogVar);
|
unsigned getLevel (LogVar) const;
|
||||||
LogVarSet singletons (void);
|
|
||||||
TupleSet tupleSet (unsigned = 0) const;
|
void rename (LogVar, LogVar);
|
||||||
TupleSet tupleSet (const LogVars&);
|
|
||||||
unsigned size (void) const;
|
void applySubstitution (const Substitution&);
|
||||||
unsigned nrSymbols (LogVar);
|
|
||||||
void exportToGraphViz (const char*, bool = false) const;
|
void project (const LogVarSet&);
|
||||||
bool isCountNormalized (const LogVarSet&);
|
|
||||||
unsigned getConditionalCount (const LogVarSet&);
|
void remove (const LogVarSet&);
|
||||||
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
|
|
||||||
bool isCarteesianProduct (const LogVarSet&) const;
|
bool isSingleton (LogVar);
|
||||||
|
|
||||||
|
LogVarSet singletons (void);
|
||||||
|
|
||||||
|
TupleSet tupleSet (unsigned = 0) const;
|
||||||
|
|
||||||
|
TupleSet tupleSet (const LogVars&);
|
||||||
|
|
||||||
|
unsigned size (void) const;
|
||||||
|
|
||||||
|
unsigned nrSymbols (LogVar);
|
||||||
|
|
||||||
|
void exportToGraphViz (const char*, bool = false) const;
|
||||||
|
|
||||||
|
bool isCountNormalized (const LogVarSet&);
|
||||||
|
|
||||||
|
unsigned getConditionalCount (const LogVarSet&);
|
||||||
|
|
||||||
|
TinySet<unsigned> getConditionalCounts (const LogVarSet&);
|
||||||
|
|
||||||
|
bool isCarteesianProduct (const LogVarSet&) const;
|
||||||
|
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> split (
|
std::pair<ConstraintTree*, ConstraintTree*> split (
|
||||||
const Tuple&,
|
const Tuple&, unsigned);
|
||||||
unsigned);
|
|
||||||
|
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> split (
|
std::pair<ConstraintTree*, ConstraintTree*> split (
|
||||||
const ConstraintTree*,
|
const ConstraintTree*, unsigned) const;
|
||||||
unsigned) const;
|
|
||||||
|
|
||||||
ConstraintTrees countNormalize (const LogVarSet&);
|
ConstraintTrees countNormalize (const LogVarSet&);
|
||||||
|
|
||||||
ConstraintTrees jointCountNormalize (
|
ConstraintTrees jointCountNormalize (
|
||||||
ConstraintTree*,
|
ConstraintTree*, ConstraintTree*, LogVar, LogVar, LogVar);
|
||||||
ConstraintTree*,
|
|
||||||
LogVar,
|
|
||||||
LogVar,
|
|
||||||
LogVar);
|
|
||||||
|
|
||||||
static bool identical (
|
static bool identical (
|
||||||
const ConstraintTree*,
|
const ConstraintTree*, const ConstraintTree*, unsigned);
|
||||||
const ConstraintTree*,
|
|
||||||
unsigned);
|
|
||||||
|
|
||||||
static bool overlap (
|
static bool overlap (
|
||||||
const ConstraintTree*,
|
const ConstraintTree*, const ConstraintTree*, unsigned);
|
||||||
const ConstraintTree*,
|
|
||||||
unsigned);
|
|
||||||
|
|
||||||
LogVars expand (LogVar);
|
LogVars expand (LogVar);
|
||||||
ConstraintTrees ground (LogVar);
|
ConstraintTrees ground (LogVar);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
unsigned countTuples (const CTNode*) const;
|
unsigned countTuples (const CTNode*) const;
|
||||||
CTNodes getNodesBelow (CTNode*) const;
|
|
||||||
CTNodes getNodesAtLevel (unsigned) const;
|
|
||||||
void swapLogVar (LogVar);
|
|
||||||
bool join (CTNode*, const Tuple&, unsigned, CTNode*);
|
|
||||||
|
|
||||||
bool indenticalSubtrees (
|
CTNodes getNodesBelow (CTNode*) const;
|
||||||
const CTNode*,
|
|
||||||
const CTNode*,
|
|
||||||
bool) const;
|
|
||||||
|
|
||||||
void getTuples (
|
CTNodes getNodesAtLevel (unsigned) const;
|
||||||
CTNode*,
|
|
||||||
Tuples,
|
void swapLogVar (LogVar);
|
||||||
unsigned,
|
|
||||||
Tuples&,
|
bool join (CTNode*, const Tuple&, unsigned, CTNode*);
|
||||||
CTNodes&) const;
|
|
||||||
|
bool indenticalSubtrees (
|
||||||
|
const CTNode*, const CTNode*, bool) const;
|
||||||
|
|
||||||
|
void getTuples (CTNode*, Tuples, unsigned, Tuples&, CTNodes&) const;
|
||||||
|
|
||||||
vector<std::pair<CTNode*, unsigned>> countNormalize (
|
vector<std::pair<CTNode*, unsigned>> countNormalize (
|
||||||
const CTNode*,
|
const CTNode*, unsigned);
|
||||||
unsigned);
|
|
||||||
|
|
||||||
static void split (
|
static void split (
|
||||||
CTNode*,
|
CTNode*, CTNode*, CTNodes&, unsigned);
|
||||||
CTNode*,
|
|
||||||
CTNodes&,
|
|
||||||
unsigned);
|
|
||||||
|
|
||||||
static bool overlap (const CTNode*, const CTNode*, unsigned);
|
static bool overlap (const CTNode*, const CTNode*, unsigned);
|
||||||
static CTNode* copySubtree (const CTNode*);
|
|
||||||
static void deleteSubtree (CTNode*);
|
|
||||||
|
|
||||||
CTNode* root_;
|
CTNode* root_;
|
||||||
LogVars logVars_;
|
LogVars logVars_;
|
||||||
LogVarSet logVarSet_;
|
LogVarSet logVarSet_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,45 +0,0 @@
|
|||||||
#ifndef HORUS_DISTRIBUTION_H
|
|
||||||
#define HORUS_DISTRIBUTION_H
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
//TODO die die die die die
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
struct Distribution
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
Distribution (int id)
|
|
||||||
{
|
|
||||||
this->id = id;
|
|
||||||
}
|
|
||||||
|
|
||||||
Distribution (const Params& params, int id = -1)
|
|
||||||
{
|
|
||||||
this->id = id;
|
|
||||||
this->params = params;
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateParameters (const Params& params)
|
|
||||||
{
|
|
||||||
this->params = params;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool shared (void)
|
|
||||||
{
|
|
||||||
return id != -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
int id;
|
|
||||||
Params params;
|
|
||||||
|
|
||||||
private:
|
|
||||||
DISALLOW_COPY_AND_ASSIGN (Distribution);
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_DISTRIBUTION_H
|
|
||||||
|
|
@ -1,53 +1,39 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#include "ElimGraph.h"
|
#include <fstream>
|
||||||
#include "BayesNet.h"
|
|
||||||
|
|
||||||
|
#include "ElimGraph.h"
|
||||||
|
|
||||||
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
||||||
|
|
||||||
|
|
||||||
ElimGraph::ElimGraph (const BayesNet& bayesNet)
|
ElimGraph::ElimGraph (const vector<Factor*>& factors)
|
||||||
{
|
{
|
||||||
const BnNodeSet& bnNodes = bayesNet.getBayesNodes();
|
for (unsigned i = 0; i < factors.size(); i++) {
|
||||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
const VarIds& vids = factors[i]->arguments();
|
||||||
if (bnNodes[i]->hasEvidence() == false) {
|
for (unsigned j = 0; j < vids.size() - 1; j++) {
|
||||||
addNode (new EgNode (bnNodes[i]));
|
EgNode* n1 = getEgNode (vids[j]);
|
||||||
}
|
if (n1 == 0) {
|
||||||
}
|
n1 = new EgNode (vids[j], factors[i]->range (j));
|
||||||
|
addNode (n1);
|
||||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
}
|
||||||
if (bnNodes[i]->hasEvidence() == false) {
|
for (unsigned k = j + 1; k < vids.size(); k++) {
|
||||||
EgNode* n = getEgNode (bnNodes[i]->varId());
|
EgNode* n2 = getEgNode (vids[k]);
|
||||||
const BnNodeSet& childs = bnNodes[i]->getChilds();
|
if (n2 == 0) {
|
||||||
for (unsigned j = 0; j < childs.size(); j++) {
|
n2 = new EgNode (vids[k], factors[i]->range (k));
|
||||||
if (childs[j]->hasEvidence() == false) {
|
addNode (n2);
|
||||||
addEdge (n, getEgNode (childs[j]->varId()));
|
|
||||||
}
|
}
|
||||||
|
if (neighbors (n1, n2) == false) {
|
||||||
|
addEdge (n1, n2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (vids.size() == 1) {
|
||||||
|
if (getEgNode (vids[0]) == 0) {
|
||||||
|
addNode (new EgNode (vids[0], factors[i]->range (0)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
|
||||||
vector<EgNode*> neighs;
|
|
||||||
const vector<BayesNode*>& parents = bnNodes[i]->getParents();
|
|
||||||
for (unsigned i = 0; i < parents.size(); i++) {
|
|
||||||
if (parents[i]->hasEvidence() == false) {
|
|
||||||
neighs.push_back (getEgNode (parents[i]->varId()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (neighs.size() > 0) {
|
|
||||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
|
||||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
|
||||||
if (!neighbors (neighs[i], neighs[j])) {
|
|
||||||
addEdge (neighs[i], neighs[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
setIndexes();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -61,40 +47,16 @@ ElimGraph::~ElimGraph (void)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ElimGraph::addNode (EgNode* n)
|
|
||||||
{
|
|
||||||
nodes_.push_back (n);
|
|
||||||
varMap_.insert (make_pair (n->varId(), n));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
EgNode*
|
|
||||||
ElimGraph::getEgNode (VarId vid) const
|
|
||||||
{
|
|
||||||
unordered_map<VarId,EgNode*>::const_iterator it =varMap_.find (vid);
|
|
||||||
if (it ==varMap_.end()) {
|
|
||||||
return 0;
|
|
||||||
} else {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarIds
|
VarIds
|
||||||
ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
||||||
{
|
{
|
||||||
VarIds elimOrder;
|
VarIds elimOrder;
|
||||||
marked_.resize (nodes_.size(), false);
|
marked_.resize (nodes_.size(), false);
|
||||||
|
|
||||||
for (unsigned i = 0; i < exclude.size(); i++) {
|
for (unsigned i = 0; i < exclude.size(); i++) {
|
||||||
|
assert (getEgNode (exclude[i]));
|
||||||
EgNode* node = getEgNode (exclude[i]);
|
EgNode* node = getEgNode (exclude[i]);
|
||||||
assert (node);
|
|
||||||
marked_[*node] = true;
|
marked_[*node] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned nVarsToEliminate = nodes_.size() - exclude.size();
|
unsigned nVarsToEliminate = nodes_.size() - exclude.size();
|
||||||
for (unsigned i = 0; i < nVarsToEliminate; i++) {
|
for (unsigned i = 0; i < nVarsToEliminate; i++) {
|
||||||
EgNode* node = getLowestCostNode();
|
EgNode* node = getLowestCostNode();
|
||||||
@ -107,6 +69,100 @@ ElimGraph::getEliminatingOrder (const VarIds& exclude)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ElimGraph::print (void) const
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
|
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||||
|
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||||
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
|
cout << " " << neighs[j]->label();
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ElimGraph::exportToGraphViz (
|
||||||
|
const char* fileName,
|
||||||
|
bool showNeighborless,
|
||||||
|
const VarIds& highlightVarIds) const
|
||||||
|
{
|
||||||
|
ofstream out (fileName);
|
||||||
|
if (!out.is_open()) {
|
||||||
|
cerr << "error: cannot open file to write at " ;
|
||||||
|
cerr << "Markov::exportToDotFile()" << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
out << "strict graph {" << endl;
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
|
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
|
||||||
|
out << '"' << nodes_[i]->label() << '"' << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
||||||
|
EgNode* node =getEgNode (highlightVarIds[i]);
|
||||||
|
if (node) {
|
||||||
|
out << '"' << node->label() << '"' ;
|
||||||
|
out << " [shape=box3d]" << endl;
|
||||||
|
} else {
|
||||||
|
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||||
|
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||||
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
|
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
||||||
|
out << '"' << neighs[j]->label() << '"' << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out << "}" << endl;
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
VarIds
|
||||||
|
ElimGraph::getEliminationOrder (
|
||||||
|
const vector<Factor*> factors,
|
||||||
|
VarIds excludedVids)
|
||||||
|
{
|
||||||
|
ElimGraph graph (factors);
|
||||||
|
// graph.print();
|
||||||
|
// graph.exportToGraphViz ("_egg.dot");
|
||||||
|
return graph.getEliminatingOrder (excludedVids);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ElimGraph::addNode (EgNode* n)
|
||||||
|
{
|
||||||
|
nodes_.push_back (n);
|
||||||
|
n->setIndex (nodes_.size() - 1);
|
||||||
|
varMap_.insert (make_pair (n->varId(), n));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
EgNode*
|
||||||
|
ElimGraph::getEgNode (VarId vid) const
|
||||||
|
{
|
||||||
|
unordered_map<VarId, EgNode*>::const_iterator it;
|
||||||
|
it = varMap_.find (vid);
|
||||||
|
return (it != varMap_.end()) ? it->second : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
EgNode*
|
EgNode*
|
||||||
ElimGraph::getLowestCostNode (void) const
|
ElimGraph::getLowestCostNode (void) const
|
||||||
{
|
{
|
||||||
@ -164,7 +220,7 @@ ElimGraph::getWeightCost (const EgNode* n) const
|
|||||||
const vector<EgNode*>& neighs = n->neighbors();
|
const vector<EgNode*>& neighs = n->neighbors();
|
||||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||||
if (marked_[*neighs[i]] == false) {
|
if (marked_[*neighs[i]] == false) {
|
||||||
cost *= neighs[i]->nrStates();
|
cost *= neighs[i]->range();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return cost;
|
return cost;
|
||||||
@ -204,7 +260,7 @@ ElimGraph::getWeightedFillCost (const EgNode* n) const
|
|||||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||||
if (marked_[*neighs[j]] == true) continue;
|
if (marked_[*neighs[j]] == true) continue;
|
||||||
if (!neighbors (neighs[i], neighs[j])) {
|
if (!neighbors (neighs[i], neighs[j])) {
|
||||||
cost += neighs[i]->nrStates() * neighs[j]->nrStates();
|
cost += neighs[i]->range() * neighs[j]->range();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -245,78 +301,3 @@ ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ElimGraph::setIndexes (void)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
nodes_[i]->setIndex (i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ElimGraph::printGraphicalModel (void) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
|
||||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
|
||||||
cout << " " << neighs[j]->label();
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
ElimGraph::exportToGraphViz (const char* fileName,
|
|
||||||
bool showNeighborless,
|
|
||||||
const VarIds& highlightVarIds) const
|
|
||||||
{
|
|
||||||
ofstream out (fileName);
|
|
||||||
if (!out.is_open()) {
|
|
||||||
cerr << "error: cannot open file to write at " ;
|
|
||||||
cerr << "Markov::exportToDotFile()" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
|
|
||||||
out << "strict graph {" << endl;
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
|
|
||||||
out << '"' << nodes_[i]->label() << '"' ;
|
|
||||||
if (nodes_[i]->hasEvidence()) {
|
|
||||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
|
||||||
} else {
|
|
||||||
out << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
|
||||||
EgNode* node =getEgNode (highlightVarIds[i]);
|
|
||||||
if (node) {
|
|
||||||
out << '"' << node->label() << '"' ;
|
|
||||||
out << " [shape=box3d]" << endl;
|
|
||||||
} else {
|
|
||||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
|
||||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
|
||||||
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
|
||||||
out << '"' << neighs[j]->label() << '"' << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
out << "}" << endl;
|
|
||||||
out.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -17,15 +17,15 @@ enum ElimHeuristic
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class EgNode : public VarNode {
|
class EgNode : public Var
|
||||||
|
{
|
||||||
public:
|
public:
|
||||||
EgNode (VarNode* var) : VarNode (var) { }
|
EgNode (VarId vid, unsigned range) : Var (vid, range) { }
|
||||||
void addNeighbor (EgNode* n)
|
|
||||||
{
|
void addNeighbor (EgNode* n) { neighs_.push_back (n); }
|
||||||
neighs_.push_back (n);
|
|
||||||
}
|
|
||||||
|
|
||||||
const vector<EgNode*>& neighbors (void) const { return neighs_; }
|
const vector<EgNode*>& neighbors (void) const { return neighs_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
vector<EgNode*> neighs_;
|
vector<EgNode*> neighs_;
|
||||||
};
|
};
|
||||||
@ -34,22 +34,18 @@ class EgNode : public VarNode {
|
|||||||
class ElimGraph
|
class ElimGraph
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
ElimGraph (const BayesNet&);
|
ElimGraph (const vector<Factor*>&); // TODO
|
||||||
|
|
||||||
~ElimGraph (void);
|
~ElimGraph (void);
|
||||||
|
|
||||||
void addEdge (EgNode* n1, EgNode* n2)
|
VarIds getEliminatingOrder (const VarIds&);
|
||||||
{
|
|
||||||
assert (n1 != n2);
|
void print (void) const;
|
||||||
n1->addNeighbor (n2);
|
|
||||||
n2->addNeighbor (n1);
|
void exportToGraphViz (const char*, bool = true,
|
||||||
}
|
const VarIds& = VarIds()) const;
|
||||||
void addNode (EgNode*);
|
|
||||||
EgNode* getEgNode (VarId) const;
|
static VarIds getEliminationOrder (const vector<Factor*>, VarIds);
|
||||||
VarIds getEliminatingOrder (const VarIds&);
|
|
||||||
void printGraphicalModel (void) const;
|
|
||||||
void exportToGraphViz (const char*, bool = true,
|
|
||||||
const VarIds& = VarIds()) const;
|
|
||||||
void setIndexes();
|
|
||||||
|
|
||||||
static void setEliminationHeuristic (ElimHeuristic h)
|
static void setEliminationHeuristic (ElimHeuristic h)
|
||||||
{
|
{
|
||||||
@ -57,18 +53,34 @@ class ElimGraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
EgNode* getLowestCostNode (void) const;
|
|
||||||
unsigned getNeighborsCost (const EgNode*) const;
|
|
||||||
unsigned getWeightCost (const EgNode*) const;
|
|
||||||
unsigned getFillCost (const EgNode*) const;
|
|
||||||
unsigned getWeightedFillCost (const EgNode*) const;
|
|
||||||
void connectAllNeighbors (const EgNode*);
|
|
||||||
bool neighbors (const EgNode*, const EgNode*) const;
|
|
||||||
|
|
||||||
|
void addEdge (EgNode* n1, EgNode* n2)
|
||||||
|
{
|
||||||
|
assert (n1 != n2);
|
||||||
|
n1->addNeighbor (n2);
|
||||||
|
n2->addNeighbor (n1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void addNode (EgNode*);
|
||||||
|
|
||||||
|
EgNode* getEgNode (VarId) const;
|
||||||
|
EgNode* getLowestCostNode (void) const;
|
||||||
|
|
||||||
|
unsigned getNeighborsCost (const EgNode*) const;
|
||||||
|
|
||||||
|
unsigned getWeightCost (const EgNode*) const;
|
||||||
|
|
||||||
|
unsigned getFillCost (const EgNode*) const;
|
||||||
|
|
||||||
|
unsigned getWeightedFillCost (const EgNode*) const;
|
||||||
|
|
||||||
|
void connectAllNeighbors (const EgNode*);
|
||||||
|
|
||||||
|
bool neighbors (const EgNode*, const EgNode*) const;
|
||||||
|
|
||||||
vector<EgNode*> nodes_;
|
vector<EgNode*> nodes_;
|
||||||
vector<bool> marked_;
|
vector<bool> marked_;
|
||||||
unordered_map<VarId,EgNode*> varMap_;
|
unordered_map<VarId, EgNode*> varMap_;
|
||||||
static ElimHeuristic elimHeuristic_;
|
static ElimHeuristic elimHeuristic_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
|
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
#include "Util.h"
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (const Factor& g)
|
Factor::Factor (const Factor& g)
|
||||||
@ -18,206 +18,33 @@ Factor::Factor (const Factor& g)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (VarId vid, unsigned nStates)
|
Factor::Factor (
|
||||||
|
const VarIds& vids,
|
||||||
|
const Ranges& ranges,
|
||||||
|
const Params& params,
|
||||||
|
unsigned distId)
|
||||||
{
|
{
|
||||||
varids_.push_back (vid);
|
args_ = vids;
|
||||||
ranges_.push_back (nStates);
|
|
||||||
dist_ = new Distribution (Params (nStates, 1.0));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (const VarNodes& vars)
|
|
||||||
{
|
|
||||||
int nParams = 1;
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
varids_.push_back (vars[i]->varId());
|
|
||||||
ranges_.push_back (vars[i]->nrStates());
|
|
||||||
nParams *= vars[i]->nrStates();
|
|
||||||
}
|
|
||||||
// create a uniform distribution
|
|
||||||
double val = 1.0 / nParams;
|
|
||||||
dist_ = new Distribution (Params (nParams, val));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (VarId vid, unsigned nStates, const Params& params)
|
|
||||||
{
|
|
||||||
varids_.push_back (vid);
|
|
||||||
ranges_.push_back (nStates);
|
|
||||||
dist_ = new Distribution (params);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (VarNodes& vars, Distribution* dist)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
varids_.push_back (vars[i]->varId());
|
|
||||||
ranges_.push_back (vars[i]->nrStates());
|
|
||||||
}
|
|
||||||
dist_ = dist;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (const VarNodes& vars, const Params& params)
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
varids_.push_back (vars[i]->varId());
|
|
||||||
ranges_.push_back (vars[i]->nrStates());
|
|
||||||
}
|
|
||||||
dist_ = new Distribution (params);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::Factor (const VarIds& vids,
|
|
||||||
const Ranges& ranges,
|
|
||||||
const Params& params)
|
|
||||||
{
|
|
||||||
varids_ = vids;
|
|
||||||
ranges_ = ranges;
|
ranges_ = ranges;
|
||||||
dist_ = new Distribution (params);
|
params_ = params;
|
||||||
|
distId_ = distId;
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Factor::~Factor (void)
|
Factor::Factor (
|
||||||
|
const Vars& vars,
|
||||||
|
const Params& params,
|
||||||
|
unsigned distId)
|
||||||
{
|
{
|
||||||
if (dist_->shared() == false) {
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
delete dist_;
|
args_.push_back (vars[i]->varId());
|
||||||
}
|
ranges_.push_back (vars[i]->range());
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::setParameters (const Params& params)
|
|
||||||
{
|
|
||||||
assert (dist_->params.size() == params.size());
|
|
||||||
dist_->params = params;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::copyFromFactor (const Factor& g)
|
|
||||||
{
|
|
||||||
varids_ = g.getVarIds();
|
|
||||||
ranges_ = g.getRanges();
|
|
||||||
dist_ = new Distribution (g.getParameters());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::multiply (const Factor& g)
|
|
||||||
{
|
|
||||||
if (varids_.size() == 0) {
|
|
||||||
copyFromFactor (g);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const VarIds& g_varids = g.getVarIds();
|
|
||||||
const Ranges& g_ranges = g.getRanges();
|
|
||||||
const Params& g_params = g.getParameters();
|
|
||||||
|
|
||||||
if (varids_ == g_varids) {
|
|
||||||
// optimization: if the factors contain the same set of variables,
|
|
||||||
// we can do a 1 to 1 operation on the parameters
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
Util::add (dist_->params, g_params);
|
|
||||||
} else {
|
|
||||||
Util::multiply (dist_->params, g_params);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
bool sharedVars = false;
|
|
||||||
vector<unsigned> gvarpos;
|
|
||||||
for (unsigned i = 0; i < g_varids.size(); i++) {
|
|
||||||
int idx = indexOf (g_varids[i]);
|
|
||||||
if (idx == -1) {
|
|
||||||
insertVariable (g_varids[i], g_ranges[i]);
|
|
||||||
gvarpos.push_back (varids_.size() - 1);
|
|
||||||
} else {
|
|
||||||
sharedVars = true;
|
|
||||||
gvarpos.push_back (idx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (sharedVars == false) {
|
|
||||||
// optimization: if the original factors doesn't have common variables,
|
|
||||||
// we don't need to marry the states of the common variables
|
|
||||||
unsigned count = 0;
|
|
||||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
dist_->params[i] += g_params[count];
|
|
||||||
} else {
|
|
||||||
dist_->params[i] *= g_params[count];
|
|
||||||
}
|
|
||||||
count ++;
|
|
||||||
if (count >= g_params.size()) {
|
|
||||||
count = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
StatesIndexer indexer (ranges_, false);
|
|
||||||
while (indexer.valid()) {
|
|
||||||
unsigned g_li = 0;
|
|
||||||
unsigned prod = 1;
|
|
||||||
for (int j = gvarpos.size() - 1; j >= 0; j--) {
|
|
||||||
g_li += indexer[gvarpos[j]] * prod;
|
|
||||||
prod *= g_ranges[j];
|
|
||||||
}
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
dist_->params[indexer] += g_params[g_li];
|
|
||||||
} else {
|
|
||||||
dist_->params[indexer] *= g_params[g_li];
|
|
||||||
}
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::insertVariable (VarId varId, unsigned nrStates)
|
|
||||||
{
|
|
||||||
assert (indexOf (varId) == -1);
|
|
||||||
Params oldParams = dist_->params;
|
|
||||||
dist_->params.clear();
|
|
||||||
dist_->params.reserve (oldParams.size() * nrStates);
|
|
||||||
for (unsigned i = 0; i < oldParams.size(); i++) {
|
|
||||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
|
||||||
dist_->params.push_back (oldParams[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
varids_.push_back (varId);
|
|
||||||
ranges_.push_back (nrStates);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::insertVariables (const VarIds& varIds, const Ranges& ranges)
|
|
||||||
{
|
|
||||||
Params oldParams = dist_->params;
|
|
||||||
unsigned nrStates = 1;
|
|
||||||
for (unsigned i = 0; i < varIds.size(); i++) {
|
|
||||||
assert (indexOf (varIds[i]) == -1);
|
|
||||||
varids_.push_back (varIds[i]);
|
|
||||||
ranges_.push_back (ranges[i]);
|
|
||||||
nrStates *= ranges[i];
|
|
||||||
}
|
|
||||||
dist_->params.clear();
|
|
||||||
dist_->params.reserve (oldParams.size() * nrStates);
|
|
||||||
for (unsigned i = 0; i < oldParams.size(); i++) {
|
|
||||||
for (unsigned reps = 0; reps < nrStates; reps++) {
|
|
||||||
dist_->params.push_back (oldParams[i]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
params_ = params;
|
||||||
|
distId_ = distId;
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -226,10 +53,10 @@ void
|
|||||||
Factor::sumOutAllExcept (VarId vid)
|
Factor::sumOutAllExcept (VarId vid)
|
||||||
{
|
{
|
||||||
assert (indexOf (vid) != -1);
|
assert (indexOf (vid) != -1);
|
||||||
while (varids_.back() != vid) {
|
while (args_.back() != vid) {
|
||||||
sumOutLastVariable();
|
sumOutLastVariable();
|
||||||
}
|
}
|
||||||
while (varids_.front() != vid) {
|
while (args_.front() != vid) {
|
||||||
sumOutFirstVariable();
|
sumOutFirstVariable();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -239,9 +66,10 @@ Factor::sumOutAllExcept (VarId vid)
|
|||||||
void
|
void
|
||||||
Factor::sumOutAllExcept (const VarIds& vids)
|
Factor::sumOutAllExcept (const VarIds& vids)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
for (int i = 0; i < (int)args_.size(); i++) {
|
||||||
if (std::find (vids.begin(), vids.end(), varids_[i]) == vids.end()) {
|
if (Util::contains (vids, args_[i]) == false) {
|
||||||
sumOut (varids_[i]);
|
sumOut (args_[i]);
|
||||||
|
i --;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -254,11 +82,11 @@ Factor::sumOut (VarId vid)
|
|||||||
int idx = indexOf (vid);
|
int idx = indexOf (vid);
|
||||||
assert (idx != -1);
|
assert (idx != -1);
|
||||||
|
|
||||||
if (vid == varids_.back()) {
|
if (vid == args_.back()) {
|
||||||
sumOutLastVariable(); // optimization
|
sumOutLastVariable(); // optimization
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (vid == varids_.front()) {
|
if (vid == args_.front()) {
|
||||||
sumOutFirstVariable(); // optimization
|
sumOutFirstVariable(); // optimization
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -271,7 +99,7 @@ Factor::sumOut (VarId vid)
|
|||||||
// on the left of `var', with the states of the remaining vars fixed
|
// on the left of `var', with the states of the remaining vars fixed
|
||||||
unsigned leftVarOffset = 1;
|
unsigned leftVarOffset = 1;
|
||||||
|
|
||||||
for (int i = varids_.size() - 1; i > idx; i--) {
|
for (int i = args_.size() - 1; i > idx; i--) {
|
||||||
varOffset *= ranges_[i];
|
varOffset *= ranges_[i];
|
||||||
leftVarOffset *= ranges_[i];
|
leftVarOffset *= ranges_[i];
|
||||||
}
|
}
|
||||||
@ -280,25 +108,24 @@ Factor::sumOut (VarId vid)
|
|||||||
unsigned offset = 0;
|
unsigned offset = 0;
|
||||||
unsigned count1 = 0;
|
unsigned count1 = 0;
|
||||||
unsigned count2 = 0;
|
unsigned count2 = 0;
|
||||||
unsigned newpsSize = dist_->params.size() / ranges_[idx];
|
unsigned newpsSize = params_.size() / ranges_[idx];
|
||||||
|
|
||||||
Params newps;
|
Params newps;
|
||||||
newps.reserve (newpsSize);
|
newps.reserve (newpsSize);
|
||||||
Params& params = dist_->params;
|
|
||||||
|
|
||||||
while (newps.size() < newpsSize) {
|
while (newps.size() < newpsSize) {
|
||||||
double sum = Util::addIdenty();
|
double sum = LogAware::addIdenty();
|
||||||
for (unsigned i = 0; i < ranges_[idx]; i++) {
|
for (unsigned i = 0; i < ranges_[idx]; i++) {
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::logSum (sum, params[offset]);
|
sum = Util::logSum (sum, params_[offset]);
|
||||||
} else {
|
} else {
|
||||||
sum += params[offset];
|
sum += params_[offset];
|
||||||
}
|
}
|
||||||
offset += varOffset;
|
offset += varOffset;
|
||||||
}
|
}
|
||||||
newps.push_back (sum);
|
newps.push_back (sum);
|
||||||
count1 ++;
|
count1 ++;
|
||||||
if (idx == (int)varids_.size() - 1) {
|
if (idx == (int)args_.size() - 1) {
|
||||||
offset = count1 * ranges_[idx];
|
offset = count1 * ranges_[idx];
|
||||||
} else {
|
} else {
|
||||||
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
||||||
@ -308,9 +135,9 @@ Factor::sumOut (VarId vid)
|
|||||||
offset = (leftVarOffset * count2) + count1;
|
offset = (leftVarOffset * count2) + count1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
varids_.erase (varids_.begin() + idx);
|
args_.erase (args_.begin() + idx);
|
||||||
ranges_.erase (ranges_.begin() + idx);
|
ranges_.erase (ranges_.begin() + idx);
|
||||||
dist_->params = newps;
|
params_ = newps;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -318,20 +145,19 @@ Factor::sumOut (VarId vid)
|
|||||||
void
|
void
|
||||||
Factor::sumOutFirstVariable (void)
|
Factor::sumOutFirstVariable (void)
|
||||||
{
|
{
|
||||||
Params& params = dist_->params;
|
unsigned range = ranges_.front();
|
||||||
unsigned nStates = ranges_.front();
|
unsigned sep = params_.size() / range;
|
||||||
unsigned sep = params.size() / nStates;
|
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
for (unsigned i = sep; i < params.size(); i++) {
|
for (unsigned i = sep; i < params_.size(); i++) {
|
||||||
Util::logSum (params[i % sep], params[i]);
|
params_[i % sep] = Util::logSum (params_[i % sep], params_[i]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (unsigned i = sep; i < params.size(); i++) {
|
for (unsigned i = sep; i < params_.size(); i++) {
|
||||||
params[i % sep] += params[i];
|
params_[i % sep] += params_[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
params.resize (sep);
|
params_.resize (sep);
|
||||||
varids_.erase (varids_.begin());
|
args_.erase (args_.begin());
|
||||||
ranges_.erase (ranges_.begin());
|
ranges_.erase (ranges_.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -340,143 +166,55 @@ Factor::sumOutFirstVariable (void)
|
|||||||
void
|
void
|
||||||
Factor::sumOutLastVariable (void)
|
Factor::sumOutLastVariable (void)
|
||||||
{
|
{
|
||||||
Params& params = dist_->params;
|
unsigned range = ranges_.back();
|
||||||
unsigned nStates = ranges_.back();
|
|
||||||
unsigned idx1 = 0;
|
unsigned idx1 = 0;
|
||||||
unsigned idx2 = 0;
|
unsigned idx2 = 0;
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
while (idx1 < params.size()) {
|
while (idx1 < params_.size()) {
|
||||||
params[idx2] = params[idx1];
|
params_[idx2] = params_[idx1];
|
||||||
idx1 ++;
|
idx1 ++;
|
||||||
for (unsigned j = 1; j < nStates; j++) {
|
for (unsigned j = 1; j < range; j++) {
|
||||||
Util::logSum (params[idx2], params[idx1]);
|
params_[idx2] = Util::logSum (params_[idx2], params_[idx1]);
|
||||||
idx1 ++;
|
idx1 ++;
|
||||||
}
|
}
|
||||||
idx2 ++;
|
idx2 ++;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
while (idx1 < params.size()) {
|
while (idx1 < params_.size()) {
|
||||||
params[idx2] = params[idx1];
|
params_[idx2] = params_[idx1];
|
||||||
idx1 ++;
|
idx1 ++;
|
||||||
for (unsigned j = 1; j < nStates; j++) {
|
for (unsigned j = 1; j < range; j++) {
|
||||||
params[idx2] += params[idx1];
|
params_[idx2] += params_[idx1];
|
||||||
idx1 ++;
|
idx1 ++;
|
||||||
}
|
}
|
||||||
idx2 ++;
|
idx2 ++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
params.resize (idx2);
|
params_.resize (idx2);
|
||||||
varids_.pop_back();
|
args_.pop_back();
|
||||||
ranges_.pop_back();
|
ranges_.pop_back();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Factor::orderVariables (void)
|
Factor::multiply (Factor& g)
|
||||||
{
|
{
|
||||||
VarIds sortedVarIds = varids_;
|
if (args_.size() == 0) {
|
||||||
sort (sortedVarIds.begin(), sortedVarIds.end());
|
copyFromFactor (g);
|
||||||
reorderVariables (sortedVarIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::reorderVariables (const VarIds& newVarIds)
|
|
||||||
{
|
|
||||||
assert (newVarIds.size() == varids_.size());
|
|
||||||
if (newVarIds == varids_) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
TFactor<VarId>::multiply (g);
|
||||||
Ranges newRanges;
|
|
||||||
vector<unsigned> positions;
|
|
||||||
for (unsigned i = 0; i < newVarIds.size(); i++) {
|
|
||||||
unsigned idx = indexOf (newVarIds[i]);
|
|
||||||
newRanges.push_back (ranges_[idx]);
|
|
||||||
positions.push_back (idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned N = ranges_.size();
|
|
||||||
Params newParams (dist_->params.size());
|
|
||||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
|
||||||
unsigned li = i;
|
|
||||||
// calculate vector index corresponding to linear index
|
|
||||||
vector<unsigned> vi (N);
|
|
||||||
for (int k = N-1; k >= 0; k--) {
|
|
||||||
vi[k] = li % ranges_[k];
|
|
||||||
li /= ranges_[k];
|
|
||||||
}
|
|
||||||
// convert permuted vector index to corresponding linear index
|
|
||||||
unsigned prod = 1;
|
|
||||||
unsigned new_li = 0;
|
|
||||||
for (int k = N-1; k >= 0; k--) {
|
|
||||||
new_li += vi[positions[k]] * prod;
|
|
||||||
prod *= ranges_[positions[k]];
|
|
||||||
}
|
|
||||||
newParams[new_li] = dist_->params[i];
|
|
||||||
}
|
|
||||||
varids_ = newVarIds;
|
|
||||||
ranges_ = newRanges;
|
|
||||||
dist_->params = newParams;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Factor::absorveEvidence (VarId vid, unsigned evidence)
|
Factor::reorderAccordingVarIds (void)
|
||||||
{
|
{
|
||||||
int idx = indexOf (vid);
|
VarIds sortedVarIds = args_;
|
||||||
assert (idx != -1);
|
sort (sortedVarIds.begin(), sortedVarIds.end());
|
||||||
|
reorderArguments (sortedVarIds);
|
||||||
Params oldParams = dist_->params;
|
|
||||||
dist_->params.clear();
|
|
||||||
dist_->params.reserve (oldParams.size() / ranges_[idx]);
|
|
||||||
StatesIndexer indexer (ranges_);
|
|
||||||
for (unsigned i = 0; i < evidence; i++) {
|
|
||||||
indexer.increment (idx);
|
|
||||||
}
|
|
||||||
while (indexer.valid()) {
|
|
||||||
dist_->params.push_back (oldParams[indexer]);
|
|
||||||
indexer.incrementExcluding (idx);
|
|
||||||
}
|
|
||||||
varids_.erase (varids_.begin() + idx);
|
|
||||||
ranges_.erase (ranges_.begin() + idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Factor::normalize (void)
|
|
||||||
{
|
|
||||||
Util::normalize (dist_->params);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
Factor::contains (const VarIds& vars) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
if (indexOf (vars[i]) == -1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
Factor::indexOf (VarId vid) const
|
|
||||||
{
|
|
||||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
|
||||||
if (varids_[i] == vid) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -486,9 +224,9 @@ Factor::getLabel (void) const
|
|||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "f(" ;
|
ss << "f(" ;
|
||||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (i != 0) ss << "," ;
|
if (i != 0) ss << "," ;
|
||||||
ss << VarNode (varids_[i], ranges_[i]).label();
|
ss << Var (args_[i], ranges_[i]).label();
|
||||||
}
|
}
|
||||||
ss << ")" ;
|
ss << ")" ;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
@ -499,14 +237,14 @@ Factor::getLabel (void) const
|
|||||||
void
|
void
|
||||||
Factor::print (void) const
|
Factor::print (void) const
|
||||||
{
|
{
|
||||||
VarNodes vars;
|
Vars vars;
|
||||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
vars.push_back (new VarNode (varids_[i], ranges_[i]));
|
vars.push_back (new Var (args_[i], ranges_[i]));
|
||||||
}
|
}
|
||||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
vector<string> jointStrings = Util::getStateLines (vars);
|
||||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
for (unsigned i = 0; i < params_.size(); i++) {
|
||||||
cout << "f(" << jointStrings[i] << ")" ;
|
cout << "[" << distId_ << "] f(" << jointStrings[i] << ")" ;
|
||||||
cout << " = " << dist_->params[i] << endl;
|
cout << " = " << params_[i] << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
@ -515,3 +253,13 @@ Factor::print (void) const
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Factor::copyFromFactor (const Factor& g)
|
||||||
|
{
|
||||||
|
args_ = g.arguments();
|
||||||
|
ranges_ = g.ranges();
|
||||||
|
params_ = g.params();
|
||||||
|
distId_ = g.distId();
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -3,64 +3,285 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "Distribution.h"
|
#include "Var.h"
|
||||||
#include "VarNode.h"
|
#include "Indexer.h"
|
||||||
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class Distribution;
|
|
||||||
|
template <typename T>
|
||||||
|
class TFactor
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const vector<T>& arguments (void) const { return args_; }
|
||||||
|
|
||||||
|
vector<T>& arguments (void) { return args_; }
|
||||||
|
|
||||||
|
const Ranges& ranges (void) const { return ranges_; }
|
||||||
|
|
||||||
|
const Params& params (void) const { return params_; }
|
||||||
|
|
||||||
|
Params& params (void) { return params_; }
|
||||||
|
|
||||||
|
unsigned nrArguments (void) const { return args_.size(); }
|
||||||
|
|
||||||
|
unsigned size (void) const { return params_.size(); }
|
||||||
|
|
||||||
|
unsigned distId (void) const { return distId_; }
|
||||||
|
|
||||||
|
void setDistId (unsigned id) { distId_ = id; }
|
||||||
|
|
||||||
|
void normalize (void) { LogAware::normalize (params_); }
|
||||||
|
|
||||||
|
void setParams (const Params& newParams)
|
||||||
|
{
|
||||||
|
params_ = newParams;
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
|
}
|
||||||
|
|
||||||
|
int indexOf (const T& t) const
|
||||||
|
{
|
||||||
|
int idx = -1;
|
||||||
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
|
if (args_[i] == t) {
|
||||||
|
idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
const T& argument (unsigned idx) const
|
||||||
|
{
|
||||||
|
assert (idx < args_.size());
|
||||||
|
return args_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
T& argument (unsigned idx)
|
||||||
|
{
|
||||||
|
assert (idx < args_.size());
|
||||||
|
return args_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned range (unsigned idx) const
|
||||||
|
{
|
||||||
|
assert (idx < ranges_.size());
|
||||||
|
return ranges_[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
void multiply (TFactor<T>& g)
|
||||||
|
{
|
||||||
|
const vector<T>& g_args = g.arguments();
|
||||||
|
const Ranges& g_ranges = g.ranges();
|
||||||
|
const Params& g_params = g.params();
|
||||||
|
if (args_ == g_args) {
|
||||||
|
// optimization: if the factors contain the same set of args,
|
||||||
|
// we can do a 1 to 1 operation on the parameters
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
Util::add (params_, g_params);
|
||||||
|
} else {
|
||||||
|
Util::multiply (params_, g_params);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bool sharedArgs = false;
|
||||||
|
vector<unsigned> gvarpos;
|
||||||
|
for (unsigned i = 0; i < g_args.size(); i++) {
|
||||||
|
int idx = indexOf (g_args[i]);
|
||||||
|
if (idx == -1) {
|
||||||
|
insertArgument (g_args[i], g_ranges[i]);
|
||||||
|
gvarpos.push_back (args_.size() - 1);
|
||||||
|
} else {
|
||||||
|
sharedArgs = true;
|
||||||
|
gvarpos.push_back (idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (sharedArgs == false) {
|
||||||
|
// optimization: if the original factors doesn't have common args,
|
||||||
|
// we don't need to marry the states of the common args
|
||||||
|
unsigned count = 0;
|
||||||
|
for (unsigned i = 0; i < params_.size(); i++) {
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
params_[i] += g_params[count];
|
||||||
|
} else {
|
||||||
|
params_[i] *= g_params[count];
|
||||||
|
}
|
||||||
|
count ++;
|
||||||
|
if (count >= g_params.size()) {
|
||||||
|
count = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
StatesIndexer indexer (ranges_, false);
|
||||||
|
while (indexer.valid()) {
|
||||||
|
unsigned g_li = 0;
|
||||||
|
unsigned prod = 1;
|
||||||
|
for (int j = gvarpos.size() - 1; j >= 0; j--) {
|
||||||
|
g_li += indexer[gvarpos[j]] * prod;
|
||||||
|
prod *= g_ranges[j];
|
||||||
|
}
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
params_[indexer] += g_params[g_li];
|
||||||
|
} else {
|
||||||
|
params_[indexer] *= g_params[g_li];
|
||||||
|
}
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void absorveEvidence (const T& arg, unsigned evidence)
|
||||||
|
{
|
||||||
|
int idx = indexOf (arg);
|
||||||
|
assert (idx != -1);
|
||||||
|
assert (evidence < ranges_[idx]);
|
||||||
|
Params copy = params_;
|
||||||
|
params_.clear();
|
||||||
|
params_.reserve (copy.size() / ranges_[idx]);
|
||||||
|
StatesIndexer indexer (ranges_);
|
||||||
|
for (unsigned i = 0; i < evidence; i++) {
|
||||||
|
indexer.increment (idx);
|
||||||
|
}
|
||||||
|
while (indexer.valid()) {
|
||||||
|
params_.push_back (copy[indexer]);
|
||||||
|
indexer.incrementExcluding (idx);
|
||||||
|
}
|
||||||
|
args_.erase (args_.begin() + idx);
|
||||||
|
ranges_.erase (ranges_.begin() + idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void reorderArguments (const vector<T> newArgs)
|
||||||
|
{
|
||||||
|
assert (newArgs.size() == args_.size());
|
||||||
|
if (newArgs == args_) {
|
||||||
|
return; // already in the wanted order
|
||||||
|
}
|
||||||
|
Ranges newRanges;
|
||||||
|
vector<unsigned> positions;
|
||||||
|
for (unsigned i = 0; i < newArgs.size(); i++) {
|
||||||
|
unsigned idx = indexOf (newArgs[i]);
|
||||||
|
newRanges.push_back (ranges_[idx]);
|
||||||
|
positions.push_back (idx);
|
||||||
|
}
|
||||||
|
unsigned N = ranges_.size();
|
||||||
|
Params newParams (params_.size());
|
||||||
|
for (unsigned i = 0; i < params_.size(); i++) {
|
||||||
|
unsigned li = i;
|
||||||
|
// calculate vector index corresponding to linear index
|
||||||
|
vector<unsigned> vi (N);
|
||||||
|
for (int k = N-1; k >= 0; k--) {
|
||||||
|
vi[k] = li % ranges_[k];
|
||||||
|
li /= ranges_[k];
|
||||||
|
}
|
||||||
|
// convert permuted vector index to corresponding linear index
|
||||||
|
unsigned prod = 1;
|
||||||
|
unsigned new_li = 0;
|
||||||
|
for (int k = N - 1; k >= 0; k--) {
|
||||||
|
new_li += vi[positions[k]] * prod;
|
||||||
|
prod *= ranges_[positions[k]];
|
||||||
|
}
|
||||||
|
newParams[new_li] = params_[i];
|
||||||
|
}
|
||||||
|
args_ = newArgs;
|
||||||
|
ranges_ = newRanges;
|
||||||
|
params_ = newParams;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool contains (const T& arg) const
|
||||||
|
{
|
||||||
|
return Util::contains (args_, arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool contains (const vector<T>& args) const
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
|
if (contains (args[i]) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
vector<T> args_;
|
||||||
|
Ranges ranges_;
|
||||||
|
Params params_;
|
||||||
|
unsigned distId_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void insertArgument (const T& arg, unsigned range)
|
||||||
|
{
|
||||||
|
assert (indexOf (arg) == -1);
|
||||||
|
Params copy = params_;
|
||||||
|
params_.clear();
|
||||||
|
params_.reserve (copy.size() * range);
|
||||||
|
for (unsigned i = 0; i < copy.size(); i++) {
|
||||||
|
for (unsigned reps = 0; reps < range; reps++) {
|
||||||
|
params_.push_back (copy[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
args_.push_back (arg);
|
||||||
|
ranges_.push_back (range);
|
||||||
|
}
|
||||||
|
|
||||||
|
void insertArguments (const vector<T>& args, const Ranges& ranges)
|
||||||
|
{
|
||||||
|
Params copy = params_;
|
||||||
|
unsigned nrStates = 1;
|
||||||
|
for (unsigned i = 0; i < args.size(); i++) {
|
||||||
|
assert (indexOf (args[i]) == -1);
|
||||||
|
args_.push_back (args[i]);
|
||||||
|
ranges_.push_back (ranges[i]);
|
||||||
|
nrStates *= ranges[i];
|
||||||
|
}
|
||||||
|
params_.clear();
|
||||||
|
params_.reserve (copy.size() * nrStates);
|
||||||
|
for (unsigned i = 0; i < copy.size(); i++) {
|
||||||
|
for (unsigned reps = 0; reps < nrStates; reps++) {
|
||||||
|
params_.push_back (copy[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class Factor
|
|
||||||
|
class Factor : public TFactor<VarId>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Factor (void) { }
|
Factor (void) { }
|
||||||
|
|
||||||
Factor (const Factor&);
|
Factor (const Factor&);
|
||||||
Factor (VarId, unsigned);
|
|
||||||
Factor (const VarNodes&);
|
|
||||||
Factor (VarId, unsigned, const Params&);
|
|
||||||
Factor (VarNodes&, Distribution*);
|
|
||||||
Factor (const VarNodes&, const Params&);
|
|
||||||
Factor (const VarIds&, const Ranges&, const Params&);
|
|
||||||
~Factor (void);
|
|
||||||
|
|
||||||
void setParameters (const Params&);
|
Factor (const VarIds&, const Ranges&, const Params&,
|
||||||
void copyFromFactor (const Factor& f);
|
unsigned = Util::maxUnsigned());
|
||||||
void multiply (const Factor&);
|
|
||||||
void insertVariable (VarId, unsigned);
|
|
||||||
void insertVariables (const VarIds&, const Ranges&);
|
|
||||||
void sumOutAllExcept (VarId);
|
|
||||||
void sumOutAllExcept (const VarIds&);
|
|
||||||
void sumOut (VarId);
|
|
||||||
void sumOutFirstVariable (void);
|
|
||||||
void sumOutLastVariable (void);
|
|
||||||
void orderVariables (void);
|
|
||||||
void reorderVariables (const VarIds&);
|
|
||||||
void absorveEvidence (VarId, unsigned);
|
|
||||||
void normalize (void);
|
|
||||||
bool contains (const VarIds&) const;
|
|
||||||
int indexOf (VarId) const;
|
|
||||||
string getLabel (void) const;
|
|
||||||
void print (void) const;
|
|
||||||
|
|
||||||
const VarIds& getVarIds (void) const { return varids_; }
|
Factor (const Vars&, const Params&,
|
||||||
const Ranges& getRanges (void) const { return ranges_; }
|
unsigned = Util::maxUnsigned());
|
||||||
const Params& getParameters (void) const { return dist_->params; }
|
|
||||||
Distribution* getDistribution (void) const { return dist_; }
|
|
||||||
unsigned nrVariables (void) const { return varids_.size(); }
|
|
||||||
unsigned nrParameters() const { return dist_->params.size(); }
|
|
||||||
|
|
||||||
void setDistribution (Distribution* dist)
|
void sumOutAllExcept (VarId);
|
||||||
{
|
|
||||||
dist_ = dist;
|
void sumOutAllExcept (const VarIds&);
|
||||||
}
|
|
||||||
|
void sumOut (VarId);
|
||||||
|
|
||||||
|
void sumOutFirstVariable (void);
|
||||||
|
|
||||||
|
void sumOutLastVariable (void);
|
||||||
|
|
||||||
|
void multiply (Factor&);
|
||||||
|
|
||||||
|
void reorderAccordingVarIds (void);
|
||||||
|
|
||||||
|
string getLabel (void) const;
|
||||||
|
|
||||||
|
void print (void) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void copyFromFactor (const Factor& f);
|
||||||
|
|
||||||
VarIds varids_;
|
|
||||||
Ranges ranges_;
|
|
||||||
Distribution* dist_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_FACTOR_H
|
#endif // HORUS_FACTOR_H
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
#include "BayesNet.h"
|
#include "BayesNet.h"
|
||||||
|
#include "BayesBall.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
@ -17,140 +18,92 @@ bool FactorGraph::orderFactorVariables = false;
|
|||||||
|
|
||||||
FactorGraph::FactorGraph (const FactorGraph& fg)
|
FactorGraph::FactorGraph (const FactorGraph& fg)
|
||||||
{
|
{
|
||||||
const FgVarSet& vars = fg.getVarNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
FgVarNode* varNode = new FgVarNode (vars[i]);
|
addVarNode (new VarNode (varNodes[i]));
|
||||||
addVariable (varNode);
|
|
||||||
}
|
}
|
||||||
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
const FgFacSet& facs = fg.getFactorNodes();
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
for (unsigned i = 0; i < facs.size(); i++) {
|
FacNode* facNode = new FacNode (facNodes[i]->factor());
|
||||||
FgFacNode* facNode = new FgFacNode (facs[i]);
|
addFacNode (facNode);
|
||||||
addFactor (facNode);
|
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||||
const FgVarSet& neighs = facs[i]->neighbors();
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
addEdge (facNode, varNodes_[neighs[j]->getIndex()]);
|
addEdge (varNodes_[neighs[j]->getIndex()], facNode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FactorGraph::FactorGraph (const BayesNet& bn)
|
|
||||||
{
|
|
||||||
const BnNodeSet& nodes = bn.getBayesNodes();
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
FgVarNode* varNode = new FgVarNode (nodes[i]);
|
|
||||||
addVariable (varNode);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
const BnNodeSet& parents = nodes[i]->getParents();
|
|
||||||
if (!(nodes[i]->hasEvidence() && parents.size() == 0)) {
|
|
||||||
VarNodes neighs;
|
|
||||||
neighs.push_back (varNodes_[nodes[i]->getIndex()]);
|
|
||||||
for (unsigned j = 0; j < parents.size(); j++) {
|
|
||||||
neighs.push_back (varNodes_[parents[j]->getIndex()]);
|
|
||||||
}
|
|
||||||
FgFacNode* fn = new FgFacNode (
|
|
||||||
new Factor (neighs, nodes[i]->getDistribution()));
|
|
||||||
if (orderFactorVariables) {
|
|
||||||
sort (neighs.begin(), neighs.end(), CompVarId());
|
|
||||||
fn->factor()->orderVariables();
|
|
||||||
}
|
|
||||||
addFactor (fn);
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
|
||||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
setIndexes();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||||
{
|
{
|
||||||
ifstream is (fileName);
|
std::ifstream is (fileName);
|
||||||
if (!is.is_open()) {
|
if (!is.is_open()) {
|
||||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
cerr << "error: cannot read from file " << fileName << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
ignoreLines (is);
|
||||||
string line;
|
string line;
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
|
||||||
getline (is, line);
|
getline (is, line);
|
||||||
if (line != "MARKOV") {
|
if (line != "MARKOV") {
|
||||||
cerr << "error: the network must be a MARKOV network " << endl;
|
cerr << "error: the network must be a MARKOV network " << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
// read the number of vars
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
ignoreLines (is);
|
||||||
unsigned nVars;
|
unsigned nrVars;
|
||||||
is >> nVars;
|
is >> nrVars;
|
||||||
|
// read the range of each var
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
ignoreLines (is);
|
||||||
vector<int> domainSizes (nVars);
|
Ranges ranges (nrVars);
|
||||||
for (unsigned i = 0; i < nVars; i++) {
|
for (unsigned i = 0; i < nrVars; i++) {
|
||||||
unsigned ds;
|
is >> ranges[i];
|
||||||
is >> ds;
|
|
||||||
domainSizes[i] = ds;
|
|
||||||
}
|
}
|
||||||
|
unsigned nrFactors;
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
unsigned nrArgs;
|
||||||
for (unsigned i = 0; i < nVars; i++) {
|
unsigned vid;
|
||||||
addVariable (new FgVarNode (i, domainSizes[i]));
|
is >> nrFactors;
|
||||||
}
|
vector<VarIds> factorVarIds;
|
||||||
|
vector<Ranges> factorRanges;
|
||||||
unsigned nFactors;
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
is >> nFactors;
|
ignoreLines (is);
|
||||||
for (unsigned i = 0; i < nFactors; i++) {
|
is >> nrArgs;
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
factorVarIds.push_back ({ });
|
||||||
unsigned nFactorVars;
|
factorRanges.push_back ({ });
|
||||||
is >> nFactorVars;
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
VarNodes neighs;
|
|
||||||
for (unsigned j = 0; j < nFactorVars; j++) {
|
|
||||||
unsigned vid;
|
|
||||||
is >> vid;
|
is >> vid;
|
||||||
FgVarNode* neigh = getFgVarNode (vid);
|
if (vid >= ranges.size()) {
|
||||||
if (!neigh) {
|
cerr << "error: invalid variable identifier `" << vid << "'" << endl;
|
||||||
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
|
cerr << "identifiers must be between 0 and " << ranges.size() - 1 ;
|
||||||
|
cerr << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
neighs.push_back (neigh);
|
factorVarIds.back().push_back (vid);
|
||||||
}
|
factorRanges.back().push_back (ranges[vid]);
|
||||||
FgFacNode* fn = new FgFacNode (new Factor (neighs));
|
|
||||||
addFactor (fn);
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
|
||||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// read the parameters
|
||||||
for (unsigned i = 0; i < nFactors; i++) {
|
unsigned nrParams;
|
||||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
unsigned nParams;
|
ignoreLines (is);
|
||||||
is >> nParams;
|
is >> nrParams;
|
||||||
if (facNodes_[i]->getParameters().size() != nParams) {
|
if (nrParams != Util::expectedSize (factorRanges[i])) {
|
||||||
cerr << "error: invalid number of parameters for factor " ;
|
cerr << "error: invalid number of parameters for factor nº " << i ;
|
||||||
cerr << facNodes_[i]->getLabel() ;
|
cerr << ", expected: " << Util::expectedSize (factorRanges[i]);
|
||||||
cerr << ", expected: " << facNodes_[i]->getParameters().size();
|
cerr << ", given: " << nrParams << endl;
|
||||||
cerr << ", given: " << nParams << endl;
|
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
Params params (nParams);
|
Params params (nrParams);
|
||||||
for (unsigned j = 0; j < nParams; j++) {
|
for (unsigned j = 0; j < nrParams; j++) {
|
||||||
double param;
|
is >> params[j];
|
||||||
is >> param;
|
|
||||||
params[j] = param;
|
|
||||||
}
|
}
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::toLog (params);
|
Util::toLog (params);
|
||||||
}
|
}
|
||||||
facNodes_[i]->factor()->setParameters (params);
|
addFactor (Factor (factorVarIds[i], factorRanges[i], params));
|
||||||
}
|
}
|
||||||
is.close();
|
is.close();
|
||||||
setIndexes();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -158,87 +111,58 @@ FactorGraph::readFromUaiFormat (const char* fileName)
|
|||||||
void
|
void
|
||||||
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||||
{
|
{
|
||||||
ifstream is (fileName);
|
std::ifstream is (fileName);
|
||||||
if (!is.is_open()) {
|
if (!is.is_open()) {
|
||||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
cerr << "error: cannot read from file " << fileName << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
ignoreLines (is);
|
||||||
string line;
|
unsigned nrFactors;
|
||||||
unsigned nFactors;
|
unsigned nrArgs;
|
||||||
|
VarId vid;
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
is >> nrFactors;
|
||||||
is >> nFactors;
|
for (unsigned i = 0; i < nrFactors; i++) {
|
||||||
|
ignoreLines (is);
|
||||||
if (is.fail()) {
|
// read the factor arguments
|
||||||
cerr << "error: cannot read the number of factors" << endl;
|
is >> nrArgs;
|
||||||
abort();
|
|
||||||
}
|
|
||||||
|
|
||||||
getline (is, line);
|
|
||||||
if (is.fail() || line.size() > 0) {
|
|
||||||
cerr << "error: cannot read the number of factors" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < nFactors; i++) {
|
|
||||||
unsigned nVars;
|
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
|
||||||
|
|
||||||
is >> nVars;
|
|
||||||
VarIds vids;
|
VarIds vids;
|
||||||
for (unsigned j = 0; j < nVars; j++) {
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
VarId vid;
|
ignoreLines (is);
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
|
||||||
is >> vid;
|
is >> vid;
|
||||||
vids.push_back (vid);
|
vids.push_back (vid);
|
||||||
}
|
}
|
||||||
|
// read ranges
|
||||||
VarNodes neighs;
|
Ranges ranges (nrArgs);
|
||||||
unsigned nParams = 1;
|
for (unsigned j = 0; j < nrArgs; j++) {
|
||||||
for (unsigned j = 0; j < nVars; j++) {
|
ignoreLines (is);
|
||||||
unsigned dsize;
|
is >> ranges[j];
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
VarNode* var = getVarNode (vids[j]);
|
||||||
is >> dsize;
|
if (var != 0 && ranges[j] != var->range()) {
|
||||||
FgVarNode* var = getFgVarNode (vids[j]);
|
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
||||||
if (var == 0) {
|
cerr << "more factors with a different range" << endl;
|
||||||
var = new FgVarNode (vids[j], dsize);
|
|
||||||
addVariable (var);
|
|
||||||
} else {
|
|
||||||
if (var->nrStates() != dsize) {
|
|
||||||
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
|
||||||
cerr << "more factors with different domain sizes" << endl;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
neighs.push_back (var);
|
|
||||||
nParams *= var->nrStates();
|
|
||||||
}
|
}
|
||||||
Params params (nParams, 0);
|
// read parameters
|
||||||
|
ignoreLines (is);
|
||||||
unsigned nNonzeros;
|
unsigned nNonzeros;
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
|
||||||
is >> nNonzeros;
|
is >> nNonzeros;
|
||||||
|
Params params (Util::expectedSize (ranges), 0);
|
||||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||||
|
ignoreLines (is);
|
||||||
unsigned index;
|
unsigned index;
|
||||||
double val;
|
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
|
||||||
is >> index;
|
is >> index;
|
||||||
while ((is.peek()) == '#') getline (is, line);
|
ignoreLines (is);
|
||||||
|
double val;
|
||||||
is >> val;
|
is >> val;
|
||||||
params[index] = val;
|
params[index] = val;
|
||||||
}
|
}
|
||||||
reverse (neighs.begin(), neighs.end());
|
reverse (vids.begin(), vids.end());
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::toLog (params);
|
Util::toLog (params);
|
||||||
}
|
}
|
||||||
FgFacNode* fn = new FgFacNode (new Factor (neighs, params));
|
addFactor (Factor (vids, ranges, params));
|
||||||
addFactor (fn);
|
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
|
||||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
is.close();
|
is.close();
|
||||||
setIndexes();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -256,17 +180,41 @@ FactorGraph::~FactorGraph (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::addVariable (FgVarNode* vn)
|
FactorGraph::addFactor (const Factor& factor)
|
||||||
{
|
{
|
||||||
varNodes_.push_back (vn);
|
FacNode* fn = new FacNode (factor);
|
||||||
vn->setIndex (varNodes_.size() - 1);
|
addFacNode (fn);
|
||||||
varMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
|
const VarIds& vids = factor.arguments();
|
||||||
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
|
bool found = false;
|
||||||
|
for (unsigned j = 0; j < varNodes_.size(); j++) {
|
||||||
|
if (varNodes_[j]->varId() == vids[i]) {
|
||||||
|
addEdge (varNodes_[j], fn);
|
||||||
|
found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (found == false) {
|
||||||
|
VarNode* vn = new VarNode (vids[i], factor.range (i));
|
||||||
|
addVarNode (vn);
|
||||||
|
addEdge (vn, fn);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::addFactor (FgFacNode* fn)
|
FactorGraph::addVarNode (VarNode* vn)
|
||||||
|
{
|
||||||
|
varNodes_.push_back (vn);
|
||||||
|
vn->setIndex (varNodes_.size() - 1);
|
||||||
|
varMap_.insert (make_pair (vn->varId(), vn));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::addFacNode (FacNode* fn)
|
||||||
{
|
{
|
||||||
facNodes_.push_back (fn);
|
facNodes_.push_back (fn);
|
||||||
fn->setIndex (facNodes_.size() - 1);
|
fn->setIndex (facNodes_.size() - 1);
|
||||||
@ -275,7 +223,7 @@ FactorGraph::addFactor (FgFacNode* fn)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
|
FactorGraph::addEdge (VarNode* vn, FacNode* fn)
|
||||||
{
|
{
|
||||||
vn->addNeighbor (fn);
|
vn->addNeighbor (fn);
|
||||||
fn->addNeighbor (vn);
|
fn->addNeighbor (vn);
|
||||||
@ -283,37 +231,6 @@ FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FactorGraph::addEdge (FgFacNode* fn, FgVarNode* vn)
|
|
||||||
{
|
|
||||||
fn->addNeighbor (vn);
|
|
||||||
vn->addNeighbor (fn);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarNode*
|
|
||||||
FactorGraph::getVariableNode (VarId vid) const
|
|
||||||
{
|
|
||||||
FgVarNode* vn = getFgVarNode (vid);
|
|
||||||
assert (vn);
|
|
||||||
return vn;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarNodes
|
|
||||||
FactorGraph::getVariableNodes (void) const
|
|
||||||
{
|
|
||||||
VarNodes vars;
|
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
|
||||||
vars.push_back (varNodes_[i]);
|
|
||||||
}
|
|
||||||
return vars;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
FactorGraph::isTree (void) const
|
FactorGraph::isTree (void) const
|
||||||
{
|
{
|
||||||
@ -322,51 +239,42 @@ FactorGraph::isTree (void) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
DAGraph&
|
||||||
FactorGraph::setIndexes (void)
|
FactorGraph::getStructure (void)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
assert (fromBayesNet_);
|
||||||
varNodes_[i]->setIndex (i);
|
if (structure_.empty()) {
|
||||||
}
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
structure_.addNode (new DAGraphNode (varNodes_[i]));
|
||||||
facNodes_[i]->setIndex (i);
|
}
|
||||||
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
|
const VarIds& vids = facNodes_[i]->factor().arguments();
|
||||||
|
for (unsigned j = 1; j < vids.size(); j++) {
|
||||||
|
structure_.addEdge (vids[j], vids[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return structure_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FactorGraph::freeDistributions (void)
|
FactorGraph::print (void) const
|
||||||
{
|
|
||||||
set<Distribution*> dists;
|
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
|
||||||
dists.insert (facNodes_[i]->factor()->getDistribution());
|
|
||||||
}
|
|
||||||
for (set<Distribution*>::iterator it = dists.begin();
|
|
||||||
it != dists.end(); it++) {
|
|
||||||
delete *it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
FactorGraph::printGraphicalModel (void) const
|
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
cout << "VarId = " << varNodes_[i]->varId() << endl;
|
cout << "var id = " << varNodes_[i]->varId() << endl;
|
||||||
cout << "Label = " << varNodes_[i]->label() << endl;
|
cout << "label = " << varNodes_[i]->label() << endl;
|
||||||
cout << "Nr States = " << varNodes_[i]->nrStates() << endl;
|
cout << "range = " << varNodes_[i]->range() << endl;
|
||||||
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
|
cout << "evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||||
cout << "Factors = " ;
|
cout << "factors = " ;
|
||||||
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
|
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
|
||||||
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
|
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
|
||||||
}
|
}
|
||||||
cout << endl << endl;
|
cout << endl << endl;
|
||||||
}
|
}
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
facNodes_[i]->factor()->print();
|
facNodes_[i]->factor().print();
|
||||||
cout << endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -381,31 +289,26 @@ FactorGraph::exportToGraphViz (const char* fileName) const
|
|||||||
cerr << "FactorGraph::exportToDotFile()" << endl;
|
cerr << "FactorGraph::exportToDotFile()" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
out << "graph \"" << fileName << "\" {" << endl;
|
out << "graph \"" << fileName << "\" {" << endl;
|
||||||
|
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
if (varNodes_[i]->hasEvidence()) {
|
if (varNodes_[i]->hasEvidence()) {
|
||||||
out << '"' << varNodes_[i]->label() << '"' ;
|
out << '"' << varNodes_[i]->label() << '"' ;
|
||||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||||
out << " [label=\"" << facNodes_[i]->getLabel();
|
out << " [label=\"" << facNodes_[i]->getLabel();
|
||||||
out << "\"" << ", shape=box]" << endl;
|
out << "\"" << ", shape=box]" << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
const FgVarSet& myVars = facNodes_[i]->neighbors();
|
const VarNodes& myVars = facNodes_[i]->neighbors();
|
||||||
for (unsigned j = 0; j < myVars.size(); j++) {
|
for (unsigned j = 0; j < myVars.size(); j++) {
|
||||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||||
out << " -- " ;
|
out << " -- " ;
|
||||||
out << '"' << myVars[j]->label() << '"' << endl;
|
out << '"' << myVars[j]->label() << '"' << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out << "}" << endl;
|
out << "}" << endl;
|
||||||
out.close();
|
out.close();
|
||||||
}
|
}
|
||||||
@ -417,30 +320,26 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
|||||||
{
|
{
|
||||||
ofstream out (fileName);
|
ofstream out (fileName);
|
||||||
if (!out.is_open()) {
|
if (!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file " << fileName << endl;
|
||||||
cerr << "FactorGraph::exportToUaiFormat()" << endl;
|
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
out << "MARKOV" << endl;
|
out << "MARKOV" << endl;
|
||||||
out << varNodes_.size() << endl;
|
out << varNodes_.size() << endl;
|
||||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||||
out << varNodes_[i]->nrStates() << " " ;
|
out << varNodes_[i]->range() << " " ;
|
||||||
}
|
}
|
||||||
out << endl;
|
out << endl;
|
||||||
|
|
||||||
out << facNodes_.size() << endl;
|
out << facNodes_.size() << endl;
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
const FgVarSet& factorVars = facNodes_[i]->neighbors();
|
const VarNodes& factorVars = facNodes_[i]->neighbors();
|
||||||
out << factorVars.size();
|
out << factorVars.size();
|
||||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||||
out << " " << factorVars[j]->getIndex();
|
out << " " << factorVars[j]->getIndex();
|
||||||
}
|
}
|
||||||
out << endl;
|
out << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
Params params = facNodes_[i]->getParameters();
|
Params params = facNodes_[i]->factor().params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (params);
|
Util::fromLog (params);
|
||||||
}
|
}
|
||||||
@ -450,7 +349,6 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
|||||||
}
|
}
|
||||||
out << endl;
|
out << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
out.close();
|
out.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -461,23 +359,22 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
|||||||
{
|
{
|
||||||
ofstream out (fileName);
|
ofstream out (fileName);
|
||||||
if (!out.is_open()) {
|
if (!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file " << fileName << endl;
|
||||||
cerr << "FactorGraph::exportToLibDaiFormat()" << endl;
|
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
out << facNodes_.size() << endl << endl;
|
out << facNodes_.size() << endl << endl;
|
||||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||||
const FgVarSet& factorVars = facNodes_[i]->neighbors();
|
const VarNodes& factorVars = facNodes_[i]->neighbors();
|
||||||
out << factorVars.size() << endl;
|
out << factorVars.size() << endl;
|
||||||
for (int j = factorVars.size() - 1; j >= 0; j--) {
|
for (int j = factorVars.size() - 1; j >= 0; j--) {
|
||||||
out << factorVars[j]->varId() << " " ;
|
out << factorVars[j]->varId() << " " ;
|
||||||
}
|
}
|
||||||
out << endl;
|
out << endl;
|
||||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||||
out << factorVars[j]->nrStates() << " " ;
|
out << factorVars[j]->range() << " " ;
|
||||||
}
|
}
|
||||||
out << endl;
|
out << endl;
|
||||||
Params params = facNodes_[i]->factor()->getParameters();
|
Params params = facNodes_[i]->factor().params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (params);
|
Util::fromLog (params);
|
||||||
}
|
}
|
||||||
@ -492,6 +389,17 @@ FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FactorGraph::ignoreLines (std::ifstream& is) const
|
||||||
|
{
|
||||||
|
string ignoreStr;
|
||||||
|
while (is.peek() == '#' || is.peek() == '\n') {
|
||||||
|
getline (is, ignoreStr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
FactorGraph::containsCycle (void) const
|
FactorGraph::containsCycle (void) const
|
||||||
{
|
{
|
||||||
@ -511,13 +419,14 @@ FactorGraph::containsCycle (void) const
|
|||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
FactorGraph::containsCycle (const FgVarNode* v,
|
FactorGraph::containsCycle (
|
||||||
const FgFacNode* p,
|
const VarNode* v,
|
||||||
vector<bool>& visitedVars,
|
const FacNode* p,
|
||||||
vector<bool>& visitedFactors) const
|
vector<bool>& visitedVars,
|
||||||
|
vector<bool>& visitedFactors) const
|
||||||
{
|
{
|
||||||
visitedVars[v->getIndex()] = true;
|
visitedVars[v->getIndex()] = true;
|
||||||
const FgFacSet& adjacencies = v->neighbors();
|
const FacNodes& adjacencies = v->neighbors();
|
||||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||||
int w = adjacencies[i]->getIndex();
|
int w = adjacencies[i]->getIndex();
|
||||||
if (!visitedFactors[w]) {
|
if (!visitedFactors[w]) {
|
||||||
@ -535,13 +444,14 @@ FactorGraph::containsCycle (const FgVarNode* v,
|
|||||||
|
|
||||||
|
|
||||||
bool
|
bool
|
||||||
FactorGraph::containsCycle (const FgFacNode* v,
|
FactorGraph::containsCycle (
|
||||||
const FgVarNode* p,
|
const FacNode* v,
|
||||||
vector<bool>& visitedVars,
|
const VarNode* p,
|
||||||
vector<bool>& visitedFactors) const
|
vector<bool>& visitedVars,
|
||||||
|
vector<bool>& visitedFactors) const
|
||||||
{
|
{
|
||||||
visitedFactors[v->getIndex()] = true;
|
visitedFactors[v->getIndex()] = true;
|
||||||
const FgVarSet& adjacencies = v->neighbors();
|
const VarNodes& adjacencies = v->neighbors();
|
||||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||||
int w = adjacencies[i]->getIndex();
|
int w = adjacencies[i]->getIndex();
|
||||||
if (!visitedVars[w]) {
|
if (!visitedVars[w]) {
|
||||||
|
@ -3,135 +3,139 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "GraphicalModel.h"
|
|
||||||
#include "Distribution.h"
|
|
||||||
#include "Factor.h"
|
#include "Factor.h"
|
||||||
|
#include "BayesNet.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class BayesNet;
|
|
||||||
class FgFacNode;
|
|
||||||
|
|
||||||
class FgVarNode : public VarNode
|
class FacNode;
|
||||||
|
|
||||||
|
class VarNode : public Var
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FgVarNode (VarId varId, unsigned nrStates) : VarNode (varId, nrStates) { }
|
VarNode (VarId varId, unsigned nrStates)
|
||||||
FgVarNode (const VarNode* v) : VarNode (v) { }
|
: Var (varId, nrStates) { }
|
||||||
|
|
||||||
void addNeighbor (FgFacNode* fn) { neighs_.push_back (fn); }
|
VarNode (const Var* v) : Var (v) { }
|
||||||
const FgFacSet& neighbors (void) const { return neighs_; }
|
|
||||||
|
void addNeighbor (FacNode* fn) { neighs_.push_back (fn); }
|
||||||
|
|
||||||
|
const FacNodes& neighbors (void) const { return neighs_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
DISALLOW_COPY_AND_ASSIGN (VarNode);
|
||||||
// members
|
|
||||||
FgFacSet neighs_;
|
FacNodes neighs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class FgFacNode
|
class FacNode
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FgFacNode (const FgFacNode* fn) {
|
FacNode (const Factor& f) : factor_(f), index_(-1) { }
|
||||||
factor_ = new Factor (*fn->factor());
|
|
||||||
index_ = -1;
|
const Factor& factor (void) const { return factor_; }
|
||||||
}
|
|
||||||
FgFacNode (Factor* f) : factor_(new Factor(*f)), index_(-1) { }
|
Factor& factor (void) { return factor_; }
|
||||||
Factor* factor() const { return factor_; }
|
|
||||||
void addNeighbor (FgVarNode* vn) { neighs_.push_back (vn); }
|
void addNeighbor (VarNode* vn) { neighs_.push_back (vn); }
|
||||||
const FgVarSet& neighbors (void) const { return neighs_; }
|
|
||||||
|
const VarNodes& neighbors (void) const { return neighs_; }
|
||||||
|
|
||||||
|
int getIndex (void) const { return index_; }
|
||||||
|
|
||||||
|
void setIndex (int index) { index_ = index; }
|
||||||
|
|
||||||
|
string getLabel (void) { return factor_.getLabel(); }
|
||||||
|
|
||||||
int getIndex (void) const
|
|
||||||
{
|
|
||||||
assert (index_ != -1);
|
|
||||||
return index_;
|
|
||||||
}
|
|
||||||
void setIndex (int index)
|
|
||||||
{
|
|
||||||
index_ = index;
|
|
||||||
}
|
|
||||||
Distribution* getDistribution (void)
|
|
||||||
{
|
|
||||||
return factor_->getDistribution();
|
|
||||||
}
|
|
||||||
const Params& getParameters (void) const
|
|
||||||
{
|
|
||||||
return factor_->getParameters();
|
|
||||||
}
|
|
||||||
string getLabel (void)
|
|
||||||
{
|
|
||||||
return factor_->getLabel();
|
|
||||||
}
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN (FgFacNode);
|
DISALLOW_COPY_AND_ASSIGN (FacNode);
|
||||||
|
|
||||||
Factor* factor_;
|
VarNodes neighs_;
|
||||||
int index_;
|
Factor factor_;
|
||||||
FgVarSet neighs_;
|
int index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
struct CompVarId
|
struct CompVarId
|
||||||
{
|
{
|
||||||
bool operator() (const VarNode* vn1, const VarNode* vn2) const
|
bool operator() (const Var* v1, const Var* v2) const
|
||||||
{
|
{
|
||||||
return vn1->varId() < vn2->varId();
|
return v1->varId() < v2->varId();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class FactorGraph : public GraphicalModel
|
class FactorGraph
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FactorGraph (void) {};
|
FactorGraph (bool fbn = false) : fromBayesNet_(fbn) { }
|
||||||
|
|
||||||
FactorGraph (const FactorGraph&);
|
FactorGraph (const FactorGraph&);
|
||||||
FactorGraph (const BayesNet&);
|
|
||||||
~FactorGraph (void);
|
~FactorGraph (void);
|
||||||
|
|
||||||
void readFromUaiFormat (const char*);
|
const VarNodes& varNodes (void) const { return varNodes_; }
|
||||||
void readFromLibDaiFormat (const char*);
|
|
||||||
void addVariable (FgVarNode*);
|
|
||||||
void addFactor (FgFacNode*);
|
|
||||||
void addEdge (FgVarNode*, FgFacNode*);
|
|
||||||
void addEdge (FgFacNode*, FgVarNode*);
|
|
||||||
VarNode* getVariableNode (unsigned) const;
|
|
||||||
VarNodes getVariableNodes (void) const;
|
|
||||||
bool isTree (void) const;
|
|
||||||
void setIndexes (void);
|
|
||||||
void freeDistributions (void);
|
|
||||||
void printGraphicalModel (void) const;
|
|
||||||
void exportToGraphViz (const char*) const;
|
|
||||||
void exportToUaiFormat (const char*) const;
|
|
||||||
void exportToLibDaiFormat (const char*) const;
|
|
||||||
|
|
||||||
const FgVarSet& getVarNodes (void) const { return varNodes_; }
|
|
||||||
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
|
|
||||||
|
|
||||||
FgVarNode* getFgVarNode (VarId vid) const
|
const FacNodes& facNodes (void) const { return facNodes_; }
|
||||||
|
|
||||||
|
bool isFromBayesNetwork (void) const { return fromBayesNet_ ; }
|
||||||
|
|
||||||
|
VarNode* getVarNode (VarId vid) const
|
||||||
{
|
{
|
||||||
IndexMap::const_iterator it = varMap_.find (vid);
|
VarMap::const_iterator it = varMap_.find (vid);
|
||||||
if (it == varMap_.end()) {
|
return it != varMap_.end() ? it->second : 0;
|
||||||
return 0;
|
|
||||||
} else {
|
|
||||||
return varNodes_[it->second];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void readFromUaiFormat (const char*);
|
||||||
|
|
||||||
|
void readFromLibDaiFormat (const char*);
|
||||||
|
|
||||||
|
void addFactor (const Factor& factor);
|
||||||
|
|
||||||
|
void addVarNode (VarNode*);
|
||||||
|
|
||||||
|
void addFacNode (FacNode*);
|
||||||
|
|
||||||
|
void addEdge (VarNode*, FacNode*);
|
||||||
|
|
||||||
|
bool isTree (void) const;
|
||||||
|
|
||||||
|
DAGraph& getStructure (void);
|
||||||
|
|
||||||
|
void print (void) const;
|
||||||
|
|
||||||
|
void exportToGraphViz (const char*) const;
|
||||||
|
|
||||||
|
void exportToUaiFormat (const char*) const;
|
||||||
|
|
||||||
|
void exportToLibDaiFormat (const char*) const;
|
||||||
|
|
||||||
static bool orderFactorVariables;
|
static bool orderFactorVariables;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
//DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
// DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||||
bool containsCycle (void) const;
|
|
||||||
bool containsCycle (const FgVarNode*, const FgFacNode*,
|
|
||||||
vector<bool>&, vector<bool>&) const;
|
|
||||||
bool containsCycle (const FgFacNode*, const FgVarNode*,
|
|
||||||
vector<bool>&, vector<bool>&) const;
|
|
||||||
|
|
||||||
FgVarSet varNodes_;
|
void ignoreLines (std::ifstream&) const;
|
||||||
FgFacSet facNodes_;
|
|
||||||
|
|
||||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
bool containsCycle (void) const;
|
||||||
IndexMap varMap_;
|
|
||||||
|
bool containsCycle (const VarNode*, const FacNode*,
|
||||||
|
vector<bool>&, vector<bool>&) const;
|
||||||
|
|
||||||
|
bool containsCycle (const FacNode*, const VarNode*,
|
||||||
|
vector<bool>&, vector<bool>&) const;
|
||||||
|
|
||||||
|
VarNodes varNodes_;
|
||||||
|
FacNodes facNodes_;
|
||||||
|
|
||||||
|
DAGraph structure_;
|
||||||
|
bool fromBayesNet_;
|
||||||
|
|
||||||
|
typedef unordered_map<unsigned, VarNode*> VarMap;
|
||||||
|
VarMap varMap_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_FACTORGRAPH_H
|
#endif // HORUS_FACTORGRAPH_H
|
||||||
|
@ -1,175 +0,0 @@
|
|||||||
#ifndef HORUS_FGBPSOLVER_H
|
|
||||||
#define HORUS_FGBPSOLVER_H
|
|
||||||
|
|
||||||
#include <set>
|
|
||||||
#include <vector>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "Solver.h"
|
|
||||||
#include "Factor.h"
|
|
||||||
#include "FactorGraph.h"
|
|
||||||
#include "Util.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SpLink
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
SpLink (FgFacNode* fn, FgVarNode* vn)
|
|
||||||
{
|
|
||||||
fac_ = fn;
|
|
||||||
var_ = vn;
|
|
||||||
v1_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
|
|
||||||
v2_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
|
|
||||||
currMsg_ = &v1_;
|
|
||||||
nextMsg_ = &v2_;
|
|
||||||
msgSended_ = false;
|
|
||||||
residual_ = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual ~SpLink (void) {};
|
|
||||||
|
|
||||||
virtual void updateMessage (void)
|
|
||||||
{
|
|
||||||
swap (currMsg_, nextMsg_);
|
|
||||||
msgSended_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateResidual (void)
|
|
||||||
{
|
|
||||||
residual_ = Util::getMaxNorm (v1_, v2_);
|
|
||||||
}
|
|
||||||
|
|
||||||
string toString (void) const
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss << fac_->getLabel();
|
|
||||||
ss << " -- " ;
|
|
||||||
ss << var_->label();
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
FgFacNode* getFactor (void) const { return fac_; }
|
|
||||||
FgVarNode* getVariable (void) const { return var_; }
|
|
||||||
const Params& getMessage (void) const { return *currMsg_; }
|
|
||||||
Params& getNextMessage (void) { return *nextMsg_; }
|
|
||||||
bool messageWasSended (void) const { return msgSended_; }
|
|
||||||
double getResidual (void) const { return residual_; }
|
|
||||||
void clearResidual (void) { residual_ = 0.0; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
FgFacNode* fac_;
|
|
||||||
FgVarNode* var_;
|
|
||||||
Params v1_;
|
|
||||||
Params v2_;
|
|
||||||
Params* currMsg_;
|
|
||||||
Params* nextMsg_;
|
|
||||||
bool msgSended_;
|
|
||||||
double residual_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
typedef vector<SpLink*> SpLinkSet;
|
|
||||||
|
|
||||||
|
|
||||||
class SPNodeInfo
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
void addSpLink (SpLink* link) { links_.push_back (link); }
|
|
||||||
const SpLinkSet& getLinks (void) { return links_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
SpLinkSet links_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class FgBpSolver : public Solver
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
FgBpSolver (const FactorGraph&);
|
|
||||||
virtual ~FgBpSolver (void);
|
|
||||||
|
|
||||||
void runSolver (void);
|
|
||||||
virtual Params getPosterioriOf (VarId);
|
|
||||||
virtual Params getJointDistributionOf (const VarIds&);
|
|
||||||
|
|
||||||
protected:
|
|
||||||
virtual void initializeSolver (void);
|
|
||||||
virtual void createLinks (void);
|
|
||||||
virtual void maxResidualSchedule (void);
|
|
||||||
virtual void calculateFactor2VariableMsg (SpLink*) const;
|
|
||||||
virtual Params getVar2FactorMsg (const SpLink*) const;
|
|
||||||
virtual Params getJointByConditioning (const VarIds&) const;
|
|
||||||
virtual void printLinkInformation (void) const;
|
|
||||||
|
|
||||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
|
||||||
{
|
|
||||||
if (DL >= 3) {
|
|
||||||
cout << "calculating & updating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
calculateFactor2VariableMsg (link);
|
|
||||||
if (calcResidual) {
|
|
||||||
link->updateResidual();
|
|
||||||
}
|
|
||||||
link->updateMessage();
|
|
||||||
}
|
|
||||||
|
|
||||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
|
||||||
{
|
|
||||||
if (DL >= 3) {
|
|
||||||
cout << "calculating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
calculateFactor2VariableMsg (link);
|
|
||||||
if (calcResidual) {
|
|
||||||
link->updateResidual();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void updateMessage (SpLink* link)
|
|
||||||
{
|
|
||||||
link->updateMessage();
|
|
||||||
if (DL >= 3) {
|
|
||||||
cout << "updating " << link->toString() << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SPNodeInfo* ninf (const FgVarNode* var) const
|
|
||||||
{
|
|
||||||
return varsI_[var->getIndex()];
|
|
||||||
}
|
|
||||||
|
|
||||||
SPNodeInfo* ninf (const FgFacNode* fac) const
|
|
||||||
{
|
|
||||||
return facsI_[fac->getIndex()];
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CompareResidual {
|
|
||||||
inline bool operator() (const SpLink* link1, const SpLink* link2)
|
|
||||||
{
|
|
||||||
return link1->getResidual() > link2->getResidual();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
SpLinkSet links_;
|
|
||||||
unsigned nIters_;
|
|
||||||
vector<SPNodeInfo*> varsI_;
|
|
||||||
vector<SPNodeInfo*> facsI_;
|
|
||||||
const FactorGraph* factorGraph_;
|
|
||||||
|
|
||||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
|
||||||
SortedOrder sortedOrder_;
|
|
||||||
|
|
||||||
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
|
|
||||||
SpLinkMap linkMap_;
|
|
||||||
|
|
||||||
private:
|
|
||||||
void runLoopySolver (void);
|
|
||||||
bool converged (void);
|
|
||||||
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_FGBPSOLVER_H
|
|
||||||
|
|
@ -8,7 +8,9 @@
|
|||||||
|
|
||||||
|
|
||||||
vector<LiftedOperator*>
|
vector<LiftedOperator*>
|
||||||
LiftedOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
|
LiftedOperator::getValidOps (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
{
|
{
|
||||||
vector<LiftedOperator*> validOps;
|
vector<LiftedOperator*> validOps;
|
||||||
vector<SumOutOperator*> sumOutOps;
|
vector<SumOutOperator*> sumOutOps;
|
||||||
@ -28,12 +30,15 @@ LiftedOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
LiftedOperator::printValidOps (ParfactorList& pfList, const Grounds& query)
|
LiftedOperator::printValidOps (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
{
|
{
|
||||||
vector<LiftedOperator*> validOps;
|
vector<LiftedOperator*> validOps;
|
||||||
validOps = LiftedOperator::getValidOps (pfList, query);
|
validOps = LiftedOperator::getValidOps (pfList, query);
|
||||||
for (unsigned i = 0; i < validOps.size(); i++) {
|
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||||
cout << "-> " << validOps[i]->toString() << endl;
|
cout << "-> " << validOps[i]->toString() << endl;
|
||||||
|
delete validOps[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,14 +61,14 @@ SumOutOperator::getCost (void)
|
|||||||
pfIter = pfList_.begin();
|
pfIter = pfList_.begin();
|
||||||
while (pfIter != pfList_.end()) {
|
while (pfIter != pfList_.end()) {
|
||||||
if ((*pfIter)->containsGroup (groupSet[i])) {
|
if ((*pfIter)->containsGroup (groupSet[i])) {
|
||||||
int idx = (*pfIter)->indexOfFormulaWithGroup (groupSet[i]);
|
int idx = (*pfIter)->indexOfGroup (groupSet[i]);
|
||||||
cost *= (*pfIter)->range (idx);
|
cost *= (*pfIter)->range (idx);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
++ pfIter;
|
++ pfIter;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return cost;
|
return cost;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -77,14 +82,13 @@ SumOutOperator::apply (void)
|
|||||||
pfList_.remove (iters[0]);
|
pfList_.remove (iters[0]);
|
||||||
for (unsigned i = 1; i < iters.size(); i++) {
|
for (unsigned i = 1; i < iters.size(); i++) {
|
||||||
product->multiply (**(iters[i]));
|
product->multiply (**(iters[i]));
|
||||||
delete *(iters[i]);
|
pfList_.removeAndDelete (iters[i]);
|
||||||
pfList_.remove (iters[i]);
|
|
||||||
}
|
}
|
||||||
if (product->nrFormulas() == 1) {
|
if (product->nrArguments() == 1) {
|
||||||
delete product;
|
delete product;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
int fIdx = product->indexOfFormulaWithGroup (group_);
|
int fIdx = product->indexOfGroup (group_);
|
||||||
LogVarSet excl = product->exclusiveLogVars (fIdx);
|
LogVarSet excl = product->exclusiveLogVars (fIdx);
|
||||||
if (product->constr()->isCountNormalized (excl)) {
|
if (product->constr()->isCountNormalized (excl)) {
|
||||||
product->sumOut (fIdx);
|
product->sumOut (fIdx);
|
||||||
@ -96,21 +100,21 @@ SumOutOperator::apply (void)
|
|||||||
pfList_.add (pfs[i]);
|
pfList_.add (pfs[i]);
|
||||||
}
|
}
|
||||||
delete product;
|
delete product;
|
||||||
pfList_.shatter();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<SumOutOperator*>
|
vector<SumOutOperator*>
|
||||||
SumOutOperator::getValidOps (ParfactorList& pfList, const Grounds& query)
|
SumOutOperator::getValidOps (
|
||||||
|
ParfactorList& pfList,
|
||||||
|
const Grounds& query)
|
||||||
{
|
{
|
||||||
vector<SumOutOperator*> validOps;
|
vector<SumOutOperator*> validOps;
|
||||||
set<unsigned> allGroups;
|
set<unsigned> allGroups;
|
||||||
ParfactorList::const_iterator it = pfList.begin();
|
ParfactorList::const_iterator it = pfList.begin();
|
||||||
while (it != pfList.end()) {
|
while (it != pfList.end()) {
|
||||||
assert (*it);
|
const ProbFormulas& formulas = (*it)->arguments();
|
||||||
const ProbFormulas& formulas = (*it)->formulas();
|
|
||||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||||
allGroups.insert (formulas[i].group());
|
allGroups.insert (formulas[i].group());
|
||||||
}
|
}
|
||||||
@ -134,8 +138,8 @@ SumOutOperator::toString (void)
|
|||||||
stringstream ss;
|
stringstream ss;
|
||||||
vector<ParfactorList::iterator> pfIters;
|
vector<ParfactorList::iterator> pfIters;
|
||||||
pfIters = parfactorsWithGroup (pfList_, group_);
|
pfIters = parfactorsWithGroup (pfList_, group_);
|
||||||
int idx = (*pfIters[0])->indexOfFormulaWithGroup (group_);
|
int idx = (*pfIters[0])->indexOfGroup (group_);
|
||||||
ProbFormula f = (*pfIters[0])->formula (idx);
|
ProbFormula f = (*pfIters[0])->argument (idx);
|
||||||
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
TupleSet tupleSet = (*pfIters[0])->constr()->tupleSet (f.logVars());
|
||||||
ss << "sum out " << f.functor() << "/" << f.arity();
|
ss << "sum out " << f.functor() << "/" << f.arity();
|
||||||
ss << "|" << tupleSet << " (group " << group_ << ")";
|
ss << "|" << tupleSet << " (group " << group_ << ")";
|
||||||
@ -158,9 +162,9 @@ SumOutOperator::validOp (
|
|||||||
}
|
}
|
||||||
unordered_map<unsigned, unsigned> groupToRange;
|
unordered_map<unsigned, unsigned> groupToRange;
|
||||||
for (unsigned i = 0; i < pfIters.size(); i++) {
|
for (unsigned i = 0; i < pfIters.size(); i++) {
|
||||||
int fIdx = (*pfIters[i])->indexOfFormulaWithGroup (group);
|
int fIdx = (*pfIters[i])->indexOfGroup (group);
|
||||||
if ((*pfIters[i])->formulas()[fIdx].contains (
|
if ((*pfIters[i])->argument (fIdx).contains (
|
||||||
(*pfIters[i])->elimLogVars()) == false) {
|
(*pfIters[i])->elimLogVars()) == false) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
vector<unsigned> ranges = (*pfIters[i])->ranges();
|
vector<unsigned> ranges = (*pfIters[i])->ranges();
|
||||||
@ -206,8 +210,8 @@ SumOutOperator::isToEliminate (
|
|||||||
unsigned group,
|
unsigned group,
|
||||||
const Grounds& query)
|
const Grounds& query)
|
||||||
{
|
{
|
||||||
int fIdx = g->indexOfFormulaWithGroup (group);
|
int fIdx = g->indexOfGroup (group);
|
||||||
const ProbFormula& formula = g->formula (fIdx);
|
const ProbFormula& formula = g->argument (fIdx);
|
||||||
bool toElim = true;
|
bool toElim = true;
|
||||||
for (unsigned i = 0; i < query.size(); i++) {
|
for (unsigned i = 0; i < query.size(); i++) {
|
||||||
if (formula.functor() == query[i].functor() &&
|
if (formula.functor() == query[i].functor() &&
|
||||||
@ -228,7 +232,7 @@ unsigned
|
|||||||
CountingOperator::getCost (void)
|
CountingOperator::getCost (void)
|
||||||
{
|
{
|
||||||
unsigned cost = 0;
|
unsigned cost = 0;
|
||||||
int fIdx = (*pfIter_)->indexOfFormulaWithLogVar (X_);
|
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||||
unsigned range = (*pfIter_)->range (fIdx);
|
unsigned range = (*pfIter_)->range (fIdx);
|
||||||
unsigned size = (*pfIter_)->size() / range;
|
unsigned size = (*pfIter_)->size() / range;
|
||||||
TinySet<unsigned> counts;
|
TinySet<unsigned> counts;
|
||||||
@ -247,18 +251,19 @@ CountingOperator::apply (void)
|
|||||||
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
|
if ((*pfIter_)->constr()->isCountNormalized (X_)) {
|
||||||
(*pfIter_)->countConvert (X_);
|
(*pfIter_)->countConvert (X_);
|
||||||
} else {
|
} else {
|
||||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
Parfactor* pf = *pfIter_;
|
||||||
|
pfList_.remove (pfIter_);
|
||||||
|
Parfactors pfs = FoveSolver::countNormalize (pf, X_);
|
||||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||||
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
unsigned condCount = pfs[i]->constr()->getConditionalCount (X_);
|
||||||
bool cartProduct = pfs[i]->constr()->isCarteesianProduct (
|
bool cartProduct = pfs[i]->constr()->isCarteesianProduct (
|
||||||
(*pfIter_)->countedLogVars() | X_);
|
pfs[i]->countedLogVars() | X_);
|
||||||
if (condCount > 1 && cartProduct) {
|
if (condCount > 1 && cartProduct) {
|
||||||
pfs[i]->countConvert (X_);
|
pfs[i]->countConvert (X_);
|
||||||
}
|
}
|
||||||
pfList_.add (pfs[i]);
|
pfList_.add (pfs[i]);
|
||||||
}
|
}
|
||||||
pfList_.deleteAndRemove (pfIter_);
|
delete pf;
|
||||||
pfList_.shatter();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -289,14 +294,17 @@ CountingOperator::toString (void)
|
|||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "count convert " << X_ << " in " ;
|
ss << "count convert " << X_ << " in " ;
|
||||||
ss << (*pfIter_)->getHeaderString();
|
ss << (*pfIter_)->getLabel();
|
||||||
ss << " [cost=" << getCost() << "]" << endl;
|
ss << " [cost=" << getCost() << "]" << endl;
|
||||||
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
Parfactors pfs = FoveSolver::countNormalize (*pfIter_, X_);
|
||||||
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
if ((*pfIter_)->constr()->isCountNormalized (X_) == false) {
|
||||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||||
ss << " º " << pfs[i]->getHeaderString() << endl;
|
ss << " º " << pfs[i]->getLabel() << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||||
|
delete pfs[i];
|
||||||
|
}
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -308,8 +316,8 @@ CountingOperator::validOp (Parfactor* g, LogVar X)
|
|||||||
if (g->nrFormulas (X) != 1) {
|
if (g->nrFormulas (X) != 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
int fIdx = g->indexOfFormulaWithLogVar (X);
|
int fIdx = g->indexOfLogVar (X);
|
||||||
if (g->formulas()[fIdx].isCounting()) {
|
if (g->argument (fIdx).isCounting()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool countNormalized = g->constr()->isCountNormalized (X);
|
bool countNormalized = g->constr()->isCountNormalized (X);
|
||||||
@ -332,10 +340,10 @@ GroundOperator::getCost (void)
|
|||||||
unsigned cost = 0;
|
unsigned cost = 0;
|
||||||
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
|
bool isCountingLv = (*pfIter_)->countedLogVars().contains (X_);
|
||||||
if (isCountingLv) {
|
if (isCountingLv) {
|
||||||
int fIdx = (*pfIter_)->indexOfFormulaWithLogVar (X_);
|
int fIdx = (*pfIter_)->indexOfLogVar (X_);
|
||||||
unsigned currSize = (*pfIter_)->size();
|
unsigned currSize = (*pfIter_)->size();
|
||||||
unsigned nrHists = (*pfIter_)->range (fIdx);
|
unsigned nrHists = (*pfIter_)->range (fIdx);
|
||||||
unsigned range = (*pfIter_)->formula(fIdx).range();
|
unsigned range = (*pfIter_)->argument (fIdx).range();
|
||||||
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
|
unsigned nrSymbols = (*pfIter_)->constr()->getConditionalCount (X_);
|
||||||
cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
|
cost = (currSize / nrHists) * (std::pow (range, nrSymbols));
|
||||||
} else {
|
} else {
|
||||||
@ -350,18 +358,17 @@ void
|
|||||||
GroundOperator::apply (void)
|
GroundOperator::apply (void)
|
||||||
{
|
{
|
||||||
bool countedLv = (*pfIter_)->countedLogVars().contains (X_);
|
bool countedLv = (*pfIter_)->countedLogVars().contains (X_);
|
||||||
|
Parfactor* pf = *pfIter_;
|
||||||
|
pfList_.remove (pfIter_);
|
||||||
if (countedLv) {
|
if (countedLv) {
|
||||||
(*pfIter_)->fullExpand (X_);
|
pf->fullExpand (X_);
|
||||||
(*pfIter_)->setNewGroups();
|
pfList_.add (pf);
|
||||||
pfList_.shatter();
|
|
||||||
} else {
|
} else {
|
||||||
ConstraintTrees cts = (*pfIter_)->constr()->ground (X_);
|
ConstraintTrees cts = pf->constr()->ground (X_);
|
||||||
for (unsigned i = 0; i < cts.size(); i++) {
|
for (unsigned i = 0; i < cts.size(); i++) {
|
||||||
Parfactor* newPf = new Parfactor (*pfIter_, cts[i]);
|
pfList_.add (new Parfactor (pf, cts[i]));
|
||||||
pfList_.add (newPf);
|
|
||||||
}
|
}
|
||||||
pfList_.deleteAndRemove (pfIter_);
|
delete pf;
|
||||||
pfList_.shatter();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -393,24 +400,13 @@ GroundOperator::toString (void)
|
|||||||
((*pfIter_)->countedLogVars().contains (X_))
|
((*pfIter_)->countedLogVars().contains (X_))
|
||||||
? ss << "full expanding "
|
? ss << "full expanding "
|
||||||
: ss << "grounding " ;
|
: ss << "grounding " ;
|
||||||
ss << X_ << " in " << (*pfIter_)->getHeaderString();
|
ss << X_ << " in " << (*pfIter_)->getLabel();
|
||||||
ss << " [cost=" << getCost() << "]" << endl;
|
ss << " [cost=" << getCost() << "]" << endl;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FoveSolver::FoveSolver (const ParfactorList* pfList)
|
|
||||||
{
|
|
||||||
for (ParfactorList::const_iterator it = pfList->begin();
|
|
||||||
it != pfList->end();
|
|
||||||
it ++) {
|
|
||||||
pfList_.addShattered (new Parfactor (**it));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
FoveSolver::getPosterioriOf (const Ground& query)
|
FoveSolver::getPosterioriOf (const Ground& query)
|
||||||
{
|
{
|
||||||
@ -422,14 +418,12 @@ FoveSolver::getPosterioriOf (const Ground& query)
|
|||||||
Params
|
Params
|
||||||
FoveSolver::getJointDistributionOf (const Grounds& query)
|
FoveSolver::getJointDistributionOf (const Grounds& query)
|
||||||
{
|
{
|
||||||
shatterAgainstQuery (query);
|
|
||||||
runSolver (query);
|
runSolver (query);
|
||||||
(*pfList_.begin())->normalize();
|
(*pfList_.begin())->normalize();
|
||||||
Params params = (*pfList_.begin())->params();
|
Params params = (*pfList_.begin())->params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (params);
|
Util::fromLog (params);
|
||||||
}
|
}
|
||||||
delete *pfList_.begin();
|
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -438,32 +432,38 @@ FoveSolver::getJointDistributionOf (const Grounds& query)
|
|||||||
void
|
void
|
||||||
FoveSolver::absorveEvidence (
|
FoveSolver::absorveEvidence (
|
||||||
ParfactorList& pfList,
|
ParfactorList& pfList,
|
||||||
const ObservedFormulas& obsFormulas)
|
ObservedFormulas& obsFormulas)
|
||||||
{
|
{
|
||||||
ParfactorList::iterator it = pfList.begin();
|
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||||
while (it != pfList.end()) {
|
Parfactors newPfs;
|
||||||
bool increment = true;
|
ParfactorList::iterator it = pfList.begin();
|
||||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
while (it != pfList.end()) {
|
||||||
if (absorved (pfList, it, obsFormulas[i])) {
|
Parfactor* pf = *it;
|
||||||
it = pfList.deleteAndRemove (it);
|
it = pfList.remove (it);
|
||||||
increment = false;
|
Parfactors absorvedPfs = absorve (obsFormulas[i], pf);
|
||||||
break;
|
if (absorvedPfs.empty() == false) {
|
||||||
}
|
if (absorvedPfs.size() == 1 && absorvedPfs[0] == 0) {
|
||||||
}
|
// just remove pf;
|
||||||
if (increment) {
|
} else {
|
||||||
++ it;
|
Util::addToVector (newPfs, absorvedPfs);
|
||||||
|
}
|
||||||
|
delete pf;
|
||||||
|
} else {
|
||||||
|
it = pfList.insertShattered (it, pf);
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
pfList.add (newPfs);
|
||||||
}
|
}
|
||||||
pfList.shatter();
|
if (Constants::DEBUG >= 2 && obsFormulas.empty() == false) {
|
||||||
if (obsFormulas.empty() == false) {
|
Util::printAsteriskLine();
|
||||||
cout << "*******************************************************" << endl;
|
|
||||||
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
cout << "AFTER EVIDENCE ABSORVED" << endl;
|
||||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||||
cout << " -> " << *obsFormulas[i] << endl;
|
cout << " -> " << obsFormulas[i] << endl;
|
||||||
}
|
}
|
||||||
cout << "*******************************************************" << endl;
|
Util::printAsteriskLine();
|
||||||
|
pfList.print();
|
||||||
}
|
}
|
||||||
pfList.print();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -473,14 +473,14 @@ FoveSolver::countNormalize (
|
|||||||
Parfactor* g,
|
Parfactor* g,
|
||||||
const LogVarSet& set)
|
const LogVarSet& set)
|
||||||
{
|
{
|
||||||
if (set.empty()) {
|
|
||||||
assert (false); // TODO
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
Parfactors normPfs;
|
Parfactors normPfs;
|
||||||
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
if (set.empty()) {
|
||||||
for (unsigned i = 0; i < normCts.size(); i++) {
|
normPfs.push_back (new Parfactor (*g));
|
||||||
normPfs.push_back (new Parfactor (g, normCts[i]));
|
} else {
|
||||||
|
ConstraintTrees normCts = g->constr()->countNormalize (set);
|
||||||
|
for (unsigned i = 0; i < normCts.size(); i++) {
|
||||||
|
normPfs.push_back (new Parfactor (g, normCts[i]));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return normPfs;
|
return normPfs;
|
||||||
}
|
}
|
||||||
@ -490,17 +490,25 @@ FoveSolver::countNormalize (
|
|||||||
void
|
void
|
||||||
FoveSolver::runSolver (const Grounds& query)
|
FoveSolver::runSolver (const Grounds& query)
|
||||||
{
|
{
|
||||||
|
shatterAgainstQuery (query);
|
||||||
|
runWeakBayesBall (query);
|
||||||
while (true) {
|
while (true) {
|
||||||
cout << "---------------------------------------------------" << endl;
|
if (Constants::DEBUG >= 2) {
|
||||||
pfList_.print();
|
Util::printDashedLine();
|
||||||
LiftedOperator::printValidOps (pfList_, query);
|
pfList_.print();
|
||||||
|
LiftedOperator::printValidOps (pfList_, query);
|
||||||
|
}
|
||||||
LiftedOperator* op = getBestOperation (query);
|
LiftedOperator* op = getBestOperation (query);
|
||||||
if (op == 0) {
|
if (op == 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
cout << "best operation: " << op->toString() << endl;
|
if (Constants::DEBUG >= 2) {
|
||||||
|
cout << "best operation: " << op->toString() << endl;
|
||||||
|
}
|
||||||
op->apply();
|
op->apply();
|
||||||
|
delete op;
|
||||||
}
|
}
|
||||||
|
assert (pfList_.size() > 0);
|
||||||
if (pfList_.size() > 1) {
|
if (pfList_.size() > 1) {
|
||||||
ParfactorList::iterator pfIter = pfList_.begin();
|
ParfactorList::iterator pfIter = pfList_.begin();
|
||||||
pfIter ++;
|
pfIter ++;
|
||||||
@ -514,26 +522,6 @@ FoveSolver::runSolver (const Grounds& query)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
FoveSolver::allEliminated (const Grounds&)
|
|
||||||
{
|
|
||||||
ParfactorList::iterator pfIter = pfList_.begin();
|
|
||||||
while (pfIter != pfList_.end()) {
|
|
||||||
const ProbFormulas formulas = (*pfIter)->formulas();
|
|
||||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
|
||||||
//bool toElim = false;
|
|
||||||
//for (unsigned j = 0; j < queries.size(); j++) {
|
|
||||||
// if ((*pfIter)->containsGround (queries[i]) == false) {
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
++ pfIter;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
LiftedOperator*
|
LiftedOperator*
|
||||||
FoveSolver::getBestOperation (const Grounds& query)
|
FoveSolver::getBestOperation (const Grounds& query)
|
||||||
{
|
{
|
||||||
@ -548,156 +536,176 @@ FoveSolver::getBestOperation (const Grounds& query)
|
|||||||
bestCost = cost;
|
bestCost = cost;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (unsigned i = 0; i < validOps.size(); i++) {
|
||||||
|
if (validOps[i] != bestOp) {
|
||||||
|
delete validOps[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
return bestOp;
|
return bestOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
FoveSolver::runWeakBayesBall (const Grounds& query)
|
||||||
|
{
|
||||||
|
queue<unsigned> todo; // groups to process
|
||||||
|
set<unsigned> done; // processed or in queue
|
||||||
|
for (unsigned i = 0; i < query.size(); i++) {
|
||||||
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
int group = (*it)->findGroup (query[i]);
|
||||||
|
if (group != -1) {
|
||||||
|
todo.push (group);
|
||||||
|
done.insert (group);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
set<Parfactor*> requiredPfs;
|
||||||
|
while (todo.empty() == false) {
|
||||||
|
unsigned group = todo.front();
|
||||||
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
if (Util::contains (requiredPfs, *it) == false &&
|
||||||
|
(*it)->containsGroup (group)) {
|
||||||
|
vector<unsigned> groups = (*it)->getAllGroups();
|
||||||
|
for (unsigned i = 0; i < groups.size(); i++) {
|
||||||
|
if (Util::contains (done, groups[i]) == false) {
|
||||||
|
todo.push (groups[i]);
|
||||||
|
done.insert (groups[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
requiredPfs.insert (*it);
|
||||||
|
}
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
todo.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
if (Util::contains (requiredPfs, *it) == false) {
|
||||||
|
it = pfList_.removeAndDelete (it);
|
||||||
|
} else {
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Constants::DEBUG >= 2) {
|
||||||
|
Util::printHeader ("REQUIRED PARFACTORS");
|
||||||
|
pfList_.print();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
FoveSolver::shatterAgainstQuery (const Grounds& query)
|
FoveSolver::shatterAgainstQuery (const Grounds& query)
|
||||||
{
|
{
|
||||||
// return;
|
|
||||||
for (unsigned i = 0; i < query.size(); i++) {
|
for (unsigned i = 0; i < query.size(); i++) {
|
||||||
if (query[i].isAtom()) {
|
if (query[i].isAtom()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
ParfactorList pfListCopy = pfList_;
|
bool found = false;
|
||||||
pfList_.clear();
|
Parfactors newPfs;
|
||||||
for (ParfactorList::iterator it = pfListCopy.begin();
|
ParfactorList::iterator it = pfList_.begin();
|
||||||
it != pfListCopy.end(); ++ it) {
|
while (it != pfList_.end()) {
|
||||||
Parfactor* pf = *it;
|
if ((*it)->containsGround (query[i])) {
|
||||||
if (pf->containsGround (query[i])) {
|
found = true;
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> split =
|
std::pair<ConstraintTree*, ConstraintTree*> split =
|
||||||
pf->constr()->split (query[i].args(), query[i].arity());
|
(*it)->constr()->split (query[i].args(), query[i].arity());
|
||||||
ConstraintTree* commCt = split.first;
|
ConstraintTree* commCt = split.first;
|
||||||
ConstraintTree* exclCt = split.second;
|
ConstraintTree* exclCt = split.second;
|
||||||
pfList_.add (new Parfactor (pf, commCt));
|
newPfs.push_back (new Parfactor (*it, commCt));
|
||||||
if (exclCt->empty() == false) {
|
if (exclCt->empty() == false) {
|
||||||
pfList_.add (new Parfactor (pf, exclCt));
|
newPfs.push_back (new Parfactor (*it, exclCt));
|
||||||
} else {
|
} else {
|
||||||
delete exclCt;
|
delete exclCt;
|
||||||
}
|
}
|
||||||
delete pf;
|
it = pfList_.removeAndDelete (it);
|
||||||
} else {
|
} else {
|
||||||
pfList_.add (pf);
|
++ it;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pfList_.shatter();
|
if (found == false) {
|
||||||
|
cerr << "error: could not find a parfactor with ground " ;
|
||||||
|
cerr << "`" << query[i] << "'" << endl;
|
||||||
|
exit (0);
|
||||||
|
}
|
||||||
|
pfList_.add (newPfs);
|
||||||
}
|
}
|
||||||
cout << endl;
|
if (Constants::DEBUG >= 2) {
|
||||||
cout << "*******************************************************" << endl;
|
cout << endl;
|
||||||
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
Util::printAsteriskLine();
|
||||||
for (unsigned i = 0; i < query.size(); i++) {
|
cout << "SHATTERED AGAINST THE QUERY" << endl;
|
||||||
cout << " -> " << query[i] << endl;
|
for (unsigned i = 0; i < query.size(); i++) {
|
||||||
|
cout << " -> " << query[i] << endl;
|
||||||
|
}
|
||||||
|
Util::printAsteriskLine();
|
||||||
|
pfList_.print();
|
||||||
}
|
}
|
||||||
cout << "*******************************************************" << endl;
|
|
||||||
pfList_.print();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
Parfactors
|
||||||
FoveSolver::absorved (
|
FoveSolver::absorve (
|
||||||
ParfactorList& pfList,
|
ObservedFormula& obsFormula,
|
||||||
ParfactorList::iterator pfIter,
|
Parfactor* g)
|
||||||
const ObservedFormula* obsFormula)
|
|
||||||
{
|
{
|
||||||
Parfactors absorvedPfs;
|
Parfactors absorvedPfs;
|
||||||
Parfactor* g = *pfIter;
|
const ProbFormulas& formulas = g->arguments();
|
||||||
const ProbFormulas& formulas = g->formulas();
|
|
||||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||||
if (obsFormula->functor() == formulas[i].functor() &&
|
if (obsFormula.functor() == formulas[i].functor() &&
|
||||||
obsFormula->arity() == formulas[i].arity()) {
|
obsFormula.arity() == formulas[i].arity()) {
|
||||||
|
|
||||||
if (obsFormula->isAtom()) {
|
if (obsFormula.isAtom()) {
|
||||||
if (formulas.size() > 1) {
|
if (formulas.size() > 1) {
|
||||||
g->absorveEvidence (i, obsFormula->evidence());
|
g->absorveEvidence (formulas[i], obsFormula.evidence());
|
||||||
} else {
|
} else {
|
||||||
return true;
|
// hack to erase parfactor g
|
||||||
|
absorvedPfs.push_back (0);
|
||||||
}
|
}
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
g->constr()->moveToTop (formulas[i].logVars());
|
g->constr()->moveToTop (formulas[i].logVars());
|
||||||
std::pair<ConstraintTree*, ConstraintTree*> res
|
std::pair<ConstraintTree*, ConstraintTree*> res
|
||||||
= g->constr()->split (obsFormula->constr(), formulas[i].arity());
|
= g->constr()->split (&(obsFormula.constr()), formulas[i].arity());
|
||||||
ConstraintTree* commCt = res.first;
|
ConstraintTree* commCt = res.first;
|
||||||
ConstraintTree* exclCt = res.second;
|
ConstraintTree* exclCt = res.second;
|
||||||
|
|
||||||
if (commCt->empty()) {
|
if (commCt->empty() == false) {
|
||||||
delete commCt;
|
if (formulas.size() > 1) {
|
||||||
delete exclCt;
|
LogVarSet excl = g->exclusiveLogVars (i);
|
||||||
continue;
|
Parfactors countNormPfs = countNormalize (g, excl);
|
||||||
}
|
for (unsigned j = 0; j < countNormPfs.size(); j++) {
|
||||||
|
countNormPfs[j]->absorveEvidence (
|
||||||
if (exclCt->empty() == false) {
|
formulas[i], obsFormula.evidence());
|
||||||
pfList.add (new Parfactor (g, exclCt));
|
absorvedPfs.push_back (countNormPfs[j]);
|
||||||
} else {
|
}
|
||||||
delete exclCt;
|
} else {
|
||||||
}
|
delete commCt;
|
||||||
|
|
||||||
if (formulas.size() > 1) {
|
|
||||||
LogVarSet excl = g->exclusiveLogVars (i);
|
|
||||||
Parfactors countNormPfs = countNormalize (g, excl);
|
|
||||||
for (unsigned j = 0; j < countNormPfs.size(); j++) {
|
|
||||||
countNormPfs[j]->absorveEvidence (i, obsFormula->evidence());
|
|
||||||
absorvedPfs.push_back (countNormPfs[j]);
|
|
||||||
}
|
}
|
||||||
|
if (exclCt->empty() == false) {
|
||||||
|
absorvedPfs.push_back (new Parfactor (g, exclCt));
|
||||||
|
} else {
|
||||||
|
delete exclCt;
|
||||||
|
}
|
||||||
|
if (absorvedPfs.empty()) {
|
||||||
|
// hack to erase parfactor g
|
||||||
|
absorvedPfs.push_back (0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
} else {
|
} else {
|
||||||
delete commCt;
|
delete commCt;
|
||||||
|
delete exclCt;
|
||||||
}
|
}
|
||||||
return true;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return absorvedPfs;
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
FoveSolver::proper (
|
|
||||||
const ProbFormula& f1,
|
|
||||||
ConstraintTree* c1,
|
|
||||||
const ProbFormula& f2,
|
|
||||||
ConstraintTree* c2)
|
|
||||||
{
|
|
||||||
return disjoint (f1, c1, f2, c2)
|
|
||||||
|| identical (f1, c1, f2, c2);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
FoveSolver::identical (
|
|
||||||
const ProbFormula& f1,
|
|
||||||
ConstraintTree* c1,
|
|
||||||
const ProbFormula& f2,
|
|
||||||
ConstraintTree* c2)
|
|
||||||
{
|
|
||||||
if (f1.sameSkeletonAs (f2) == false) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
c1->moveToTop (f1.logVars());
|
|
||||||
c2->moveToTop (f2.logVars());
|
|
||||||
return ConstraintTree::identical (
|
|
||||||
c1, c2, f1.logVars().size());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
FoveSolver::disjoint (
|
|
||||||
const ProbFormula& f1,
|
|
||||||
ConstraintTree* c1,
|
|
||||||
const ProbFormula& f2,
|
|
||||||
ConstraintTree* c2)
|
|
||||||
{
|
|
||||||
if (f1.sameSkeletonAs (f2) == false) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
c1->moveToTop (f1.logVars());
|
|
||||||
c2->moveToTop (f2.logVars());
|
|
||||||
return ConstraintTree::overlap (
|
|
||||||
c1, c2, f1.arity()) == false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,10 +9,14 @@ class LiftedOperator
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
virtual unsigned getCost (void) = 0;
|
virtual unsigned getCost (void) = 0;
|
||||||
|
|
||||||
virtual void apply (void) = 0;
|
virtual void apply (void) = 0;
|
||||||
|
|
||||||
virtual string toString (void) = 0;
|
virtual string toString (void) = 0;
|
||||||
|
|
||||||
static vector<LiftedOperator*> getValidOps (
|
static vector<LiftedOperator*> getValidOps (
|
||||||
ParfactorList&, const Grounds&);
|
ParfactorList&, const Grounds&);
|
||||||
|
|
||||||
static void printValidOps (ParfactorList&, const Grounds&);
|
static void printValidOps (ParfactorList&, const Grounds&);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -23,18 +27,26 @@ class SumOutOperator : public LiftedOperator
|
|||||||
public:
|
public:
|
||||||
SumOutOperator (unsigned group, ParfactorList& pfList)
|
SumOutOperator (unsigned group, ParfactorList& pfList)
|
||||||
: group_(group), pfList_(pfList) { }
|
: group_(group), pfList_(pfList) { }
|
||||||
|
|
||||||
unsigned getCost (void);
|
unsigned getCost (void);
|
||||||
|
|
||||||
void apply (void);
|
void apply (void);
|
||||||
|
|
||||||
static vector<SumOutOperator*> getValidOps (
|
static vector<SumOutOperator*> getValidOps (
|
||||||
ParfactorList&, const Grounds&);
|
ParfactorList&, const Grounds&);
|
||||||
|
|
||||||
string toString (void);
|
string toString (void);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static bool validOp (unsigned, ParfactorList&, const Grounds&);
|
static bool validOp (unsigned, ParfactorList&, const Grounds&);
|
||||||
|
|
||||||
static vector<ParfactorList::iterator> parfactorsWithGroup (
|
static vector<ParfactorList::iterator> parfactorsWithGroup (
|
||||||
ParfactorList& pfList, unsigned group);
|
ParfactorList& pfList, unsigned group);
|
||||||
|
|
||||||
static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
|
static bool isToEliminate (Parfactor*, unsigned, const Grounds&);
|
||||||
unsigned group_;
|
|
||||||
ParfactorList& pfList_;
|
unsigned group_;
|
||||||
|
ParfactorList& pfList_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -47,15 +59,21 @@ class CountingOperator : public LiftedOperator
|
|||||||
LogVar X,
|
LogVar X,
|
||||||
ParfactorList& pfList)
|
ParfactorList& pfList)
|
||||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||||
|
|
||||||
unsigned getCost (void);
|
unsigned getCost (void);
|
||||||
|
|
||||||
void apply (void);
|
void apply (void);
|
||||||
|
|
||||||
static vector<CountingOperator*> getValidOps (ParfactorList&);
|
static vector<CountingOperator*> getValidOps (ParfactorList&);
|
||||||
|
|
||||||
string toString (void);
|
string toString (void);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static bool validOp (Parfactor*, LogVar);
|
static bool validOp (Parfactor*, LogVar);
|
||||||
ParfactorList::iterator pfIter_;
|
|
||||||
LogVar X_;
|
ParfactorList::iterator pfIter_;
|
||||||
ParfactorList& pfList_;
|
LogVar X_;
|
||||||
|
ParfactorList& pfList_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -68,14 +86,19 @@ class GroundOperator : public LiftedOperator
|
|||||||
LogVar X,
|
LogVar X,
|
||||||
ParfactorList& pfList)
|
ParfactorList& pfList)
|
||||||
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
: pfIter_(pfIter), X_(X), pfList_(pfList) { }
|
||||||
|
|
||||||
unsigned getCost (void);
|
unsigned getCost (void);
|
||||||
|
|
||||||
void apply (void);
|
void apply (void);
|
||||||
|
|
||||||
static vector<GroundOperator*> getValidOps (ParfactorList&);
|
static vector<GroundOperator*> getValidOps (ParfactorList&);
|
||||||
|
|
||||||
string toString (void);
|
string toString (void);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ParfactorList::iterator pfIter_;
|
ParfactorList::iterator pfIter_;
|
||||||
LogVar X_;
|
LogVar X_;
|
||||||
ParfactorList& pfList_;
|
ParfactorList& pfList_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -83,49 +106,29 @@ class GroundOperator : public LiftedOperator
|
|||||||
class FoveSolver
|
class FoveSolver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FoveSolver (const ParfactorList*);
|
FoveSolver (const ParfactorList& pfList) : pfList_(pfList) { }
|
||||||
|
|
||||||
Params getPosterioriOf (const Ground&);
|
Params getPosterioriOf (const Ground&);
|
||||||
Params getJointDistributionOf (const Grounds&);
|
|
||||||
|
|
||||||
static void absorveEvidence (
|
Params getJointDistributionOf (const Grounds&);
|
||||||
ParfactorList& pfList,
|
|
||||||
const ObservedFormulas& obsFormulas);
|
|
||||||
|
|
||||||
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
static void absorveEvidence (
|
||||||
|
ParfactorList& pfList, ObservedFormulas& obsFormulas);
|
||||||
|
|
||||||
|
static Parfactors countNormalize (Parfactor*, const LogVarSet&);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void runSolver (const Grounds&);
|
void runSolver (const Grounds&);
|
||||||
bool allEliminated (const Grounds&);
|
|
||||||
LiftedOperator* getBestOperation (const Grounds&);
|
|
||||||
void shatterAgainstQuery (const Grounds&);
|
|
||||||
|
|
||||||
static bool absorved (
|
LiftedOperator* getBestOperation (const Grounds&);
|
||||||
ParfactorList& pfList,
|
|
||||||
ParfactorList::iterator pfIter,
|
|
||||||
const ObservedFormula*);
|
|
||||||
|
|
||||||
public:
|
void runWeakBayesBall (const Grounds&);
|
||||||
|
|
||||||
static bool proper (
|
void shatterAgainstQuery (const Grounds&);
|
||||||
const ProbFormula&,
|
|
||||||
ConstraintTree*,
|
|
||||||
const ProbFormula&,
|
|
||||||
ConstraintTree*);
|
|
||||||
|
|
||||||
static bool identical (
|
static Parfactors absorve (ObservedFormula&, Parfactor*);
|
||||||
const ProbFormula&,
|
|
||||||
ConstraintTree*,
|
|
||||||
const ProbFormula&,
|
|
||||||
ConstraintTree*);
|
|
||||||
|
|
||||||
static bool disjoint (
|
ParfactorList pfList_;
|
||||||
const ProbFormula&,
|
|
||||||
ConstraintTree*,
|
|
||||||
const ProbFormula&,
|
|
||||||
ConstraintTree*);
|
|
||||||
|
|
||||||
ParfactorList pfList_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_FOVESOLVER_H
|
#endif // HORUS_FOVESOLVER_H
|
||||||
|
@ -1,67 +0,0 @@
|
|||||||
#ifndef HORUS_GRAPHICALMODEL_H
|
|
||||||
#define HORUS_GRAPHICALMODEL_H
|
|
||||||
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "VarNode.h"
|
|
||||||
#include "Distribution.h"
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
struct VariableInfo
|
|
||||||
{
|
|
||||||
VariableInfo (string l, const States& sts)
|
|
||||||
{
|
|
||||||
label = l;
|
|
||||||
states = sts;
|
|
||||||
}
|
|
||||||
string label;
|
|
||||||
States states;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class GraphicalModel
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
virtual ~GraphicalModel (void) {};
|
|
||||||
virtual VarNode* getVariableNode (VarId) const = 0;
|
|
||||||
virtual VarNodes getVariableNodes (void) const = 0;
|
|
||||||
virtual void printGraphicalModel (void) const = 0;
|
|
||||||
|
|
||||||
static void addVariableInformation (VarId vid, string label,
|
|
||||||
const States& states)
|
|
||||||
{
|
|
||||||
assert (varsInfo_.find (vid) == varsInfo_.end());
|
|
||||||
varsInfo_.insert (make_pair (vid, VariableInfo (label, states)));
|
|
||||||
}
|
|
||||||
static VariableInfo getVariableInformation (VarId vid)
|
|
||||||
{
|
|
||||||
assert (varsInfo_.find (vid) != varsInfo_.end());
|
|
||||||
return varsInfo_.find (vid)->second;
|
|
||||||
}
|
|
||||||
static bool variablesHaveInformation (void)
|
|
||||||
{
|
|
||||||
return varsInfo_.size() != 0;
|
|
||||||
}
|
|
||||||
static void clearVariablesInformation (void)
|
|
||||||
{
|
|
||||||
varsInfo_.clear();
|
|
||||||
}
|
|
||||||
static void addDistribution (unsigned id, Distribution* dist)
|
|
||||||
{
|
|
||||||
distsInfo_[id] = dist;
|
|
||||||
}
|
|
||||||
static void updateDistribution (unsigned id, const Params& params)
|
|
||||||
{
|
|
||||||
distsInfo_[id]->updateParameters (params);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
static unordered_map<VarId,VariableInfo> varsInfo_;
|
|
||||||
static unordered_map<unsigned,Distribution*> distsInfo_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // HORUS_GRAPHICALMODEL_H
|
|
||||||
|
|
@ -84,16 +84,34 @@ HistogramSet::nrHistograms (unsigned N, unsigned R)
|
|||||||
|
|
||||||
unsigned
|
unsigned
|
||||||
HistogramSet::findIndex (
|
HistogramSet::findIndex (
|
||||||
const Histogram& hist,
|
const Histogram& h,
|
||||||
const vector<Histogram>& histograms)
|
const vector<Histogram>& hists)
|
||||||
{
|
{
|
||||||
vector<Histogram>::const_iterator it = std::lower_bound (
|
vector<Histogram>::const_iterator it = std::lower_bound (
|
||||||
histograms.begin(),
|
hists.begin(), hists.end(), h, std::greater<Histogram>());
|
||||||
histograms.end(),
|
assert (it != hists.end() && *it == h);
|
||||||
hist,
|
return std::distance (hists.begin(), it);
|
||||||
std::greater<Histogram>());
|
}
|
||||||
assert (it != histograms.end() && *it == hist);
|
|
||||||
return std::distance (histograms.begin(), it);
|
|
||||||
|
|
||||||
|
vector<double>
|
||||||
|
HistogramSet::getNumAssigns (unsigned N, unsigned R)
|
||||||
|
{
|
||||||
|
HistogramSet hs (N, R);
|
||||||
|
unsigned N_factorial = Util::factorial (N);
|
||||||
|
unsigned H = hs.nrHistograms();
|
||||||
|
vector<double> numAssigns;
|
||||||
|
numAssigns.reserve (H);
|
||||||
|
for (unsigned h = 0; h < H; h++) {
|
||||||
|
unsigned prod = 1;
|
||||||
|
for (unsigned r = 0; r < R; r++) {
|
||||||
|
prod *= Util::factorial (hs[r]);
|
||||||
|
}
|
||||||
|
numAssigns.push_back (LogAware::tl (N_factorial / prod));
|
||||||
|
hs.nextHistogram();
|
||||||
|
}
|
||||||
|
return numAssigns;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,8 +26,9 @@ class HistogramSet
|
|||||||
static unsigned nrHistograms (unsigned, unsigned);
|
static unsigned nrHistograms (unsigned, unsigned);
|
||||||
|
|
||||||
static unsigned findIndex (
|
static unsigned findIndex (
|
||||||
const Histogram&,
|
const Histogram&, const vector<Histogram>&);
|
||||||
const vector<Histogram>&);
|
|
||||||
|
static vector<double> getNumAssigns (unsigned, unsigned);
|
||||||
|
|
||||||
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);
|
friend std::ostream& operator<< (ostream &os, const HistogramSet& hs);
|
||||||
|
|
||||||
|
@ -1,17 +1,9 @@
|
|||||||
#ifndef HORUS_HORUS_H
|
#ifndef HORUS_HORUS_H
|
||||||
#define HORUS_HORUS_H
|
#define HORUS_HORUS_H
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <cassert>
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <fstream>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
||||||
TypeName(const TypeName&); \
|
TypeName(const TypeName&); \
|
||||||
@ -19,55 +11,51 @@
|
|||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class VarNode;
|
class Var;
|
||||||
class BayesNode;
|
|
||||||
class FgVarNode;
|
|
||||||
class FgFacNode;
|
|
||||||
class Factor;
|
class Factor;
|
||||||
|
class VarNode;
|
||||||
|
class FacNode;
|
||||||
|
|
||||||
typedef vector<double> Params;
|
typedef vector<double> Params;
|
||||||
typedef unsigned VarId;
|
typedef unsigned VarId;
|
||||||
typedef vector<VarId> VarIds;
|
typedef vector<VarId> VarIds;
|
||||||
typedef vector<VarNode*> VarNodes;
|
typedef vector<Var*> Vars;
|
||||||
typedef vector<BayesNode*> BnNodeSet;
|
typedef vector<VarNode*> VarNodes;
|
||||||
typedef vector<FgVarNode*> FgVarSet;
|
typedef vector<FacNode*> FacNodes;
|
||||||
typedef vector<FgFacNode*> FgFacSet;
|
typedef vector<Factor*> Factors;
|
||||||
typedef vector<Factor*> FactorSet;
|
typedef vector<string> States;
|
||||||
typedef vector<string> States;
|
typedef vector<unsigned> Ranges;
|
||||||
typedef vector<unsigned> Ranges;
|
|
||||||
|
|
||||||
|
|
||||||
namespace Globals {
|
enum InfAlgorithms
|
||||||
extern bool logDomain;
|
{
|
||||||
|
VE, // variable elimination
|
||||||
|
BP, // belief propagation
|
||||||
|
CBP // counting belief propagation
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// level of debug information
|
namespace Globals {
|
||||||
static const unsigned DL = 1;
|
|
||||||
|
|
||||||
static const int NO_EVIDENCE = -1;
|
extern bool logDomain;
|
||||||
|
|
||||||
|
extern InfAlgorithms infAlgorithm;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
namespace Constants {
|
||||||
|
|
||||||
|
// level of debug information
|
||||||
|
const unsigned DEBUG = 0;
|
||||||
|
|
||||||
|
const int NO_EVIDENCE = -1;
|
||||||
|
|
||||||
// number of digits to show when printing a parameter
|
// number of digits to show when printing a parameter
|
||||||
static const unsigned PRECISION = 5;
|
const unsigned PRECISION = 5;
|
||||||
|
|
||||||
static const bool COLLECT_STATISTICS = false;
|
const bool COLLECT_STATS = false;
|
||||||
|
|
||||||
static const bool EXPORT_TO_GRAPHVIZ = false;
|
|
||||||
static const unsigned EXPORT_MINIMAL_SIZE = 100;
|
|
||||||
|
|
||||||
static const double INF = -numeric_limits<double>::infinity();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace InfAlgorithms {
|
|
||||||
enum InfAlgs
|
|
||||||
{
|
|
||||||
VE, // variable elimination
|
|
||||||
BN_BP, // bayesian network belief propagation
|
|
||||||
FG_BP, // factor graph belief propagation
|
|
||||||
CBP // counting bp solver
|
|
||||||
};
|
|
||||||
extern InfAlgs infAlgorithm;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,197 +3,89 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "BayesNet.h"
|
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "VarElimSolver.h"
|
#include "VarElimSolver.h"
|
||||||
#include "BnBpSolver.h"
|
#include "BpSolver.h"
|
||||||
#include "FgBpSolver.h"
|
|
||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
|
|
||||||
//#include "TinySet.h"
|
|
||||||
#include "LiftedUtils.h"
|
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
void processArguments (BayesNet&, int, const char* []);
|
|
||||||
void processArguments (FactorGraph&, int, const char* []);
|
void processArguments (FactorGraph&, int, const char* []);
|
||||||
void runSolver (Solver*, const VarNodes&);
|
void runSolver (const FactorGraph&, const VarIds&);
|
||||||
|
|
||||||
const string USAGE = "usage: \
|
const string USAGE = "usage: \
|
||||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
./hcli ve|bp|cbp NETWORK_FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||||
|
|
||||||
|
|
||||||
class Cenas
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
Cenas (int cc)
|
|
||||||
{
|
|
||||||
c = cc;
|
|
||||||
}
|
|
||||||
//operator int (void) const
|
|
||||||
//{
|
|
||||||
// cout << "return int" << endl;
|
|
||||||
// return c;
|
|
||||||
//}
|
|
||||||
operator double (void) const
|
|
||||||
{
|
|
||||||
cout << "return double" << endl;
|
|
||||||
return 0.0;
|
|
||||||
}
|
|
||||||
private:
|
|
||||||
int c;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
main (int argc, const char* argv[])
|
main (int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
LogVar X = 3;
|
if (argc <= 1) {
|
||||||
LogVarSet Xs = X;
|
cerr << "error: no solver specified" << endl;
|
||||||
cout << "set: " << X << endl;
|
|
||||||
Cenas c1 (1);
|
|
||||||
Cenas c2 (3);
|
|
||||||
cout << (c1 < c2) << endl;
|
|
||||||
return 0;
|
|
||||||
if (!argv[1]) {
|
|
||||||
cerr << "error: no graphical model specified" << endl;
|
cerr << "error: no graphical model specified" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
const string& fileName = argv[1];
|
if (argc <= 2) {
|
||||||
const string& extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
cerr << "error: no graphical model specified" << endl;
|
||||||
if (extension == "xml") {
|
cerr << USAGE << endl;
|
||||||
BayesNet bn;
|
|
||||||
bn.readFromBifFormat (argv[1]);
|
|
||||||
processArguments (bn, argc, argv);
|
|
||||||
} else if (extension == "uai") {
|
|
||||||
FactorGraph fg;
|
|
||||||
fg.readFromUaiFormat (argv[1]);
|
|
||||||
processArguments (fg, argc, argv);
|
|
||||||
} else if (extension == "fg") {
|
|
||||||
FactorGraph fg;
|
|
||||||
fg.readFromLibDaiFormat (argv[1]);
|
|
||||||
processArguments (fg, argc, argv);
|
|
||||||
} else {
|
|
||||||
cerr << "error: the graphical model must be defined either " ;
|
|
||||||
cerr << "in a xml, uai or libDAI file" << endl;
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
|
string solver (argv[1]);
|
||||||
|
if (solver == "ve") {
|
||||||
|
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||||
|
} else if (solver == "bp") {
|
||||||
|
Globals::infAlgorithm = InfAlgorithms::BP;
|
||||||
|
} else if (solver == "cbp") {
|
||||||
|
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||||
|
} else {
|
||||||
|
cerr << "error: unknow solver `" << solver << "'" << endl ;
|
||||||
|
cerr << USAGE << endl;
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
string fileName (argv[2]);
|
||||||
|
string extension = fileName.substr (
|
||||||
|
fileName.find_last_of ('.') + 1);
|
||||||
|
FactorGraph fg;
|
||||||
|
if (extension == "uai") {
|
||||||
|
fg.readFromUaiFormat (fileName.c_str());
|
||||||
|
} else if (extension == "fg") {
|
||||||
|
fg.readFromLibDaiFormat (fileName.c_str());
|
||||||
|
} else {
|
||||||
|
cerr << "error: the graphical model must be defined either " ;
|
||||||
|
cerr << "in a UAI or libDAI file" << endl;
|
||||||
|
exit (0);
|
||||||
|
}
|
||||||
|
processArguments (fg, argc, argv);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
processArguments (BayesNet& bn, int argc, const char* argv[])
|
|
||||||
{
|
|
||||||
VarNodes queryVars;
|
|
||||||
for (int i = 2; i < argc; i++) {
|
|
||||||
const string& arg = argv[i];
|
|
||||||
if (arg.find ('=') == std::string::npos) {
|
|
||||||
BayesNode* queryVar = bn.getBayesNode (arg);
|
|
||||||
if (queryVar) {
|
|
||||||
queryVars.push_back (queryVar);
|
|
||||||
} else {
|
|
||||||
cerr << "error: there isn't a variable labeled of " ;
|
|
||||||
cerr << "`" << arg << "'" ;
|
|
||||||
cerr << endl;
|
|
||||||
bn.freeDistributions();
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
size_t pos = arg.find ('=');
|
|
||||||
const string& label = arg.substr (0, pos);
|
|
||||||
const string& state = arg.substr (pos + 1);
|
|
||||||
if (label.empty()) {
|
|
||||||
cerr << "error: missing left argument" << endl;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
bn.freeDistributions();
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
if (state.empty()) {
|
|
||||||
cerr << "error: missing right argument" << endl;
|
|
||||||
cerr << USAGE << endl;
|
|
||||||
bn.freeDistributions();
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
BayesNode* node = bn.getBayesNode (label);
|
|
||||||
if (node) {
|
|
||||||
if (node->isValidState (state)) {
|
|
||||||
node->setEvidence (state);
|
|
||||||
} else {
|
|
||||||
cerr << "error: `" << state << "' " ;
|
|
||||||
cerr << "is not a valid state for " ;
|
|
||||||
cerr << "`" << node->label() << "'" ;
|
|
||||||
cerr << endl;
|
|
||||||
bn.freeDistributions();
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cerr << "error: there isn't a variable labeled of " ;
|
|
||||||
cerr << "`" << label << "'" ;
|
|
||||||
cerr << endl;
|
|
||||||
bn.freeDistributions();
|
|
||||||
exit (0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Solver* solver = 0;
|
|
||||||
FactorGraph* fg = 0;
|
|
||||||
switch (InfAlgorithms::infAlgorithm) {
|
|
||||||
case InfAlgorithms::VE:
|
|
||||||
fg = new FactorGraph (bn);
|
|
||||||
solver = new VarElimSolver (*fg);
|
|
||||||
break;
|
|
||||||
case InfAlgorithms::BN_BP:
|
|
||||||
solver = new BnBpSolver (bn);
|
|
||||||
break;
|
|
||||||
case InfAlgorithms::FG_BP:
|
|
||||||
fg = new FactorGraph (bn);
|
|
||||||
solver = new FgBpSolver (*fg);
|
|
||||||
break;
|
|
||||||
case InfAlgorithms::CBP:
|
|
||||||
fg = new FactorGraph (bn);
|
|
||||||
solver = new CbpSolver (*fg);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
assert (false);
|
|
||||||
}
|
|
||||||
runSolver (solver, queryVars);
|
|
||||||
delete fg;
|
|
||||||
bn.freeDistributions();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||||
{
|
{
|
||||||
VarNodes queryVars;
|
VarIds queryIds;
|
||||||
for (int i = 2; i < argc; i++) {
|
for (int i = 3; i < argc; i++) {
|
||||||
const string& arg = argv[i];
|
const string& arg = argv[i];
|
||||||
if (arg.find ('=') == std::string::npos) {
|
if (arg.find ('=') == std::string::npos) {
|
||||||
if (!Util::isInteger (arg)) {
|
if (!Util::isInteger (arg)) {
|
||||||
cerr << "error: `" << arg << "' " ;
|
cerr << "error: `" << arg << "' " ;
|
||||||
cerr << "is not a valid variable id" ;
|
cerr << "is not a valid variable id" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
VarId vid;
|
VarId vid;
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << arg;
|
ss << arg;
|
||||||
ss >> vid;
|
ss >> vid;
|
||||||
VarNode* queryVar = fg.getFgVarNode (vid);
|
VarNode* queryVar = fg.getVarNode (vid);
|
||||||
if (queryVar) {
|
if (queryVar) {
|
||||||
queryVars.push_back (queryVar);
|
queryIds.push_back (vid);
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: there isn't a variable with " ;
|
cerr << "error: there isn't a variable with " ;
|
||||||
cerr << "`" << vid << "' as id" ;
|
cerr << "`" << vid << "' as id" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -201,33 +93,29 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
if (arg.substr (0, pos).empty()) {
|
if (arg.substr (0, pos).empty()) {
|
||||||
cerr << "error: missing left argument" << endl;
|
cerr << "error: missing left argument" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
if (arg.substr (pos + 1).empty()) {
|
if (arg.substr (pos + 1).empty()) {
|
||||||
cerr << "error: missing right argument" << endl;
|
cerr << "error: missing right argument" << endl;
|
||||||
cerr << USAGE << endl;
|
cerr << USAGE << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
if (!Util::isInteger (arg.substr (0, pos))) {
|
if (!Util::isInteger (arg.substr (0, pos))) {
|
||||||
cerr << "error: `" << arg.substr (0, pos) << "' " ;
|
cerr << "error: `" << arg.substr (0, pos) << "' " ;
|
||||||
cerr << "is not a variable id" ;
|
cerr << "is not a variable id" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
VarId vid;
|
VarId vid;
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << arg.substr (0, pos);
|
ss << arg.substr (0, pos);
|
||||||
ss >> vid;
|
ss >> vid;
|
||||||
VarNode* var = fg.getFgVarNode (vid);
|
VarNode* var = fg.getVarNode (vid);
|
||||||
if (var) {
|
if (var) {
|
||||||
if (!Util::isInteger (arg.substr (pos + 1))) {
|
if (!Util::isInteger (arg.substr (pos + 1))) {
|
||||||
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
||||||
cerr << "is not a state index" ;
|
cerr << "is not a state index" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
int stateIndex;
|
int stateIndex;
|
||||||
@ -241,29 +129,31 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
cerr << "is not a valid state index for variable " ;
|
cerr << "is not a valid state index for variable " ;
|
||||||
cerr << "`" << var->varId() << "'" ;
|
cerr << "`" << var->varId() << "'" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
cerr << "error: there isn't a variable with " ;
|
cerr << "error: there isn't a variable with " ;
|
||||||
cerr << "`" << vid << "' as id" ;
|
cerr << "`" << vid << "' as id" ;
|
||||||
cerr << endl;
|
cerr << endl;
|
||||||
fg.freeDistributions();
|
|
||||||
exit (0);
|
exit (0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
runSolver (fg, queryIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
runSolver (const FactorGraph& fg, const VarIds& queryIds)
|
||||||
|
{
|
||||||
Solver* solver = 0;
|
Solver* solver = 0;
|
||||||
switch (InfAlgorithms::infAlgorithm) {
|
switch (Globals::infAlgorithm) {
|
||||||
case InfAlgorithms::VE:
|
case InfAlgorithms::VE:
|
||||||
solver = new VarElimSolver (fg);
|
solver = new VarElimSolver (fg);
|
||||||
break;
|
break;
|
||||||
case InfAlgorithms::BN_BP:
|
case InfAlgorithms::BP:
|
||||||
case InfAlgorithms::FG_BP:
|
solver = new BpSolver (fg);
|
||||||
//cout << "here!" << endl;
|
|
||||||
//fg.printGraphicalModel();
|
|
||||||
//fg.exportToLibDaiFormat ("net.fg");
|
|
||||||
solver = new FgBpSolver (fg);
|
|
||||||
break;
|
break;
|
||||||
case InfAlgorithms::CBP:
|
case InfAlgorithms::CBP:
|
||||||
solver = new CbpSolver (fg);
|
solver = new CbpSolver (fg);
|
||||||
@ -271,28 +161,10 @@ processArguments (FactorGraph& fg, int argc, const char* argv[])
|
|||||||
default:
|
default:
|
||||||
assert (false);
|
assert (false);
|
||||||
}
|
}
|
||||||
runSolver (solver, queryVars);
|
if (queryIds.size() == 0) {
|
||||||
fg.freeDistributions();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
runSolver (Solver* solver, const VarNodes& queryVars)
|
|
||||||
{
|
|
||||||
VarIds vids;
|
|
||||||
for (unsigned i = 0; i < queryVars.size(); i++) {
|
|
||||||
vids.push_back (queryVars[i]->varId());
|
|
||||||
}
|
|
||||||
if (queryVars.size() == 0) {
|
|
||||||
solver->runSolver();
|
|
||||||
solver->printAllPosterioris();
|
solver->printAllPosterioris();
|
||||||
} else if (queryVars.size() == 1) {
|
|
||||||
solver->runSolver();
|
|
||||||
solver->printPosterioriOf (vids[0]);
|
|
||||||
} else {
|
} else {
|
||||||
solver->runSolver();
|
solver->printAnswer (queryIds);
|
||||||
solver->printJointDistributionOf (vids);
|
|
||||||
}
|
}
|
||||||
delete solver;
|
delete solver;
|
||||||
}
|
}
|
||||||
|
@ -7,22 +7,50 @@
|
|||||||
|
|
||||||
#include <YapInterface.h>
|
#include <YapInterface.h>
|
||||||
|
|
||||||
#include "BayesNet.h"
|
#include "ParfactorList.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
|
#include "FoveSolver.h"
|
||||||
#include "VarElimSolver.h"
|
#include "VarElimSolver.h"
|
||||||
#include "BnBpSolver.h"
|
#include "BpSolver.h"
|
||||||
#include "FgBpSolver.h"
|
|
||||||
#include "CbpSolver.h"
|
#include "CbpSolver.h"
|
||||||
#include "ElimGraph.h"
|
#include "ElimGraph.h"
|
||||||
#include "FoveSolver.h"
|
#include "BayesBall.h"
|
||||||
#include "ParfactorList.h"
|
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
typedef std::pair<ParfactorList*, ObservedFormulas*> LiftedNetwork;
|
||||||
|
|
||||||
|
|
||||||
|
Params readParameters (YAP_Term);
|
||||||
|
|
||||||
|
vector<unsigned> readUnsignedList (YAP_Term);
|
||||||
|
|
||||||
|
void readLiftedEvidence (YAP_Term, ObservedFormulas&);
|
||||||
|
|
||||||
|
Parfactor* readParfactor (YAP_Term);
|
||||||
|
|
||||||
|
void runVeSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||||
|
vector<Params>& results);
|
||||||
|
|
||||||
|
void runBpSolver (FactorGraph* fg, const vector<VarIds>& tasks,
|
||||||
|
vector<Params>& results);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<unsigned>
|
||||||
|
readUnsignedList (YAP_Term list)
|
||||||
|
{
|
||||||
|
vector<unsigned> vec;
|
||||||
|
while (list != YAP_TermNil()) {
|
||||||
|
vec.push_back ((unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (list)));
|
||||||
|
list = YAP_TailOfTerm (list);
|
||||||
|
}
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
Params readParams (YAP_Term);
|
|
||||||
|
|
||||||
|
|
||||||
int createLiftedNetwork (void)
|
int createLiftedNetwork (void)
|
||||||
@ -30,107 +58,121 @@ int createLiftedNetwork (void)
|
|||||||
Parfactors parfactors;
|
Parfactors parfactors;
|
||||||
YAP_Term parfactorList = YAP_ARG1;
|
YAP_Term parfactorList = YAP_ARG1;
|
||||||
while (parfactorList != YAP_TermNil()) {
|
while (parfactorList != YAP_TermNil()) {
|
||||||
YAP_Term parfactor = YAP_HeadOfTerm (parfactorList);
|
YAP_Term pfTerm = YAP_HeadOfTerm (parfactorList);
|
||||||
|
parfactors.push_back (readParfactor (pfTerm));
|
||||||
// read dist id
|
|
||||||
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, parfactor));
|
|
||||||
|
|
||||||
// read the ranges
|
|
||||||
Ranges ranges;
|
|
||||||
YAP_Term rangeList = YAP_ArgOfTerm (3, parfactor);
|
|
||||||
while (rangeList != YAP_TermNil()) {
|
|
||||||
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
|
|
||||||
ranges.push_back (range);
|
|
||||||
rangeList = YAP_TailOfTerm (rangeList);
|
|
||||||
}
|
|
||||||
|
|
||||||
// read parametric random vars
|
|
||||||
ProbFormulas formulas;
|
|
||||||
unsigned count = 0;
|
|
||||||
unordered_map<YAP_Term, LogVar> lvMap;
|
|
||||||
YAP_Term pvList = YAP_ArgOfTerm (2, parfactor);
|
|
||||||
while (pvList != YAP_TermNil()) {
|
|
||||||
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
|
|
||||||
if (YAP_IsAtomTerm (formulaTerm)) {
|
|
||||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
|
|
||||||
Symbol functor = LiftedUtils::getSymbol (name);
|
|
||||||
formulas.push_back (ProbFormula (functor, ranges[count]));
|
|
||||||
} else {
|
|
||||||
LogVars logVars;
|
|
||||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
|
|
||||||
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
|
|
||||||
Symbol functor = LiftedUtils::getSymbol (name);
|
|
||||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
|
||||||
for (unsigned i = 1; i <= arity; i++) {
|
|
||||||
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
|
|
||||||
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
|
|
||||||
if (it != lvMap.end()) {
|
|
||||||
logVars.push_back (it->second);
|
|
||||||
} else {
|
|
||||||
unsigned newLv = lvMap.size();
|
|
||||||
lvMap[ti] = newLv;
|
|
||||||
logVars.push_back (newLv);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
|
|
||||||
}
|
|
||||||
count ++;
|
|
||||||
pvList = YAP_TailOfTerm (pvList);
|
|
||||||
}
|
|
||||||
|
|
||||||
// read the parameters
|
|
||||||
const Params& params = readParams (YAP_ArgOfTerm (4, parfactor));
|
|
||||||
|
|
||||||
// read the constraint
|
|
||||||
Tuples tuples;
|
|
||||||
if (lvMap.size() >= 1) {
|
|
||||||
YAP_Term tupleList = YAP_ArgOfTerm (5, parfactor);
|
|
||||||
while (tupleList != YAP_TermNil()) {
|
|
||||||
YAP_Term term = YAP_HeadOfTerm (tupleList);
|
|
||||||
assert (YAP_IsApplTerm (term));
|
|
||||||
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
|
|
||||||
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
|
||||||
assert (lvMap.size() == arity);
|
|
||||||
Tuple tuple (arity);
|
|
||||||
for (unsigned i = 1; i <= arity; i++) {
|
|
||||||
YAP_Term ti = YAP_ArgOfTerm (i, term);
|
|
||||||
if (YAP_IsAtomTerm (ti) == false) {
|
|
||||||
cerr << "error: bad formed constraint" << endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
|
||||||
tuple[i - 1] = LiftedUtils::getSymbol (name);
|
|
||||||
}
|
|
||||||
tuples.push_back (tuple);
|
|
||||||
tupleList = YAP_TailOfTerm (tupleList);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
parfactors.push_back (new Parfactor (formulas, params, tuples, distId));
|
|
||||||
parfactorList = YAP_TailOfTerm (parfactorList);
|
parfactorList = YAP_TailOfTerm (parfactorList);
|
||||||
}
|
}
|
||||||
|
|
||||||
// LiftedUtils::printSymbolDictionary();
|
// LiftedUtils::printSymbolDictionary();
|
||||||
cout << "*******************************************************" << endl;
|
if (Constants::DEBUG > 2) {
|
||||||
cout << "INITIAL PARFACTORS" << endl;
|
// Util::printHeader ("INITIAL PARFACTORS");
|
||||||
cout << "*******************************************************" << endl;
|
// for (unsigned i = 0; i < parfactors.size(); i++) {
|
||||||
for (unsigned i = 0; i < parfactors.size(); i++) {
|
// parfactors[i]->print();
|
||||||
parfactors[i]->print();
|
// }
|
||||||
cout << endl;
|
|
||||||
}
|
}
|
||||||
ParfactorList* pfList = new ParfactorList();
|
|
||||||
for (unsigned i = 0; i < parfactors.size(); i++) {
|
|
||||||
pfList->add (parfactors[i]);
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
cout << "*******************************************************" << endl;
|
|
||||||
cout << "SHATTERED PARFACTORS" << endl;
|
|
||||||
cout << "*******************************************************" << endl;
|
|
||||||
pfList->shatter();
|
|
||||||
pfList->print();
|
|
||||||
|
|
||||||
// insert the evidence
|
ParfactorList* pfList = new ParfactorList (parfactors);
|
||||||
ObservedFormulas obsFormulas;
|
|
||||||
YAP_Term observedList = YAP_ARG2;
|
if (Constants::DEBUG >= 2) {
|
||||||
|
Util::printHeader ("SHATTERED PARFACTORS");
|
||||||
|
pfList->print();
|
||||||
|
}
|
||||||
|
|
||||||
|
// read evidence
|
||||||
|
ObservedFormulas* obsFormulas = new ObservedFormulas();
|
||||||
|
readLiftedEvidence (YAP_ARG2, *(obsFormulas));
|
||||||
|
|
||||||
|
LiftedNetwork* net = new LiftedNetwork (pfList, obsFormulas);
|
||||||
|
YAP_Int p = (YAP_Int) (net);
|
||||||
|
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Parfactor* readParfactor (YAP_Term pfTerm)
|
||||||
|
{
|
||||||
|
// read dist id
|
||||||
|
unsigned distId = YAP_IntOfTerm (YAP_ArgOfTerm (1, pfTerm));
|
||||||
|
|
||||||
|
// read the ranges
|
||||||
|
Ranges ranges;
|
||||||
|
YAP_Term rangeList = YAP_ArgOfTerm (3, pfTerm);
|
||||||
|
while (rangeList != YAP_TermNil()) {
|
||||||
|
unsigned range = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (rangeList));
|
||||||
|
ranges.push_back (range);
|
||||||
|
rangeList = YAP_TailOfTerm (rangeList);
|
||||||
|
}
|
||||||
|
|
||||||
|
// read parametric random vars
|
||||||
|
ProbFormulas formulas;
|
||||||
|
unsigned count = 0;
|
||||||
|
unordered_map<YAP_Term, LogVar> lvMap;
|
||||||
|
YAP_Term pvList = YAP_ArgOfTerm (2, pfTerm);
|
||||||
|
while (pvList != YAP_TermNil()) {
|
||||||
|
YAP_Term formulaTerm = YAP_HeadOfTerm (pvList);
|
||||||
|
if (YAP_IsAtomTerm (formulaTerm)) {
|
||||||
|
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (formulaTerm)));
|
||||||
|
Symbol functor = LiftedUtils::getSymbol (name);
|
||||||
|
formulas.push_back (ProbFormula (functor, ranges[count]));
|
||||||
|
} else {
|
||||||
|
LogVars logVars;
|
||||||
|
YAP_Functor yapFunctor = YAP_FunctorOfTerm (formulaTerm);
|
||||||
|
string name ((char*) YAP_AtomName (YAP_NameOfFunctor (yapFunctor)));
|
||||||
|
Symbol functor = LiftedUtils::getSymbol (name);
|
||||||
|
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||||
|
for (unsigned i = 1; i <= arity; i++) {
|
||||||
|
YAP_Term ti = YAP_ArgOfTerm (i, formulaTerm);
|
||||||
|
unordered_map<YAP_Term, LogVar>::iterator it = lvMap.find (ti);
|
||||||
|
if (it != lvMap.end()) {
|
||||||
|
logVars.push_back (it->second);
|
||||||
|
} else {
|
||||||
|
unsigned newLv = lvMap.size();
|
||||||
|
lvMap[ti] = newLv;
|
||||||
|
logVars.push_back (newLv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
formulas.push_back (ProbFormula (functor, logVars, ranges[count]));
|
||||||
|
}
|
||||||
|
count ++;
|
||||||
|
pvList = YAP_TailOfTerm (pvList);
|
||||||
|
}
|
||||||
|
|
||||||
|
// read the parameters
|
||||||
|
const Params& params = readParameters (YAP_ArgOfTerm (4, pfTerm));
|
||||||
|
|
||||||
|
// read the constraint
|
||||||
|
Tuples tuples;
|
||||||
|
if (lvMap.size() >= 1) {
|
||||||
|
YAP_Term tupleList = YAP_ArgOfTerm (5, pfTerm);
|
||||||
|
while (tupleList != YAP_TermNil()) {
|
||||||
|
YAP_Term term = YAP_HeadOfTerm (tupleList);
|
||||||
|
assert (YAP_IsApplTerm (term));
|
||||||
|
YAP_Functor yapFunctor = YAP_FunctorOfTerm (term);
|
||||||
|
unsigned arity = (unsigned) YAP_ArityOfFunctor (yapFunctor);
|
||||||
|
assert (lvMap.size() == arity);
|
||||||
|
Tuple tuple (arity);
|
||||||
|
for (unsigned i = 1; i <= arity; i++) {
|
||||||
|
YAP_Term ti = YAP_ArgOfTerm (i, term);
|
||||||
|
if (YAP_IsAtomTerm (ti) == false) {
|
||||||
|
cerr << "error: constraint has free variables" << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
string name ((char*) YAP_AtomName (YAP_AtomOfTerm (ti)));
|
||||||
|
tuple[i - 1] = LiftedUtils::getSymbol (name);
|
||||||
|
}
|
||||||
|
tuples.push_back (tuple);
|
||||||
|
tupleList = YAP_TailOfTerm (tupleList);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new Parfactor (formulas, params, tuples, distId);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void readLiftedEvidence (
|
||||||
|
YAP_Term observedList,
|
||||||
|
ObservedFormulas& obsFormulas)
|
||||||
|
{
|
||||||
while (observedList != YAP_TermNil()) {
|
while (observedList != YAP_TermNil()) {
|
||||||
YAP_Term pair = YAP_HeadOfTerm (observedList);
|
YAP_Term pair = YAP_HeadOfTerm (observedList);
|
||||||
YAP_Term ground = YAP_ArgOfTerm (1, pair);
|
YAP_Term ground = YAP_ArgOfTerm (1, pair);
|
||||||
@ -155,22 +197,18 @@ int createLiftedNetwork (void)
|
|||||||
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
|
unsigned evidence = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, pair));
|
||||||
bool found = false;
|
bool found = false;
|
||||||
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
for (unsigned i = 0; i < obsFormulas.size(); i++) {
|
||||||
if (obsFormulas[i]->functor() == functor &&
|
if (obsFormulas[i].functor() == functor &&
|
||||||
obsFormulas[i]->arity() == args.size() &&
|
obsFormulas[i].arity() == args.size() &&
|
||||||
obsFormulas[i]->evidence() == evidence) {
|
obsFormulas[i].evidence() == evidence) {
|
||||||
obsFormulas[i]->addTuple (args);
|
obsFormulas[i].addTuple (args);
|
||||||
found = true;
|
found = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (found == false) {
|
if (found == false) {
|
||||||
obsFormulas.push_back (new ObservedFormula (functor, evidence, args));
|
obsFormulas.push_back (ObservedFormula (functor, evidence, args));
|
||||||
}
|
}
|
||||||
observedList = YAP_TailOfTerm (observedList);
|
observedList = YAP_TailOfTerm (observedList);
|
||||||
}
|
}
|
||||||
FoveSolver::absorveEvidence (*pfList, obsFormulas);
|
|
||||||
|
|
||||||
YAP_Int p = (YAP_Int) (pfList);
|
|
||||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG3);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -178,93 +216,46 @@ int createLiftedNetwork (void)
|
|||||||
int
|
int
|
||||||
createGroundNetwork (void)
|
createGroundNetwork (void)
|
||||||
{
|
{
|
||||||
Statistics::incrementPrimaryNetworksCounting();
|
string factorsType ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||||
// cout << "creating network number " ;
|
bool fromBayesNet = factorsType == "bayes";
|
||||||
// cout << Statistics::getPrimaryNetworksCounting() << endl;
|
FactorGraph* fg = new FactorGraph (fromBayesNet);
|
||||||
// if (Statistics::getPrimaryNetworksCounting() > 98) {
|
YAP_Term factorList = YAP_ARG2;
|
||||||
// Statistics::writeStatisticsToFile ("../../compressing.stats");
|
while (factorList != YAP_TermNil()) {
|
||||||
// }
|
YAP_Term factor = YAP_HeadOfTerm (factorList);
|
||||||
BayesNet* bn = new BayesNet();
|
// read the var ids
|
||||||
YAP_Term varList = YAP_ARG1;
|
VarIds varIds = readUnsignedList (YAP_ArgOfTerm (1, factor));
|
||||||
BnNodeSet nodes;
|
// read the ranges
|
||||||
vector<VarIds> parents;
|
Ranges ranges = readUnsignedList (YAP_ArgOfTerm (2, factor));
|
||||||
while (varList != YAP_TermNil()) {
|
// read the parameters
|
||||||
YAP_Term var = YAP_HeadOfTerm (varList);
|
Params params = readParameters (YAP_ArgOfTerm (3, factor));
|
||||||
VarId vid = (VarId) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
|
// read dist id
|
||||||
unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var));
|
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (4, factor));
|
||||||
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var));
|
fg->addFactor (Factor (varIds, ranges, params, distId));
|
||||||
YAP_Term parentL = YAP_ArgOfTerm (4, var);
|
factorList = YAP_TailOfTerm (factorList);
|
||||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var));
|
|
||||||
parents.push_back (VarIds());
|
|
||||||
while (parentL != YAP_TermNil()) {
|
|
||||||
unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL));
|
|
||||||
parents.back().push_back (parentId);
|
|
||||||
parentL = YAP_TailOfTerm (parentL);
|
|
||||||
}
|
|
||||||
Distribution* dist = bn->getDistribution (distId);
|
|
||||||
if (!dist) {
|
|
||||||
dist = new Distribution (distId);
|
|
||||||
bn->addDistribution (dist);
|
|
||||||
}
|
|
||||||
assert (bn->getBayesNode (vid) == 0);
|
|
||||||
nodes.push_back (bn->addNode (vid, dsize, evidence, dist));
|
|
||||||
varList = YAP_TailOfTerm (varList);
|
|
||||||
}
|
}
|
||||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
|
||||||
BnNodeSet ps;
|
YAP_Term evidenceList = YAP_ARG3;
|
||||||
for (unsigned j = 0; j < parents[i].size(); j++) {
|
while (evidenceList != YAP_TermNil()) {
|
||||||
assert (bn->getBayesNode (parents[i][j]) != 0);
|
YAP_Term evTerm = YAP_HeadOfTerm (evidenceList);
|
||||||
ps.push_back (bn->getBayesNode (parents[i][j]));
|
unsigned vid = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (1, evTerm)));
|
||||||
}
|
unsigned ev = (unsigned) YAP_IntOfTerm ((YAP_ArgOfTerm (2, evTerm)));
|
||||||
nodes[i]->setParents (ps);
|
assert (fg->getVarNode (vid));
|
||||||
|
fg->getVarNode (vid)->setEvidence (ev);
|
||||||
|
evidenceList = YAP_TailOfTerm (evidenceList);
|
||||||
}
|
}
|
||||||
bn->setIndexes();
|
|
||||||
YAP_Int p = (YAP_Int) (bn);
|
|
||||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
YAP_Int p = (YAP_Int) (fg);
|
||||||
|
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG4);
|
||||||
int
|
|
||||||
setBayesNetParams (void)
|
|
||||||
{
|
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
|
||||||
YAP_Term distList = YAP_ARG2;
|
|
||||||
while (distList != YAP_TermNil()) {
|
|
||||||
YAP_Term dist = YAP_HeadOfTerm (distList);
|
|
||||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
|
||||||
const Params params = readParams (YAP_ArgOfTerm (2, dist));
|
|
||||||
bn->getDistribution(distId)->updateParameters (params);
|
|
||||||
distList = YAP_TailOfTerm (distList);
|
|
||||||
}
|
|
||||||
return TRUE;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
setParfactorGraphParams (void)
|
|
||||||
{
|
|
||||||
// FIXME
|
|
||||||
// ParfactorGraph* pfg = (ParfactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
|
||||||
YAP_Term distList = YAP_ARG2;
|
|
||||||
while (distList != YAP_TermNil()) {
|
|
||||||
// YAP_Term dist = YAP_HeadOfTerm (distList);
|
|
||||||
// unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
|
||||||
// const Params params = readParams (YAP_ArgOfTerm (2, dist));
|
|
||||||
// pfg->getDistribution(distId)->setData (params);
|
|
||||||
distList = YAP_TailOfTerm (distList);
|
|
||||||
}
|
|
||||||
return TRUE;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
readParams (YAP_Term paramL)
|
readParameters (YAP_Term paramL)
|
||||||
{
|
{
|
||||||
Params params;
|
Params params;
|
||||||
while (paramL!= YAP_TermNil()) {
|
assert (YAP_IsPairTerm (paramL));
|
||||||
|
while (paramL != YAP_TermNil()) {
|
||||||
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
|
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
|
||||||
paramL = YAP_TailOfTerm (paramL);
|
paramL = YAP_TailOfTerm (paramL);
|
||||||
}
|
}
|
||||||
@ -279,15 +270,14 @@ readParams (YAP_Term paramL)
|
|||||||
int
|
int
|
||||||
runLiftedSolver (void)
|
runLiftedSolver (void)
|
||||||
{
|
{
|
||||||
ParfactorList* pfList = (ParfactorList*) YAP_IntOfTerm (YAP_ARG1);
|
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
YAP_Term taskList = YAP_ARG2;
|
YAP_Term taskList = YAP_ARG2;
|
||||||
vector<Params> results;
|
vector<Params> results;
|
||||||
|
ParfactorList pfListCopy (*network->first);
|
||||||
|
FoveSolver::absorveEvidence (pfListCopy, *network->second);
|
||||||
while (taskList != YAP_TermNil()) {
|
while (taskList != YAP_TermNil()) {
|
||||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
|
||||||
Grounds queryVars;
|
Grounds queryVars;
|
||||||
assert (YAP_IsPairTerm (taskList));
|
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||||
assert (YAP_IsPairTerm (jointList));
|
|
||||||
while (jointList != YAP_TermNil()) {
|
while (jointList != YAP_TermNil()) {
|
||||||
YAP_Term ground = YAP_HeadOfTerm (jointList);
|
YAP_Term ground = YAP_HeadOfTerm (jointList);
|
||||||
if (YAP_IsAtomTerm (ground)) {
|
if (YAP_IsAtomTerm (ground)) {
|
||||||
@ -310,11 +300,11 @@ runLiftedSolver (void)
|
|||||||
}
|
}
|
||||||
jointList = YAP_TailOfTerm (jointList);
|
jointList = YAP_TailOfTerm (jointList);
|
||||||
}
|
}
|
||||||
FoveSolver solver (pfList);
|
FoveSolver solver (pfListCopy);
|
||||||
if (queryVars.size() == 1) {
|
if (queryVars.size() == 1) {
|
||||||
results.push_back (solver.getPosterioriOf (queryVars[0]));
|
results.push_back (solver.getPosterioriOf (queryVars[0]));
|
||||||
} else {
|
} else {
|
||||||
assert (false); // TODO joint dist
|
results.push_back (solver.getJointDistributionOf (queryVars));
|
||||||
}
|
}
|
||||||
taskList = YAP_TailOfTerm (taskList);
|
taskList = YAP_TailOfTerm (taskList);
|
||||||
}
|
}
|
||||||
@ -339,77 +329,23 @@ runLiftedSolver (void)
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
runOtherSolvers (void)
|
runGroundSolver (void)
|
||||||
{
|
{
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
YAP_Term taskList = YAP_ARG2;
|
|
||||||
vector<VarIds> tasks;
|
vector<VarIds> tasks;
|
||||||
std::set<VarId> vids;
|
YAP_Term taskList = YAP_ARG2;
|
||||||
while (taskList != YAP_TermNil()) {
|
while (taskList != YAP_TermNil()) {
|
||||||
if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) {
|
tasks.push_back (readUnsignedList (YAP_HeadOfTerm (taskList)));
|
||||||
tasks.push_back (VarIds());
|
|
||||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
|
||||||
while (jointList != YAP_TermNil()) {
|
|
||||||
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
|
||||||
assert (bn->getBayesNode (vid));
|
|
||||||
tasks.back().push_back (vid);
|
|
||||||
vids.insert (vid);
|
|
||||||
jointList = YAP_TailOfTerm (jointList);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
|
|
||||||
assert (bn->getBayesNode (vid));
|
|
||||||
tasks.push_back (VarIds() = {vid});
|
|
||||||
vids.insert (vid);
|
|
||||||
}
|
|
||||||
taskList = YAP_TailOfTerm (taskList);
|
taskList = YAP_TailOfTerm (taskList);
|
||||||
}
|
}
|
||||||
|
|
||||||
Solver* bpSolver = 0;
|
|
||||||
GraphicalModel* graphicalModel = 0;
|
|
||||||
CFactorGraph::checkForIdenticalFactors = false;
|
|
||||||
if (InfAlgorithms::infAlgorithm != InfAlgorithms::VE) {
|
|
||||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (
|
|
||||||
VarIds (vids.begin(), vids.end()));
|
|
||||||
if (InfAlgorithms::infAlgorithm == InfAlgorithms::BN_BP) {
|
|
||||||
graphicalModel = mrn;
|
|
||||||
bpSolver = new BnBpSolver (*static_cast<BayesNet*> (graphicalModel));
|
|
||||||
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::FG_BP) {
|
|
||||||
graphicalModel = new FactorGraph (*mrn);
|
|
||||||
bpSolver = new FgBpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
|
||||||
delete mrn;
|
|
||||||
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::CBP) {
|
|
||||||
graphicalModel = new FactorGraph (*mrn);
|
|
||||||
bpSolver = new CbpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
|
||||||
delete mrn;
|
|
||||||
}
|
|
||||||
bpSolver->runSolver();
|
|
||||||
}
|
|
||||||
|
|
||||||
vector<Params> results;
|
vector<Params> results;
|
||||||
results.reserve (tasks.size());
|
if (Globals::infAlgorithm == InfAlgorithms::VE) {
|
||||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
runVeSolver (fg, tasks, results);
|
||||||
//if (i == 1) exit (0);
|
} else {
|
||||||
if (InfAlgorithms::infAlgorithm == InfAlgorithms::VE) {
|
runBpSolver (fg, tasks, results);
|
||||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (tasks[i]);
|
|
||||||
VarElimSolver* veSolver = new VarElimSolver (*mrn);
|
|
||||||
if (tasks[i].size() == 1) {
|
|
||||||
results.push_back (veSolver->getPosterioriOf (tasks[i][0]));
|
|
||||||
} else {
|
|
||||||
results.push_back (veSolver->getJointDistributionOf (tasks[i]));
|
|
||||||
}
|
|
||||||
delete mrn;
|
|
||||||
delete veSolver;
|
|
||||||
} else {
|
|
||||||
if (tasks[i].size() == 1) {
|
|
||||||
results.push_back (bpSolver->getPosterioriOf (tasks[i][0]));
|
|
||||||
} else {
|
|
||||||
results.push_back (bpSolver->getJointDistributionOf (tasks[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
delete bpSolver;
|
|
||||||
delete graphicalModel;
|
|
||||||
|
|
||||||
YAP_Term list = YAP_TermNil();
|
YAP_Term list = YAP_TermNil();
|
||||||
for (int i = results.size() - 1; i >= 0; i--) {
|
for (int i = results.size() - 1; i >= 0; i--) {
|
||||||
@ -424,32 +360,142 @@ runOtherSolvers (void)
|
|||||||
}
|
}
|
||||||
list = YAP_MkPairTerm (queryBeliefsL, list);
|
list = YAP_MkPairTerm (queryBeliefsL, list);
|
||||||
}
|
}
|
||||||
|
|
||||||
return YAP_Unify (list, YAP_ARG3);
|
return YAP_Unify (list, YAP_ARG3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
void runVeSolver (
|
||||||
setExtraVarsInfo (void)
|
FactorGraph* fg,
|
||||||
|
const vector<VarIds>& tasks,
|
||||||
|
vector<Params>& results)
|
||||||
{
|
{
|
||||||
// BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
results.reserve (tasks.size());
|
||||||
GraphicalModel::clearVariablesInformation();
|
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||||
YAP_Term varsInfoL = YAP_ARG2;
|
FactorGraph* mfg = fg;
|
||||||
while (varsInfoL != YAP_TermNil()) {
|
if (fg->isFromBayesNetwork()) {
|
||||||
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
|
mfg = BayesBall::getMinimalFactorGraph (*fg, tasks[i]);
|
||||||
VarId vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
|
|
||||||
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
|
|
||||||
YAP_Term statesL = YAP_ArgOfTerm (3, head);
|
|
||||||
States states;
|
|
||||||
while (statesL != YAP_TermNil()) {
|
|
||||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (statesL));
|
|
||||||
states.push_back ((char*) YAP_AtomName (atom));
|
|
||||||
statesL = YAP_TailOfTerm (statesL);
|
|
||||||
}
|
}
|
||||||
GraphicalModel::addVariableInformation (vid,
|
VarElimSolver solver (*mfg);
|
||||||
(char*) YAP_AtomName (label), states);
|
results.push_back (solver.solveQuery (tasks[i]));
|
||||||
varsInfoL = YAP_TailOfTerm (varsInfoL);
|
if (fg->isFromBayesNetwork()) {
|
||||||
|
delete mfg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void runBpSolver (
|
||||||
|
FactorGraph* fg,
|
||||||
|
const vector<VarIds>& tasks,
|
||||||
|
vector<Params>& results)
|
||||||
|
{
|
||||||
|
std::set<VarId> vids;
|
||||||
|
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||||
|
Util::addToSet (vids, tasks[i]);
|
||||||
|
}
|
||||||
|
Solver* solver = 0;
|
||||||
|
FactorGraph* mfg = fg;
|
||||||
|
if (fg->isFromBayesNetwork()) {
|
||||||
|
mfg = BayesBall::getMinimalFactorGraph (
|
||||||
|
*fg, VarIds (vids.begin(),vids.end()));
|
||||||
|
}
|
||||||
|
if (Globals::infAlgorithm == InfAlgorithms::BP) {
|
||||||
|
solver = new BpSolver (*mfg);
|
||||||
|
} else if (Globals::infAlgorithm == InfAlgorithms::CBP) {
|
||||||
|
CFactorGraph::checkForIdenticalFactors = false;
|
||||||
|
solver = new CbpSolver (*mfg);
|
||||||
|
} else {
|
||||||
|
cerr << "error: unknow solver" << endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
results.reserve (tasks.size());
|
||||||
|
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||||
|
results.push_back (solver->solveQuery (tasks[i]));
|
||||||
|
}
|
||||||
|
if (fg->isFromBayesNetwork()) {
|
||||||
|
delete mfg;
|
||||||
|
}
|
||||||
|
delete solver;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
setParfactorsParams (void)
|
||||||
|
{
|
||||||
|
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
|
ParfactorList* pfList = network->first;
|
||||||
|
YAP_Term distList = YAP_ARG2;
|
||||||
|
unordered_map<unsigned, Params> paramsMap;
|
||||||
|
while (distList != YAP_TermNil()) {
|
||||||
|
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||||
|
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||||
|
assert (Util::contains (paramsMap, distId) == false);
|
||||||
|
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
|
||||||
|
distList = YAP_TailOfTerm (distList);
|
||||||
|
}
|
||||||
|
ParfactorList::iterator it = pfList->begin();
|
||||||
|
while (it != pfList->end()) {
|
||||||
|
assert (Util::contains (paramsMap, (*it)->distId()));
|
||||||
|
// (*it)->setParams (paramsMap[(*it)->distId()]);
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
return TRUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
setFactorsParams (void)
|
||||||
|
{
|
||||||
|
return TRUE; // TODO
|
||||||
|
FactorGraph* fg = (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
|
YAP_Term distList = YAP_ARG2;
|
||||||
|
unordered_map<unsigned, Params> paramsMap;
|
||||||
|
while (distList != YAP_TermNil()) {
|
||||||
|
YAP_Term dist = YAP_HeadOfTerm (distList);
|
||||||
|
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (1, dist));
|
||||||
|
assert (Util::contains (paramsMap, distId) == false);
|
||||||
|
paramsMap[distId] = readParameters (YAP_ArgOfTerm (2, dist));
|
||||||
|
distList = YAP_TailOfTerm (distList);
|
||||||
|
}
|
||||||
|
const FacNodes& facNodes = fg->facNodes();
|
||||||
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
|
unsigned distId = facNodes[i]->factor().distId();
|
||||||
|
assert (Util::contains (paramsMap, distId));
|
||||||
|
facNodes[i]->factor().setParams (paramsMap[distId]);
|
||||||
|
}
|
||||||
|
return TRUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int
|
||||||
|
setVarsInformation (void)
|
||||||
|
{
|
||||||
|
Var::clearVarsInfo();
|
||||||
|
YAP_Term labelsL = YAP_ARG1;
|
||||||
|
vector<string> labels;
|
||||||
|
while (labelsL != YAP_TermNil()) {
|
||||||
|
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (labelsL));
|
||||||
|
labels.push_back ((char*) YAP_AtomName (atom));
|
||||||
|
labelsL = YAP_TailOfTerm (labelsL);
|
||||||
|
}
|
||||||
|
unsigned count = 0;
|
||||||
|
YAP_Term stateNamesL = YAP_ARG2;
|
||||||
|
while (stateNamesL != YAP_TermNil()) {
|
||||||
|
States states;
|
||||||
|
YAP_Term namesL = YAP_HeadOfTerm (stateNamesL);
|
||||||
|
while (namesL != YAP_TermNil()) {
|
||||||
|
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (namesL));
|
||||||
|
states.push_back ((char*) YAP_AtomName (atom));
|
||||||
|
namesL = YAP_TailOfTerm (namesL);
|
||||||
|
}
|
||||||
|
Var::addVarInfo (count, labels[count], states);
|
||||||
|
count ++;
|
||||||
|
stateNamesL = YAP_TailOfTerm (stateNamesL);
|
||||||
}
|
}
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
@ -463,13 +509,11 @@ setHorusFlag (void)
|
|||||||
if (key == "inf_alg") {
|
if (key == "inf_alg") {
|
||||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||||
if ( value == "ve") {
|
if ( value == "ve") {
|
||||||
InfAlgorithms::infAlgorithm = InfAlgorithms::VE;
|
Globals::infAlgorithm = InfAlgorithms::VE;
|
||||||
} else if (value == "bn_bp") {
|
} else if (value == "bp") {
|
||||||
InfAlgorithms::infAlgorithm = InfAlgorithms::BN_BP;
|
Globals::infAlgorithm = InfAlgorithms::BP;
|
||||||
} else if (value == "fg_bp") {
|
|
||||||
InfAlgorithms::infAlgorithm = InfAlgorithms::FG_BP;
|
|
||||||
} else if (value == "cbp") {
|
} else if (value == "cbp") {
|
||||||
InfAlgorithms::infAlgorithm = InfAlgorithms::CBP;
|
Globals::infAlgorithm = InfAlgorithms::CBP;
|
||||||
} else {
|
} else {
|
||||||
cerr << "warning: invalid value `" << value << "' " ;
|
cerr << "warning: invalid value `" << value << "' " ;
|
||||||
cerr << "for `" << key << "'" << endl;
|
cerr << "for `" << key << "'" << endl;
|
||||||
@ -541,21 +585,21 @@ setHorusFlag (void)
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
freeBayesNetwork (void)
|
freeGroundNetwork (void)
|
||||||
{
|
{
|
||||||
//Statistics::writeStatisticsToFile ("stats.txt");
|
delete (FactorGraph*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
|
||||||
bn->freeDistributions();
|
|
||||||
delete bn;
|
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
freeParfactorGraph (void)
|
freeParfactors (void)
|
||||||
{
|
{
|
||||||
delete (ParfactorList*) YAP_IntOfTerm (YAP_ARG1);
|
LiftedNetwork* network = (LiftedNetwork*) YAP_IntOfTerm (YAP_ARG1);
|
||||||
|
delete network->first;
|
||||||
|
delete network->second;
|
||||||
|
delete network;
|
||||||
return TRUE;
|
return TRUE;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -564,15 +608,15 @@ freeParfactorGraph (void)
|
|||||||
extern "C" void
|
extern "C" void
|
||||||
init_predicates (void)
|
init_predicates (void)
|
||||||
{
|
{
|
||||||
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
|
YAP_UserCPredicate ("create_lifted_network", createLiftedNetwork, 3);
|
||||||
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 2);
|
YAP_UserCPredicate ("create_ground_network", createGroundNetwork, 4);
|
||||||
YAP_UserCPredicate ("set_parfactor_graph_params", setParfactorGraphParams, 2);
|
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
||||||
YAP_UserCPredicate ("set_bayes_net_params", setBayesNetParams, 2);
|
YAP_UserCPredicate ("run_ground_solver", runGroundSolver, 3);
|
||||||
YAP_UserCPredicate ("run_lifted_solver", runLiftedSolver, 3);
|
YAP_UserCPredicate ("set_parfactors_params", setParfactorsParams, 2);
|
||||||
YAP_UserCPredicate ("run_other_solvers", runOtherSolvers, 3);
|
YAP_UserCPredicate ("set_factors_params", setFactorsParams, 2);
|
||||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
YAP_UserCPredicate ("set_vars_information", setVarsInformation, 2);
|
||||||
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
YAP_UserCPredicate ("set_horus_flag", setHorusFlag, 2);
|
||||||
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
|
YAP_UserCPredicate ("free_parfactors", freeParfactors, 1);
|
||||||
YAP_UserCPredicate ("free_parfactor_graph", freeParfactorGraph, 1);
|
YAP_UserCPredicate ("free_ground_network", freeGroundNetwork, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,11 +8,13 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
|
||||||
#include "VarNode.h"
|
#include "Var.h"
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
class StatesIndexer {
|
|
||||||
|
class StatesIndexer
|
||||||
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
StatesIndexer (const Ranges& ranges, bool calcOffsets = true)
|
StatesIndexer (const Ranges& ranges, bool calcOffsets = true)
|
||||||
@ -29,14 +31,14 @@ class StatesIndexer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StatesIndexer (const VarNodes& vars, bool calcOffsets = true)
|
StatesIndexer (const Vars& vars, bool calcOffsets = true)
|
||||||
{
|
{
|
||||||
size_ = 1;
|
size_ = 1;
|
||||||
indices_.resize (vars.size(), 0);
|
indices_.resize (vars.size(), 0);
|
||||||
ranges_.reserve (vars.size());
|
ranges_.reserve (vars.size());
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
ranges_.push_back (vars[i]->nrStates());
|
ranges_.push_back (vars[i]->range());
|
||||||
size_ *= vars[i]->nrStates();
|
size_ *= vars[i]->range();
|
||||||
}
|
}
|
||||||
li_ = 0;
|
li_ = 0;
|
||||||
if (calcOffsets) {
|
if (calcOffsets) {
|
||||||
@ -134,11 +136,11 @@ class StatesIndexer {
|
|||||||
return size_ ;
|
return size_ ;
|
||||||
}
|
}
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &out, const StatesIndexer& idx)
|
friend ostream& operator<< (ostream &os, const StatesIndexer& idx)
|
||||||
{
|
{
|
||||||
out << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
|
os << "(" << std::setw (2) << std::setfill('0') << idx.li_ << ") " ;
|
||||||
out << idx.indices_;
|
os << idx.indices_;
|
||||||
return out;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -274,21 +276,14 @@ class MapIndexer
|
|||||||
index_ = 0;
|
index_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &out, const MapIndexer& idx)
|
friend ostream& operator<< (ostream &os, const MapIndexer& idx)
|
||||||
{
|
{
|
||||||
out << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
|
os << "(" << std::setw (2) << std::setfill('0') << idx.index_ << ") " ;
|
||||||
out << idx.indices_;
|
os << idx.indices_;
|
||||||
return out;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MapIndexer (const Ranges& ranges) :
|
|
||||||
ranges_(ranges),
|
|
||||||
indices_(ranges.size(), 0),
|
|
||||||
offsets_(ranges.size())
|
|
||||||
{
|
|
||||||
index_ = 0;
|
|
||||||
}
|
|
||||||
unsigned index_;
|
unsigned index_;
|
||||||
bool valid_;
|
bool valid_;
|
||||||
vector<unsigned> ranges_;
|
vector<unsigned> ranges_;
|
||||||
|
@ -95,26 +95,37 @@ ostream& operator<< (ostream &os, const Ground& gr)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
LogVars
|
||||||
ObservedFormula::addTuple (const Tuple& t)
|
Substitution::getDiscardedLogVars (void) const
|
||||||
{
|
{
|
||||||
if (constr_ == 0) {
|
LogVars discardedLvs;
|
||||||
LogVars lvs (arity_);
|
set<LogVar> doneLvs;
|
||||||
for (unsigned i = 0; i < arity_; i++) {
|
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||||
lvs[i] = i;
|
it = subs_.begin();
|
||||||
|
while (it != subs_.end()) {
|
||||||
|
if (Util::contains (doneLvs, it->second)) {
|
||||||
|
discardedLvs.push_back (it->first);
|
||||||
|
} else {
|
||||||
|
doneLvs.insert (it->second);
|
||||||
}
|
}
|
||||||
constr_ = new ConstraintTree (lvs);
|
it ++;
|
||||||
}
|
}
|
||||||
constr_->addTuple (t);
|
return discardedLvs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ostream& operator<< (ostream &os, const ObservedFormula of)
|
ostream& operator<< (ostream &os, const Substitution& theta)
|
||||||
{
|
{
|
||||||
os << of.functor_ << "/" << of.arity_;
|
unordered_map<LogVar, LogVar>::const_iterator it;
|
||||||
os << "|" << of.constr_->tupleSet();
|
os << "[" ;
|
||||||
os << " [evidence=" << of.evidence_ << "]";
|
it = theta.subs_.begin();
|
||||||
|
while (it != theta.subs_.end()) {
|
||||||
|
if (it != theta.subs_.begin()) os << ", " ;
|
||||||
|
os << it->first << "->" << it->second ;
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
os << "]" ;
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,11 +18,17 @@ class Symbol
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Symbol (void) : id_(numeric_limits<unsigned>::max()) { }
|
Symbol (void) : id_(numeric_limits<unsigned>::max()) { }
|
||||||
|
|
||||||
Symbol (unsigned id) : id_(id) { }
|
Symbol (unsigned id) : id_(id) { }
|
||||||
|
|
||||||
operator unsigned (void) const { return id_; }
|
operator unsigned (void) const { return id_; }
|
||||||
|
|
||||||
bool valid (void) const { return id_ != numeric_limits<unsigned>::max(); }
|
bool valid (void) const { return id_ != numeric_limits<unsigned>::max(); }
|
||||||
|
|
||||||
static Symbol invalid (void) { return Symbol(); }
|
static Symbol invalid (void) { return Symbol(); }
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &os, const Symbol& s);
|
friend ostream& operator<< (ostream &os, const Symbol& s);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
unsigned id_;
|
unsigned id_;
|
||||||
};
|
};
|
||||||
@ -32,7 +38,9 @@ class LogVar
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
LogVar (void) : id_(numeric_limits<unsigned>::max()) { }
|
LogVar (void) : id_(numeric_limits<unsigned>::max()) { }
|
||||||
|
|
||||||
LogVar (unsigned id) : id_(id) { }
|
LogVar (unsigned id) : id_(id) { }
|
||||||
|
|
||||||
operator unsigned (void) const { return id_; }
|
operator unsigned (void) const { return id_; }
|
||||||
|
|
||||||
LogVar& operator++ (void)
|
LogVar& operator++ (void)
|
||||||
@ -48,6 +56,7 @@ class LogVar
|
|||||||
}
|
}
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &os, const LogVar& X);
|
friend ostream& operator<< (ostream &os, const LogVar& X);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
unsigned id_;
|
unsigned id_;
|
||||||
};
|
};
|
||||||
@ -79,8 +88,8 @@ ostream& operator<< (ostream &os, const Tuple& t);
|
|||||||
|
|
||||||
|
|
||||||
namespace LiftedUtils {
|
namespace LiftedUtils {
|
||||||
Symbol getSymbol (const string&);
|
Symbol getSymbol (const string&);
|
||||||
void printSymbolDictionary (void);
|
void printSymbolDictionary (void);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -89,71 +98,56 @@ class Ground
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Ground (Symbol f) : functor_(f) { }
|
Ground (Symbol f) : functor_(f) { }
|
||||||
|
|
||||||
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { }
|
Ground (Symbol f, const Symbols& args) : functor_(f), args_(args) { }
|
||||||
|
|
||||||
Symbol functor (void) const { return functor_; }
|
Symbol functor (void) const { return functor_; }
|
||||||
Symbols args (void) const { return args_; }
|
|
||||||
unsigned arity (void) const { return args_.size(); }
|
Symbols args (void) const { return args_; }
|
||||||
bool isAtom (void) const { return args_.size() == 0; }
|
|
||||||
|
unsigned arity (void) const { return args_.size(); }
|
||||||
|
|
||||||
|
bool isAtom (void) const { return args_.size() == 0; }
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &os, const Ground& gr);
|
friend ostream& operator<< (ostream &os, const Ground& gr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Symbol functor_;
|
Symbol functor_;
|
||||||
Symbols args_;
|
Symbols args_;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef vector<Ground> Grounds;
|
typedef vector<Ground> Grounds;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConstraintTree;
|
|
||||||
class ObservedFormula
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
ObservedFormula (Symbol f, unsigned a, unsigned ev)
|
|
||||||
: functor_(f), arity_(a), evidence_(ev), constr_(0) { }
|
|
||||||
|
|
||||||
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
|
|
||||||
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(0)
|
|
||||||
{
|
|
||||||
addTuple (tuple);
|
|
||||||
}
|
|
||||||
|
|
||||||
Symbol functor (void) const { return functor_; }
|
|
||||||
unsigned arity (void) const { return arity_; }
|
|
||||||
unsigned evidence (void) const { return evidence_; }
|
|
||||||
ConstraintTree* constr (void) const { return constr_; }
|
|
||||||
bool isAtom (void) const { return arity_ == 0; }
|
|
||||||
|
|
||||||
void addTuple (const Tuple& t);
|
|
||||||
friend ostream& operator<< (ostream &os, const ObservedFormula opv);
|
|
||||||
private:
|
|
||||||
Symbol functor_;
|
|
||||||
unsigned arity_;
|
|
||||||
unsigned evidence_;
|
|
||||||
ConstraintTree* constr_;
|
|
||||||
};
|
|
||||||
typedef vector<ObservedFormula*> ObservedFormulas;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Substitution
|
class Substitution
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
void add (LogVar X_old, LogVar X_new)
|
void add (LogVar X_old, LogVar X_new)
|
||||||
{
|
{
|
||||||
|
assert (Util::contains (subs_, X_old) == false);
|
||||||
subs_.insert (make_pair (X_old, X_new));
|
subs_.insert (make_pair (X_old, X_new));
|
||||||
}
|
}
|
||||||
|
|
||||||
void rename (LogVar X_old, LogVar X_new)
|
void rename (LogVar X_old, LogVar X_new)
|
||||||
{
|
{
|
||||||
assert (subs_.find (X_old) != subs_.end());
|
assert (Util::contains (subs_, X_old));
|
||||||
subs_.find (X_old)->second = X_new;
|
subs_.find (X_old)->second = X_new;
|
||||||
}
|
}
|
||||||
|
|
||||||
LogVar newNameFor (LogVar X) const
|
LogVar newNameFor (LogVar X) const
|
||||||
{
|
{
|
||||||
assert (subs_.find (X) != subs_.end());
|
assert (Util::contains (subs_, X));
|
||||||
return subs_.find (X)->second;
|
return subs_.find (X)->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogVars getDiscardedLogVars (void) const;
|
||||||
|
|
||||||
|
friend ostream& operator<< (ostream &os, const Substitution& theta);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
unordered_map<LogVar, LogVar> subs_;
|
unordered_map<LogVar, LogVar> subs_;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,9 +45,8 @@ CWD=$(PWD)
|
|||||||
|
|
||||||
|
|
||||||
HEADERS = \
|
HEADERS = \
|
||||||
$(srcdir)/GraphicalModel.h \
|
|
||||||
$(srcdir)/BayesNet.h \
|
$(srcdir)/BayesNet.h \
|
||||||
$(srcdir)/BayesNode.h \
|
$(srcdir)/BayesBall.h \
|
||||||
$(srcdir)/ElimGraph.h \
|
$(srcdir)/ElimGraph.h \
|
||||||
$(srcdir)/FactorGraph.h \
|
$(srcdir)/FactorGraph.h \
|
||||||
$(srcdir)/Factor.h \
|
$(srcdir)/Factor.h \
|
||||||
@ -55,12 +54,10 @@ HEADERS = \
|
|||||||
$(srcdir)/ConstraintTree.h \
|
$(srcdir)/ConstraintTree.h \
|
||||||
$(srcdir)/Solver.h \
|
$(srcdir)/Solver.h \
|
||||||
$(srcdir)/VarElimSolver.h \
|
$(srcdir)/VarElimSolver.h \
|
||||||
$(srcdir)/BnBpSolver.h \
|
$(srcdir)/BpSolver.h \
|
||||||
$(srcdir)/FgBpSolver.h \
|
|
||||||
$(srcdir)/CbpSolver.h \
|
$(srcdir)/CbpSolver.h \
|
||||||
$(srcdir)/FoveSolver.h \
|
$(srcdir)/FoveSolver.h \
|
||||||
$(srcdir)/VarNode.h \
|
$(srcdir)/Var.h \
|
||||||
$(srcdir)/Distribution.h \
|
|
||||||
$(srcdir)/Indexer.h \
|
$(srcdir)/Indexer.h \
|
||||||
$(srcdir)/Parfactor.h \
|
$(srcdir)/Parfactor.h \
|
||||||
$(srcdir)/ProbFormula.h \
|
$(srcdir)/ProbFormula.h \
|
||||||
@ -69,22 +66,20 @@ HEADERS = \
|
|||||||
$(srcdir)/LiftedUtils.h \
|
$(srcdir)/LiftedUtils.h \
|
||||||
$(srcdir)/TinySet.h \
|
$(srcdir)/TinySet.h \
|
||||||
$(srcdir)/Util.h \
|
$(srcdir)/Util.h \
|
||||||
$(srcdir)/Horus.h \
|
$(srcdir)/Horus.h
|
||||||
$(srcdir)/xmlParser/xmlParser.h
|
|
||||||
|
|
||||||
CPP_SOURCES = \
|
CPP_SOURCES = \
|
||||||
$(srcdir)/BayesNet.cpp \
|
$(srcdir)/BayesNet.cpp \
|
||||||
$(srcdir)/BayesNode.cpp \
|
$(srcdir)/BayesBall.cpp \
|
||||||
$(srcdir)/ElimGraph.cpp \
|
$(srcdir)/ElimGraph.cpp \
|
||||||
$(srcdir)/FactorGraph.cpp \
|
$(srcdir)/FactorGraph.cpp \
|
||||||
$(srcdir)/Factor.cpp \
|
$(srcdir)/Factor.cpp \
|
||||||
$(srcdir)/CFactorGraph.cpp \
|
$(srcdir)/CFactorGraph.cpp \
|
||||||
$(srcdir)/ConstraintTree.cpp \
|
$(srcdir)/ConstraintTree.cpp \
|
||||||
$(srcdir)/VarNode.cpp \
|
$(srcdir)/Var.cpp \
|
||||||
$(srcdir)/Solver.cpp \
|
$(srcdir)/Solver.cpp \
|
||||||
$(srcdir)/VarElimSolver.cpp \
|
$(srcdir)/VarElimSolver.cpp \
|
||||||
$(srcdir)/BnBpSolver.cpp \
|
$(srcdir)/BpSolver.cpp \
|
||||||
$(srcdir)/FgBpSolver.cpp \
|
|
||||||
$(srcdir)/CbpSolver.cpp \
|
$(srcdir)/CbpSolver.cpp \
|
||||||
$(srcdir)/FoveSolver.cpp \
|
$(srcdir)/FoveSolver.cpp \
|
||||||
$(srcdir)/Parfactor.cpp \
|
$(srcdir)/Parfactor.cpp \
|
||||||
@ -94,22 +89,20 @@ CPP_SOURCES = \
|
|||||||
$(srcdir)/LiftedUtils.cpp \
|
$(srcdir)/LiftedUtils.cpp \
|
||||||
$(srcdir)/Util.cpp \
|
$(srcdir)/Util.cpp \
|
||||||
$(srcdir)/HorusYap.cpp \
|
$(srcdir)/HorusYap.cpp \
|
||||||
$(srcdir)/HorusCli.cpp \
|
$(srcdir)/HorusCli.cpp
|
||||||
$(srcdir)/xmlParser/xmlParser.cpp
|
|
||||||
|
|
||||||
OBJS = \
|
OBJS = \
|
||||||
BayesNet.o \
|
BayesNet.o \
|
||||||
BayesNode.o \
|
BayesBall.o \
|
||||||
ElimGraph.o \
|
ElimGraph.o \
|
||||||
FactorGraph.o \
|
FactorGraph.o \
|
||||||
Factor.o \
|
Factor.o \
|
||||||
CFactorGraph.o \
|
CFactorGraph.o \
|
||||||
ConstraintTree.o \
|
ConstraintTree.o \
|
||||||
VarNode.o \
|
Var.o \
|
||||||
Solver.o \
|
Solver.o \
|
||||||
VarElimSolver.o \
|
VarElimSolver.o \
|
||||||
BnBpSolver.o \
|
BpSolver.o \
|
||||||
FgBpSolver.o \
|
|
||||||
CbpSolver.o \
|
CbpSolver.o \
|
||||||
FoveSolver.o \
|
FoveSolver.o \
|
||||||
Parfactor.o \
|
Parfactor.o \
|
||||||
@ -122,17 +115,16 @@ OBJS = \
|
|||||||
|
|
||||||
HCLI_OBJS = \
|
HCLI_OBJS = \
|
||||||
BayesNet.o \
|
BayesNet.o \
|
||||||
BayesNode.o \
|
BayesBall.o \
|
||||||
ElimGraph.o \
|
ElimGraph.o \
|
||||||
FactorGraph.o \
|
FactorGraph.o \
|
||||||
Factor.o \
|
Factor.o \
|
||||||
CFactorGraph.o \
|
CFactorGraph.o \
|
||||||
ConstraintTree.o \
|
ConstraintTree.o \
|
||||||
VarNode.o \
|
Var.o \
|
||||||
Solver.o \
|
Solver.o \
|
||||||
VarElimSolver.o \
|
VarElimSolver.o \
|
||||||
BnBpSolver.o \
|
BpSolver.o \
|
||||||
FgBpSolver.o \
|
|
||||||
CbpSolver.o \
|
CbpSolver.o \
|
||||||
FoveSolver.o \
|
FoveSolver.o \
|
||||||
Parfactor.o \
|
Parfactor.o \
|
||||||
@ -141,7 +133,6 @@ HCLI_OBJS = \
|
|||||||
ParfactorList.o \
|
ParfactorList.o \
|
||||||
LiftedUtils.o \
|
LiftedUtils.o \
|
||||||
Util.o \
|
Util.o \
|
||||||
xmlParser/xmlParser.o \
|
|
||||||
HorusCli.o
|
HorusCli.o
|
||||||
|
|
||||||
SOBJS=horus.@SO@
|
SOBJS=horus.@SO@
|
||||||
@ -154,10 +145,6 @@ all: $(SOBJS) hcli
|
|||||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
$(CXX) -c $(CXXFLAGS) $< -o $@
|
||||||
|
|
||||||
|
|
||||||
xmlParser/xmlParser.o : $(srcdir)/xmlParser/xmlParser.cpp
|
|
||||||
$(CXX) -c $(CXXFLAGS) $< -o $@
|
|
||||||
|
|
||||||
|
|
||||||
@DO_SECOND_LD@horus.@SO@: $(OBJS)
|
@DO_SECOND_LD@horus.@SO@: $(OBJS)
|
||||||
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
@DO_SECOND_LD@ @SHLIB_CXX_LD@ -o horus.@SO@ $(OBJS) @EXTRA_LIBS_FOR_SWIDLLS@
|
||||||
|
|
||||||
@ -171,7 +158,7 @@ install: all
|
|||||||
|
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK hcli xmlParser/*.o
|
rm -f *.o *~ $(OBJS) $(SOBJS) *.BAK hcli
|
||||||
|
|
||||||
|
|
||||||
erase_dots:
|
erase_dots:
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#include "Parfactor.h"
|
#include "Parfactor.h"
|
||||||
#include "Histogram.h"
|
#include "Histogram.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
|
#include "Util.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
@ -11,55 +12,58 @@ Parfactor::Parfactor (
|
|||||||
const Tuples& tuples,
|
const Tuples& tuples,
|
||||||
unsigned distId)
|
unsigned distId)
|
||||||
{
|
{
|
||||||
formulas_ = formulas;
|
args_ = formulas;
|
||||||
params_ = params;
|
params_ = params;
|
||||||
distId_ = distId;
|
distId_ = distId;
|
||||||
|
|
||||||
LogVars logVars;
|
LogVars logVars;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
ranges_.push_back (formulas_[i].range());
|
ranges_.push_back (args_[i].range());
|
||||||
const LogVars& lvs = formulas_[i].logVars();
|
const LogVars& lvs = args_[i].logVars();
|
||||||
for (unsigned j = 0; j < lvs.size(); j++) {
|
for (unsigned j = 0; j < lvs.size(); j++) {
|
||||||
if (std::find (logVars.begin(), logVars.end(), lvs[j]) ==
|
if (Util::contains (logVars, lvs[j]) == false) {
|
||||||
logVars.end()) {
|
|
||||||
logVars.push_back (lvs[j]);
|
logVars.push_back (lvs[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
constr_ = new ConstraintTree (logVars, tuples);
|
constr_ = new ConstraintTree (logVars, tuples);
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
|
Parfactor::Parfactor (const Parfactor* g, const Tuple& tuple)
|
||||||
{
|
{
|
||||||
formulas_ = g->formulas();
|
args_ = g->arguments();
|
||||||
params_ = g->params();
|
params_ = g->params();
|
||||||
ranges_ = g->ranges();
|
ranges_ = g->ranges();
|
||||||
distId_ = g->distId();
|
distId_ = g->distId();
|
||||||
constr_ = new ConstraintTree (g->logVars(), {tuple});
|
constr_ = new ConstraintTree (g->logVars(), {tuple});
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
|
Parfactor::Parfactor (const Parfactor* g, ConstraintTree* constr)
|
||||||
{
|
{
|
||||||
formulas_ = g->formulas();
|
args_ = g->arguments();
|
||||||
params_ = g->params();
|
params_ = g->params();
|
||||||
ranges_ = g->ranges();
|
ranges_ = g->ranges();
|
||||||
distId_ = g->distId();
|
distId_ = g->distId();
|
||||||
constr_ = constr;
|
constr_ = constr;
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Parfactor::Parfactor (const Parfactor& g)
|
Parfactor::Parfactor (const Parfactor& g)
|
||||||
{
|
{
|
||||||
formulas_ = g.formulas();
|
args_ = g.arguments();
|
||||||
params_ = g.params();
|
params_ = g.params();
|
||||||
ranges_ = g.ranges();
|
ranges_ = g.ranges();
|
||||||
distId_ = g.distId();
|
distId_ = g.distId();
|
||||||
constr_ = new ConstraintTree (*g.constr());
|
constr_ = new ConstraintTree (*g.constr());
|
||||||
|
assert (params_.size() == Util::expectedSize (ranges_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -75,9 +79,9 @@ LogVarSet
|
|||||||
Parfactor::countedLogVars (void) const
|
Parfactor::countedLogVars (void) const
|
||||||
{
|
{
|
||||||
LogVarSet set;
|
LogVarSet set;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (formulas_[i].isCounting()) {
|
if (args_[i].isCounting()) {
|
||||||
set.insert (formulas_[i].countedLogVar());
|
set.insert (args_[i].countedLogVar());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return set;
|
return set;
|
||||||
@ -107,14 +111,14 @@ Parfactor::elimLogVars (void) const
|
|||||||
LogVarSet
|
LogVarSet
|
||||||
Parfactor::exclusiveLogVars (unsigned fIdx) const
|
Parfactor::exclusiveLogVars (unsigned fIdx) const
|
||||||
{
|
{
|
||||||
assert (fIdx < formulas_.size());
|
assert (fIdx < args_.size());
|
||||||
LogVarSet remaining;
|
LogVarSet remaining;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (i != fIdx) {
|
if (i != fIdx) {
|
||||||
remaining |= formulas_[i].logVarSet();
|
remaining |= args_[i].logVarSet();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return formulas_[fIdx].logVarSet() - remaining;
|
return args_[fIdx].logVarSet() - remaining;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -131,44 +135,51 @@ Parfactor::setConstraintTree (ConstraintTree* newTree)
|
|||||||
void
|
void
|
||||||
Parfactor::sumOut (unsigned fIdx)
|
Parfactor::sumOut (unsigned fIdx)
|
||||||
{
|
{
|
||||||
assert (fIdx < formulas_.size());
|
assert (fIdx < args_.size());
|
||||||
assert (formulas_[fIdx].contains (elimLogVars()));
|
assert (args_[fIdx].contains (elimLogVars()));
|
||||||
|
|
||||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||||
unsigned condCount = constr_->getConditionalCount (excl);
|
if (args_[fIdx].isCounting()) {
|
||||||
Util::pow (params_, condCount);
|
LogAware::pow (params_, constr_->getConditionalCount (
|
||||||
|
excl - args_[fIdx].countedLogVar()));
|
||||||
|
} else {
|
||||||
|
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
||||||
|
}
|
||||||
|
|
||||||
vector<unsigned> numAssigns (ranges_[fIdx], 1);
|
if (args_[fIdx].isCounting()) {
|
||||||
if (formulas_[fIdx].isCounting()) {
|
|
||||||
unsigned N = constr_->getConditionalCount (
|
unsigned N = constr_->getConditionalCount (
|
||||||
formulas_[fIdx].countedLogVar());
|
args_[fIdx].countedLogVar());
|
||||||
unsigned R = formulas_[fIdx].range();
|
unsigned R = args_[fIdx].range();
|
||||||
unsigned H = ranges_[fIdx];
|
vector<double> numAssigns = HistogramSet::getNumAssigns (N, R);
|
||||||
HistogramSet hs (N, R);
|
StatesIndexer sindexer (ranges_, fIdx);
|
||||||
unsigned N_factorial = Util::factorial (N);
|
while (sindexer.valid()) {
|
||||||
for (unsigned h = 0; h < H; h++) {
|
unsigned h = sindexer[fIdx];
|
||||||
unsigned prod = 1;
|
if (Globals::logDomain) {
|
||||||
for (unsigned r = 0; r < R; r++) {
|
params_[sindexer] += numAssigns[h];
|
||||||
prod *= Util::factorial (hs[r]);
|
} else {
|
||||||
|
params_[sindexer] *= numAssigns[h];
|
||||||
}
|
}
|
||||||
numAssigns[h] = N_factorial / prod;
|
++ sindexer;
|
||||||
hs.nextHistogram();
|
|
||||||
}
|
}
|
||||||
cout << endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Params copy = params_;
|
Params copy = params_;
|
||||||
params_.clear();
|
params_.clear();
|
||||||
params_.resize (copy.size() / ranges_[fIdx], 0.0);
|
params_.resize (copy.size() / ranges_[fIdx], LogAware::addIdenty());
|
||||||
|
|
||||||
MapIndexer indexer (ranges_, fIdx);
|
MapIndexer indexer (ranges_, fIdx);
|
||||||
for (unsigned i = 0; i < copy.size(); i++) {
|
if (Globals::logDomain) {
|
||||||
unsigned h = indexer[fIdx];
|
for (unsigned i = 0; i < copy.size(); i++) {
|
||||||
// TODO NOT LOG DOMAIN AWARE :(
|
params_[indexer] = Util::logSum (params_[indexer], copy[i]);
|
||||||
params_[indexer] += numAssigns[h] * copy[i];
|
++ indexer;
|
||||||
++ indexer;
|
}
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < copy.size(); i++) {
|
||||||
|
params_[indexer] += copy[i];
|
||||||
|
++ indexer;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
formulas_.erase (formulas_.begin() + fIdx);
|
|
||||||
|
args_.erase (args_.begin() + fIdx);
|
||||||
ranges_.erase (ranges_.begin() + fIdx);
|
ranges_.erase (ranges_.begin() + fIdx);
|
||||||
constr_->remove (excl);
|
constr_->remove (excl);
|
||||||
}
|
}
|
||||||
@ -179,55 +190,7 @@ void
|
|||||||
Parfactor::multiply (Parfactor& g)
|
Parfactor::multiply (Parfactor& g)
|
||||||
{
|
{
|
||||||
alignAndExponentiate (this, &g);
|
alignAndExponentiate (this, &g);
|
||||||
bool sharedVars = false;
|
TFactor<ProbFormula>::multiply (g);
|
||||||
vector<unsigned> g_varpos;
|
|
||||||
const ProbFormulas& g_formulas = g.formulas();
|
|
||||||
const Params& g_params = g.params();
|
|
||||||
const Ranges& g_ranges = g.ranges();
|
|
||||||
|
|
||||||
for (unsigned i = 0; i < g_formulas.size(); i++) {
|
|
||||||
int group = g_formulas[i].group();
|
|
||||||
if (indexOfFormulaWithGroup (group) == -1) {
|
|
||||||
insertDimension (g.ranges()[i]);
|
|
||||||
formulas_.push_back (g_formulas[i]);
|
|
||||||
g_varpos.push_back (formulas_.size() - 1);
|
|
||||||
} else {
|
|
||||||
sharedVars = true;
|
|
||||||
g_varpos.push_back (indexOfFormulaWithGroup (group));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sharedVars == false) {
|
|
||||||
unsigned count = 0;
|
|
||||||
for (unsigned i = 0; i < params_.size(); i++) {
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
params_[i] += g_params[count];
|
|
||||||
} else {
|
|
||||||
params_[i] *= g_params[count];
|
|
||||||
}
|
|
||||||
count ++;
|
|
||||||
if (count >= g_params.size()) {
|
|
||||||
count = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
StatesIndexer indexer (ranges_, false);
|
|
||||||
while (indexer.valid()) {
|
|
||||||
unsigned g_li = 0;
|
|
||||||
unsigned prod = 1;
|
|
||||||
for (int j = g_varpos.size() - 1; j >= 0; j--) {
|
|
||||||
g_li += indexer[g_varpos[j]] * prod;
|
|
||||||
prod *= g_ranges[j];
|
|
||||||
}
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
params_[indexer] += g_params[g_li];
|
|
||||||
} else {
|
|
||||||
params_[indexer] *= g_params[g_li];
|
|
||||||
}
|
|
||||||
++ indexer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
constr_->join (g.constr(), true);
|
constr_->join (g.constr(), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,7 +199,7 @@ Parfactor::multiply (Parfactor& g)
|
|||||||
void
|
void
|
||||||
Parfactor::countConvert (LogVar X)
|
Parfactor::countConvert (LogVar X)
|
||||||
{
|
{
|
||||||
int fIdx = indexOfFormulaWithLogVar (X);
|
int fIdx = indexOfLogVar (X);
|
||||||
assert (fIdx != -1);
|
assert (fIdx != -1);
|
||||||
assert (constr_->isCountNormalized (X));
|
assert (constr_->isCountNormalized (X));
|
||||||
assert (constr_->getConditionalCount (X) > 1);
|
assert (constr_->getConditionalCount (X) > 1);
|
||||||
@ -248,12 +211,12 @@ Parfactor::countConvert (LogVar X)
|
|||||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||||
|
|
||||||
StatesIndexer indexer (ranges_);
|
StatesIndexer indexer (ranges_);
|
||||||
vector<Params> summout (params_.size() / R);
|
vector<Params> sumout (params_.size() / R);
|
||||||
unsigned count = 0;
|
unsigned count = 0;
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
summout[count].reserve (R);
|
sumout[count].reserve (R);
|
||||||
for (unsigned r = 0; r < R; r++) {
|
for (unsigned r = 0; r < R; r++) {
|
||||||
summout[count].push_back (params_[indexer]);
|
sumout[count].push_back (params_[indexer]);
|
||||||
indexer.increment (fIdx);
|
indexer.increment (fIdx);
|
||||||
}
|
}
|
||||||
count ++;
|
count ++;
|
||||||
@ -262,45 +225,42 @@ Parfactor::countConvert (LogVar X)
|
|||||||
}
|
}
|
||||||
|
|
||||||
params_.clear();
|
params_.clear();
|
||||||
params_.reserve (summout.size() * H);
|
params_.reserve (sumout.size() * H);
|
||||||
|
|
||||||
vector<bool> mapDims (ranges_.size(), true);
|
|
||||||
ranges_[fIdx] = H;
|
ranges_[fIdx] = H;
|
||||||
mapDims[fIdx] = false;
|
MapIndexer mapIndexer (ranges_, fIdx);
|
||||||
MapIndexer mapIndexer (ranges_, mapDims);
|
|
||||||
while (mapIndexer.valid()) {
|
while (mapIndexer.valid()) {
|
||||||
double prod = 1.0;
|
double prod = LogAware::multIdenty();
|
||||||
unsigned i = mapIndexer.mappedIndex();
|
unsigned i = mapIndexer.mappedIndex();
|
||||||
unsigned h = mapIndexer[fIdx];
|
unsigned h = mapIndexer[fIdx];
|
||||||
for (unsigned r = 0; r < R; r++) {
|
for (unsigned r = 0; r < R; r++) {
|
||||||
// TODO not log domain aware
|
if (Globals::logDomain) {
|
||||||
prod *= Util::pow (summout[i][r], histograms[h][r]);
|
prod += LogAware::pow (sumout[i][r], histograms[h][r]);
|
||||||
|
} else {
|
||||||
|
prod *= LogAware::pow (sumout[i][r], histograms[h][r]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
params_.push_back (prod);
|
params_.push_back (prod);
|
||||||
++ mapIndexer;
|
++ mapIndexer;
|
||||||
}
|
}
|
||||||
formulas_[fIdx].setCountedLogVar (X);
|
args_[fIdx].setCountedLogVar (X);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Parfactor::expandPotential (
|
Parfactor::expand (LogVar X, LogVar X_new1, LogVar X_new2)
|
||||||
LogVar X,
|
|
||||||
LogVar X_new1,
|
|
||||||
LogVar X_new2)
|
|
||||||
{
|
{
|
||||||
int fIdx = indexOfFormulaWithLogVar (X);
|
int fIdx = indexOfLogVar (X);
|
||||||
assert (fIdx != -1);
|
assert (fIdx != -1);
|
||||||
assert (formulas_[fIdx].isCounting());
|
assert (args_[fIdx].isCounting());
|
||||||
|
|
||||||
unsigned N1 = constr_->getConditionalCount (X_new1);
|
unsigned N1 = constr_->getConditionalCount (X_new1);
|
||||||
unsigned N2 = constr_->getConditionalCount (X_new2);
|
unsigned N2 = constr_->getConditionalCount (X_new2);
|
||||||
unsigned N = N1 + N2;
|
unsigned N = N1 + N2;
|
||||||
unsigned R = formulas_[fIdx].range();
|
unsigned R = args_[fIdx].range();
|
||||||
unsigned H1 = HistogramSet::nrHistograms (N1, R);
|
unsigned H1 = HistogramSet::nrHistograms (N1, R);
|
||||||
unsigned H2 = HistogramSet::nrHistograms (N2, R);
|
unsigned H2 = HistogramSet::nrHistograms (N2, R);
|
||||||
unsigned H = ranges_[fIdx];
|
|
||||||
|
|
||||||
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
vector<Histogram> histograms = HistogramSet::getHistograms (N, R);
|
||||||
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
|
vector<Histogram> histograms1 = HistogramSet::getHistograms (N1, R);
|
||||||
@ -320,48 +280,11 @@ Parfactor::expandPotential (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned size = (params_.size() / H) * H1 * H2;
|
expandPotential (fIdx, H1 * H2, sumIndexes);
|
||||||
Params copy = params_;
|
|
||||||
params_.clear();
|
|
||||||
params_.reserve (size);
|
|
||||||
|
|
||||||
unsigned prod = 1;
|
args_.insert (args_.begin() + fIdx + 1, args_[fIdx]);
|
||||||
vector<unsigned> offsets_ (ranges_.size());
|
args_[fIdx].rename (X, X_new1);
|
||||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
args_[fIdx + 1].rename (X, X_new2);
|
||||||
offsets_[i] = prod;
|
|
||||||
prod *= ranges_[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned index = 0;
|
|
||||||
ranges_[fIdx] = H1 * H2;
|
|
||||||
vector<unsigned> indices (ranges_.size(), 0);
|
|
||||||
for (unsigned k = 0; k < size; k++) {
|
|
||||||
params_.push_back (copy[index]);
|
|
||||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
|
||||||
indices[i] ++;
|
|
||||||
if (i == fIdx) {
|
|
||||||
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
|
||||||
index += diff * offsets_[i];
|
|
||||||
} else {
|
|
||||||
index += offsets_[i];
|
|
||||||
}
|
|
||||||
if (indices[i] != ranges_[i]) {
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
if (i == fIdx) {
|
|
||||||
int diff = sumIndexes[0] - sumIndexes[indices[i]];
|
|
||||||
index += diff * offsets_[i];
|
|
||||||
} else {
|
|
||||||
index -= offsets_[i] * ranges_[i];
|
|
||||||
}
|
|
||||||
indices[i] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
formulas_.insert (formulas_.begin() + fIdx + 1, formulas_[fIdx]);
|
|
||||||
formulas_[fIdx].rename (X, X_new1);
|
|
||||||
formulas_[fIdx + 1].rename (X, X_new2);
|
|
||||||
ranges_.insert (ranges_.begin() + fIdx + 1, H2);
|
ranges_.insert (ranges_.begin() + fIdx + 1, H2);
|
||||||
ranges_[fIdx] = H1;
|
ranges_[fIdx] = H1;
|
||||||
}
|
}
|
||||||
@ -371,13 +294,12 @@ Parfactor::expandPotential (
|
|||||||
void
|
void
|
||||||
Parfactor::fullExpand (LogVar X)
|
Parfactor::fullExpand (LogVar X)
|
||||||
{
|
{
|
||||||
int fIdx = indexOfFormulaWithLogVar (X);
|
int fIdx = indexOfLogVar (X);
|
||||||
assert (fIdx != -1);
|
assert (fIdx != -1);
|
||||||
assert (formulas_[fIdx].isCounting());
|
assert (args_[fIdx].isCounting());
|
||||||
|
|
||||||
unsigned N = constr_->getConditionalCount (X);
|
unsigned N = constr_->getConditionalCount (X);
|
||||||
unsigned R = formulas_[fIdx].range();
|
unsigned R = args_[fIdx].range();
|
||||||
unsigned H = ranges_[fIdx];
|
|
||||||
|
|
||||||
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
vector<Histogram> originHists = HistogramSet::getHistograms (N, R);
|
||||||
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
vector<Histogram> expandHists = HistogramSet::getHistograms (1, R);
|
||||||
@ -400,54 +322,17 @@ Parfactor::fullExpand (LogVar X)
|
|||||||
++ indexer;
|
++ indexer;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned size = (params_.size() / H) * std::pow (R, N);
|
expandPotential (fIdx, std::pow (R, N), sumIndexes);
|
||||||
Params copy = params_;
|
|
||||||
params_.clear();
|
|
||||||
params_.reserve (size);
|
|
||||||
|
|
||||||
unsigned prod = 1;
|
ProbFormula f = args_[fIdx];
|
||||||
vector<unsigned> offsets_ (ranges_.size());
|
args_.erase (args_.begin() + fIdx);
|
||||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
|
||||||
offsets_[i] = prod;
|
|
||||||
prod *= ranges_[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned index = 0;
|
|
||||||
ranges_[fIdx] = std::pow (R, N);
|
|
||||||
vector<unsigned> indices (ranges_.size(), 0);
|
|
||||||
for (unsigned k = 0; k < size; k++) {
|
|
||||||
params_.push_back (copy[index]);
|
|
||||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
|
||||||
indices[i] ++;
|
|
||||||
if (i == fIdx) {
|
|
||||||
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
|
||||||
index += diff * offsets_[i];
|
|
||||||
} else {
|
|
||||||
index += offsets_[i];
|
|
||||||
}
|
|
||||||
if (indices[i] != ranges_[i]) {
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
if (i == fIdx) {
|
|
||||||
int diff = sumIndexes[0] - sumIndexes[indices[i]];
|
|
||||||
index += diff * offsets_[i];
|
|
||||||
} else {
|
|
||||||
index -= offsets_[i] * ranges_[i];
|
|
||||||
}
|
|
||||||
indices[i] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ProbFormula f = formulas_[fIdx];
|
|
||||||
formulas_.erase (formulas_.begin() + fIdx);
|
|
||||||
ranges_.erase (ranges_.begin() + fIdx);
|
ranges_.erase (ranges_.begin() + fIdx);
|
||||||
LogVars newLvs = constr_->expand (X);
|
LogVars newLvs = constr_->expand (X);
|
||||||
assert (newLvs.size() == N);
|
assert (newLvs.size() == N);
|
||||||
for (unsigned i = 0 ; i < N; i++) {
|
for (unsigned i = 0 ; i < N; i++) {
|
||||||
ProbFormula newFormula (f.functor(), f.logVars(), f.range());
|
ProbFormula newFormula (f.functor(), f.logVars(), f.range());
|
||||||
newFormula.rename (X, newLvs[i]);
|
newFormula.rename (X, newLvs[i]);
|
||||||
formulas_.insert (formulas_.begin() + fIdx + i, newFormula);
|
args_.insert (args_.begin() + fIdx + i, newFormula);
|
||||||
ranges_.insert (ranges_.begin() + fIdx + i, R);
|
ranges_.insert (ranges_.begin() + fIdx + i, R);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -459,117 +344,43 @@ Parfactor::reorderAccordingGrounds (const Grounds& grounds)
|
|||||||
{
|
{
|
||||||
ProbFormulas newFormulas;
|
ProbFormulas newFormulas;
|
||||||
for (unsigned i = 0; i < grounds.size(); i++) {
|
for (unsigned i = 0; i < grounds.size(); i++) {
|
||||||
for (unsigned j = 0; j < formulas_.size(); j++) {
|
for (unsigned j = 0; j < args_.size(); j++) {
|
||||||
if (grounds[i].functor() == formulas_[j].functor() &&
|
if (grounds[i].functor() == args_[j].functor() &&
|
||||||
grounds[i].arity() == formulas_[j].arity()) {
|
grounds[i].arity() == args_[j].arity()) {
|
||||||
constr_->moveToTop (formulas_[j].logVars());
|
constr_->moveToTop (args_[j].logVars());
|
||||||
if (constr_->containsTuple (grounds[i].args())) {
|
if (constr_->containsTuple (grounds[i].args())) {
|
||||||
newFormulas.push_back (formulas_[j]);
|
newFormulas.push_back (args_[j]);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert (newFormulas.size() == i + 1);
|
assert (newFormulas.size() == i + 1);
|
||||||
}
|
}
|
||||||
reorderFormulas (newFormulas);
|
reorderArguments (newFormulas);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Parfactor::reorderFormulas (const ProbFormulas& newFormulas)
|
Parfactor::absorveEvidence (const ProbFormula& formula, unsigned evidence)
|
||||||
{
|
|
||||||
assert (newFormulas.size() == formulas_.size());
|
|
||||||
if (newFormulas == formulas_) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ranges newRanges;
|
|
||||||
vector<unsigned> positions;
|
|
||||||
for (unsigned i = 0; i < newFormulas.size(); i++) {
|
|
||||||
unsigned idx = indexOf (newFormulas[i]);
|
|
||||||
newRanges.push_back (ranges_[idx]);
|
|
||||||
positions.push_back (idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned N = ranges_.size();
|
|
||||||
Params newParams (params_.size());
|
|
||||||
for (unsigned i = 0; i < params_.size(); i++) {
|
|
||||||
unsigned li = i;
|
|
||||||
// calculate vector index corresponding to linear index
|
|
||||||
vector<unsigned> vi (N);
|
|
||||||
for (int k = N-1; k >= 0; k--) {
|
|
||||||
vi[k] = li % ranges_[k];
|
|
||||||
li /= ranges_[k];
|
|
||||||
}
|
|
||||||
// convert permuted vector index to corresponding linear index
|
|
||||||
unsigned prod = 1;
|
|
||||||
unsigned new_li = 0;
|
|
||||||
for (int k = N - 1; k >= 0; k--) {
|
|
||||||
new_li += vi[positions[k]] * prod;
|
|
||||||
prod *= ranges_[positions[k]];
|
|
||||||
}
|
|
||||||
newParams[new_li] = params_[i];
|
|
||||||
}
|
|
||||||
formulas_ = newFormulas;
|
|
||||||
ranges_ = newRanges;
|
|
||||||
params_ = newParams;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::absorveEvidence (unsigned fIdx, unsigned evidence)
|
|
||||||
{
|
{
|
||||||
|
int fIdx = indexOf (formula);
|
||||||
|
assert (fIdx != -1);
|
||||||
LogVarSet excl = exclusiveLogVars (fIdx);
|
LogVarSet excl = exclusiveLogVars (fIdx);
|
||||||
assert (fIdx < formulas_.size());
|
assert (args_[fIdx].isCounting() == false);
|
||||||
assert (evidence < formulas_[fIdx].range());
|
|
||||||
assert (formulas_[fIdx].isCounting() == false);
|
|
||||||
assert (constr_->isCountNormalized (excl));
|
assert (constr_->isCountNormalized (excl));
|
||||||
|
LogAware::pow (params_, constr_->getConditionalCount (excl));
|
||||||
Util::pow (params_, constr_->getConditionalCount (excl));
|
TFactor<ProbFormula>::absorveEvidence (formula, evidence);
|
||||||
|
|
||||||
Params copy = params_;
|
|
||||||
params_.clear();
|
|
||||||
params_.reserve (copy.size() / formulas_[fIdx].range());
|
|
||||||
|
|
||||||
StatesIndexer indexer (ranges_);
|
|
||||||
for (unsigned i = 0; i < evidence; i++) {
|
|
||||||
indexer.increment (fIdx);
|
|
||||||
}
|
|
||||||
while (indexer.valid()) {
|
|
||||||
params_.push_back (copy[indexer]);
|
|
||||||
indexer.incrementExcluding (fIdx);
|
|
||||||
}
|
|
||||||
formulas_.erase (formulas_.begin() + fIdx);
|
|
||||||
ranges_.erase (ranges_.begin() + fIdx);
|
|
||||||
constr_->remove (excl);
|
constr_->remove (excl);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::normalize (void)
|
|
||||||
{
|
|
||||||
Util::normalize (params_);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Parfactor::setFormulaGroup (const ProbFormula& f, int group)
|
|
||||||
{
|
|
||||||
assert (indexOf (f) != -1);
|
|
||||||
formulas_[indexOf (f)].setGroup (group);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Parfactor::setNewGroups (void)
|
Parfactor::setNewGroups (void)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
formulas_[i].setGroup (ProbFormula::getNewGroup());
|
args_[i].setGroup (ProbFormula::getNewGroup());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -578,14 +389,14 @@ Parfactor::setNewGroups (void)
|
|||||||
void
|
void
|
||||||
Parfactor::applySubstitution (const Substitution& theta)
|
Parfactor::applySubstitution (const Substitution& theta)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
LogVars& lvs = formulas_[i].logVars();
|
LogVars& lvs = args_[i].logVars();
|
||||||
for (unsigned j = 0; j < lvs.size(); j++) {
|
for (unsigned j = 0; j < lvs.size(); j++) {
|
||||||
lvs[j] = theta.newNameFor (lvs[j]);
|
lvs[j] = theta.newNameFor (lvs[j]);
|
||||||
}
|
}
|
||||||
if (formulas_[i].isCounting()) {
|
if (args_[i].isCounting()) {
|
||||||
LogVar clv = formulas_[i].countedLogVar();
|
LogVar clv = args_[i].countedLogVar();
|
||||||
formulas_[i].setCountedLogVar (theta.newNameFor (clv));
|
args_[i].setCountedLogVar (theta.newNameFor (clv));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
constr_->applySubstitution (theta);
|
constr_->applySubstitution (theta);
|
||||||
@ -593,19 +404,29 @@ Parfactor::applySubstitution (const Substitution& theta)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
int
|
||||||
Parfactor::containsGround (const Ground& ground) const
|
Parfactor::findGroup (const Ground& ground) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
int group = -1;
|
||||||
if (formulas_[i].functor() == ground.functor() &&
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
formulas_[i].arity() == ground.arity()) {
|
if (args_[i].functor() == ground.functor() &&
|
||||||
constr_->moveToTop (formulas_[i].logVars());
|
args_[i].arity() == ground.arity()) {
|
||||||
|
constr_->moveToTop (args_[i].logVars());
|
||||||
if (constr_->containsTuple (ground.args())) {
|
if (constr_->containsTuple (ground.args())) {
|
||||||
return true;
|
group = args_[i].group();
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return group;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Parfactor::containsGround (const Ground& ground) const
|
||||||
|
{
|
||||||
|
return findGroup (ground) != -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -613,8 +434,8 @@ Parfactor::containsGround (const Ground& ground) const
|
|||||||
bool
|
bool
|
||||||
Parfactor::containsGroup (unsigned group) const
|
Parfactor::containsGroup (unsigned group) const
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (formulas_[i].group() == group) {
|
if (args_[i].group() == group) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -623,30 +444,12 @@ Parfactor::containsGroup (unsigned group) const
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
const ProbFormula&
|
|
||||||
Parfactor::formula (unsigned fIdx) const
|
|
||||||
{
|
|
||||||
assert (fIdx < formulas_.size());
|
|
||||||
return formulas_[fIdx];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
|
||||||
Parfactor::range (unsigned fIdx) const
|
|
||||||
{
|
|
||||||
assert (fIdx < ranges_.size());
|
|
||||||
return ranges_[fIdx];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
unsigned
|
unsigned
|
||||||
Parfactor::nrFormulas (LogVar X) const
|
Parfactor::nrFormulas (LogVar X) const
|
||||||
{
|
{
|
||||||
unsigned count = 0;
|
unsigned count = 0;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (formulas_[i].contains (X)) {
|
if (args_[i].contains (X)) {
|
||||||
count ++;
|
count ++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -656,27 +459,12 @@ Parfactor::nrFormulas (LogVar X) const
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
Parfactor::indexOf (const ProbFormula& f) const
|
Parfactor::indexOfLogVar (LogVar X) const
|
||||||
{
|
|
||||||
int idx = -1;
|
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
|
||||||
if (f == formulas_[i]) {
|
|
||||||
idx = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int
|
|
||||||
Parfactor::indexOfFormulaWithLogVar (LogVar X) const
|
|
||||||
{
|
{
|
||||||
int idx = -1;
|
int idx = -1;
|
||||||
assert (nrFormulas (X) == 1);
|
assert (nrFormulas (X) == 1);
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (formulas_[i].contains (X)) {
|
if (args_[i].contains (X)) {
|
||||||
idx = i;
|
idx = i;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -687,11 +475,11 @@ Parfactor::indexOfFormulaWithLogVar (LogVar X) const
|
|||||||
|
|
||||||
|
|
||||||
int
|
int
|
||||||
Parfactor::indexOfFormulaWithGroup (unsigned group) const
|
Parfactor::indexOfGroup (unsigned group) const
|
||||||
{
|
{
|
||||||
int pos = -1;
|
int pos = -1;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (formulas_[i].group() == group) {
|
if (args_[i].group() == group) {
|
||||||
pos = i;
|
pos = i;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -704,9 +492,9 @@ Parfactor::indexOfFormulaWithGroup (unsigned group) const
|
|||||||
vector<unsigned>
|
vector<unsigned>
|
||||||
Parfactor::getAllGroups (void) const
|
Parfactor::getAllGroups (void) const
|
||||||
{
|
{
|
||||||
vector<unsigned> groups (formulas_.size());
|
vector<unsigned> groups (args_.size());
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
groups[i] = formulas_[i].group();
|
groups[i] = args_[i].group();
|
||||||
}
|
}
|
||||||
return groups;
|
return groups;
|
||||||
}
|
}
|
||||||
@ -714,13 +502,13 @@ Parfactor::getAllGroups (void) const
|
|||||||
|
|
||||||
|
|
||||||
string
|
string
|
||||||
Parfactor::getHeaderString (void) const
|
Parfactor::getLabel (void) const
|
||||||
{
|
{
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
ss << "phi(" ;
|
ss << "phi(" ;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (i != 0) ss << "," ;
|
if (i != 0) ss << "," ;
|
||||||
ss << formulas_[i];
|
ss << args_[i];
|
||||||
}
|
}
|
||||||
ss << ")" ;
|
ss << ")" ;
|
||||||
ConstraintTree copy (*constr_);
|
ConstraintTree copy (*constr_);
|
||||||
@ -735,32 +523,37 @@ void
|
|||||||
Parfactor::print (bool printParams) const
|
Parfactor::print (bool printParams) const
|
||||||
{
|
{
|
||||||
cout << "Formulas: " ;
|
cout << "Formulas: " ;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (i != 0) cout << ", " ;
|
if (i != 0) cout << ", " ;
|
||||||
cout << formulas_[i];
|
cout << args_[i];
|
||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
vector<string> groups;
|
if (args_[0].group() != Util::maxUnsigned()) {
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
vector<string> groups;
|
||||||
groups.push_back (string ("g") + Util::toString (formulas_[i].group()));
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
|
groups.push_back (string ("g") + Util::toString (args_[i].group()));
|
||||||
|
}
|
||||||
|
cout << "Groups: " << groups << endl;
|
||||||
}
|
}
|
||||||
cout << "Groups: " << groups << endl;
|
cout << "LogVars: " << constr_->logVarSet() << endl;
|
||||||
cout << "LogVars: " << constr_->logVars() << endl;
|
|
||||||
cout << "Ranges: " << ranges_ << endl;
|
cout << "Ranges: " << ranges_ << endl;
|
||||||
if (printParams == false) {
|
if (printParams == false) {
|
||||||
cout << "Params: " << params_ << endl;
|
cout << "Params: " << params_ << endl;
|
||||||
}
|
}
|
||||||
cout << "Tuples: " << constr_->tupleSet() << endl;
|
ConstraintTree copy (*constr_);
|
||||||
|
copy.moveToTop (copy.logVarSet().elements());
|
||||||
|
cout << "Tuples: " << copy.tupleSet() << endl;
|
||||||
if (printParams) {
|
if (printParams) {
|
||||||
vector<string> jointStrings;
|
vector<string> jointStrings;
|
||||||
StatesIndexer indexer (ranges_);
|
StatesIndexer indexer (ranges_);
|
||||||
while (indexer.valid()) {
|
while (indexer.valid()) {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
for (unsigned i = 0; i < formulas_.size(); i++) {
|
for (unsigned i = 0; i < args_.size(); i++) {
|
||||||
if (i != 0) ss << ", " ;
|
if (i != 0) ss << ", " ;
|
||||||
if (formulas_[i].isCounting()) {
|
if (args_[i].isCounting()) {
|
||||||
unsigned N = constr_->getConditionalCount (formulas_[i].countedLogVar());
|
unsigned N = constr_->getConditionalCount (
|
||||||
HistogramSet hs (N, formulas_[i].range());
|
args_[i].countedLogVar());
|
||||||
|
HistogramSet hs (N, args_[i].range());
|
||||||
unsigned c = 0;
|
unsigned c = 0;
|
||||||
while (c < indexer[i]) {
|
while (c < indexer[i]) {
|
||||||
hs.nextHistogram();
|
hs.nextHistogram();
|
||||||
@ -779,22 +572,56 @@ Parfactor::print (bool printParams) const
|
|||||||
cout << " = " << params_[i] << endl;
|
cout << " = " << params_[i] << endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Parfactor::insertDimension (unsigned range)
|
Parfactor::expandPotential (
|
||||||
|
int fIdx,
|
||||||
|
unsigned newRange,
|
||||||
|
const vector<unsigned>& sumIndexes)
|
||||||
{
|
{
|
||||||
|
unsigned size = (params_.size() / ranges_[fIdx]) * newRange;
|
||||||
Params copy = params_;
|
Params copy = params_;
|
||||||
params_.clear();
|
params_.clear();
|
||||||
params_.reserve (copy.size() * range);
|
params_.reserve (size);
|
||||||
for (unsigned i = 0; i < copy.size(); i++) {
|
|
||||||
for (unsigned reps = 0; reps < range; reps++) {
|
unsigned prod = 1;
|
||||||
params_.push_back (copy[i]);
|
vector<unsigned> offsets_ (ranges_.size());
|
||||||
|
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||||
|
offsets_[i] = prod;
|
||||||
|
prod *= ranges_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned index = 0;
|
||||||
|
ranges_[fIdx] = newRange;
|
||||||
|
vector<unsigned> indices (ranges_.size(), 0);
|
||||||
|
for (unsigned k = 0; k < size; k++) {
|
||||||
|
params_.push_back (copy[index]);
|
||||||
|
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||||
|
indices[i] ++;
|
||||||
|
if (i == fIdx) {
|
||||||
|
assert (indices[i] - 1 < sumIndexes.size());
|
||||||
|
int diff = sumIndexes[indices[i]] - sumIndexes[indices[i] - 1];
|
||||||
|
index += diff * offsets_[i];
|
||||||
|
} else {
|
||||||
|
index += offsets_[i];
|
||||||
|
}
|
||||||
|
if (indices[i] != ranges_[i]) {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
if (i == fIdx) {
|
||||||
|
int diff = sumIndexes[0] - sumIndexes[indices[i]];
|
||||||
|
index += diff * offsets_[i];
|
||||||
|
} else {
|
||||||
|
index -= offsets_[i] * ranges_[i];
|
||||||
|
}
|
||||||
|
indices[i] = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ranges_.push_back (range);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -803,29 +630,27 @@ void
|
|||||||
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
Parfactor::alignAndExponentiate (Parfactor* g1, Parfactor* g2)
|
||||||
{
|
{
|
||||||
LogVars X_1, X_2;
|
LogVars X_1, X_2;
|
||||||
const ProbFormulas& formulas1 = g1->formulas();
|
const ProbFormulas& formulas1 = g1->arguments();
|
||||||
const ProbFormulas& formulas2 = g2->formulas();
|
const ProbFormulas& formulas2 = g2->arguments();
|
||||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
for (unsigned i = 0; i < formulas1.size(); i++) {
|
||||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
for (unsigned j = 0; j < formulas2.size(); j++) {
|
||||||
if (formulas1[i].group() == formulas2[j].group()) {
|
if (formulas1[i].group() == formulas2[j].group()) {
|
||||||
X_1.insert (X_1.end(),
|
Util::addToVector (X_1, formulas1[i].logVars());
|
||||||
formulas1[i].logVars().begin(),
|
Util::addToVector (X_2, formulas2[j].logVars());
|
||||||
formulas1[i].logVars().end());
|
|
||||||
X_2.insert (X_2.end(),
|
|
||||||
formulas2[j].logVars().begin(),
|
|
||||||
formulas2[j].logVars().end());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
align (g1, X_1, g2, X_2);
|
|
||||||
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1);
|
LogVarSet Y_1 = g1->logVarSet() - LogVarSet (X_1);
|
||||||
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2);
|
LogVarSet Y_2 = g2->logVarSet() - LogVarSet (X_2);
|
||||||
assert (g1->constr()->isCountNormalized (Y_1));
|
assert (g1->constr()->isCountNormalized (Y_1));
|
||||||
assert (g2->constr()->isCountNormalized (Y_2));
|
assert (g2->constr()->isCountNormalized (Y_2));
|
||||||
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
|
unsigned condCount1 = g1->constr()->getConditionalCount (Y_1);
|
||||||
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
|
unsigned condCount2 = g2->constr()->getConditionalCount (Y_2);
|
||||||
Util::pow (g1->params(), 1.0 / condCount2);
|
LogAware::pow (g1->params(), 1.0 / condCount2);
|
||||||
Util::pow (g2->params(), 1.0 / condCount1);
|
LogAware::pow (g2->params(), 1.0 / condCount1);
|
||||||
|
// this must be done in the end or else X_1 and X_2
|
||||||
|
// will refer the old log var names in the code above
|
||||||
|
align (g1, X_1, g2, X_2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -838,7 +663,6 @@ Parfactor::align (
|
|||||||
LogVar freeLogVar = 0;
|
LogVar freeLogVar = 0;
|
||||||
Substitution theta1;
|
Substitution theta1;
|
||||||
Substitution theta2;
|
Substitution theta2;
|
||||||
|
|
||||||
const LogVarSet& allLvs1 = g1->logVarSet();
|
const LogVarSet& allLvs1 = g1->logVarSet();
|
||||||
for (unsigned i = 0; i < allLvs1.size(); i++) {
|
for (unsigned i = 0; i < allLvs1.size(); i++) {
|
||||||
theta1.add (allLvs1[i], freeLogVar);
|
theta1.add (allLvs1[i], freeLogVar);
|
||||||
@ -850,7 +674,7 @@ Parfactor::align (
|
|||||||
theta2.add (allLvs2[i], freeLogVar);
|
theta2.add (allLvs2[i], freeLogVar);
|
||||||
++ freeLogVar;
|
++ freeLogVar;
|
||||||
}
|
}
|
||||||
|
|
||||||
assert (alignLvs1.size() == alignLvs2.size());
|
assert (alignLvs1.size() == alignLvs2.size());
|
||||||
for (unsigned i = 0; i < alignLvs1.size(); i++) {
|
for (unsigned i = 0; i < alignLvs1.size(); i++) {
|
||||||
theta1.rename (alignLvs1[i], theta2.newNameFor (alignLvs2[i]));
|
theta1.rename (alignLvs1[i], theta2.newNameFor (alignLvs2[i]));
|
||||||
|
@ -9,8 +9,9 @@
|
|||||||
#include "LiftedUtils.h"
|
#include "LiftedUtils.h"
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
#include "Factor.h"
|
||||||
|
|
||||||
class Parfactor
|
class Parfactor : public TFactor<ProbFormula>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Parfactor (
|
Parfactor (
|
||||||
@ -18,27 +19,15 @@ class Parfactor
|
|||||||
const Params&,
|
const Params&,
|
||||||
const Tuples&,
|
const Tuples&,
|
||||||
unsigned);
|
unsigned);
|
||||||
|
|
||||||
Parfactor (const Parfactor*, const Tuple&);
|
Parfactor (const Parfactor*, const Tuple&);
|
||||||
|
|
||||||
Parfactor (const Parfactor*, ConstraintTree*);
|
Parfactor (const Parfactor*, ConstraintTree*);
|
||||||
|
|
||||||
Parfactor (const Parfactor&);
|
Parfactor (const Parfactor&);
|
||||||
|
|
||||||
~Parfactor (void);
|
~Parfactor (void);
|
||||||
|
|
||||||
ProbFormulas& formulas (void) { return formulas_; }
|
|
||||||
|
|
||||||
const ProbFormulas& formulas (void) const { return formulas_; }
|
|
||||||
|
|
||||||
unsigned nrFormulas (void) const { return formulas_.size(); }
|
|
||||||
|
|
||||||
Params& params (void) { return params_; }
|
|
||||||
|
|
||||||
const Params& params (void) const { return params_; }
|
|
||||||
|
|
||||||
unsigned size (void) const { return params_.size(); }
|
|
||||||
|
|
||||||
const Ranges& ranges (void) const { return ranges_; }
|
|
||||||
|
|
||||||
unsigned distId (void) const { return distId_; }
|
|
||||||
|
|
||||||
ConstraintTree* constr (void) { return constr_; }
|
ConstraintTree* constr (void) { return constr_; }
|
||||||
|
|
||||||
const ConstraintTree* constr (void) const { return constr_; }
|
const ConstraintTree* constr (void) const { return constr_; }
|
||||||
@ -57,64 +46,52 @@ class Parfactor
|
|||||||
|
|
||||||
void setConstraintTree (ConstraintTree*);
|
void setConstraintTree (ConstraintTree*);
|
||||||
|
|
||||||
void sumOut (unsigned);
|
void sumOut (unsigned fIdx);
|
||||||
|
|
||||||
void multiply (Parfactor&);
|
void multiply (Parfactor&);
|
||||||
|
|
||||||
void countConvert (LogVar);
|
void countConvert (LogVar);
|
||||||
|
|
||||||
void expandPotential (LogVar, LogVar, LogVar);
|
void expand (LogVar, LogVar, LogVar);
|
||||||
|
|
||||||
void fullExpand (LogVar);
|
void fullExpand (LogVar);
|
||||||
|
|
||||||
void reorderAccordingGrounds (const Grounds&);
|
void reorderAccordingGrounds (const Grounds&);
|
||||||
|
|
||||||
void reorderFormulas (const ProbFormulas&);
|
void absorveEvidence (const ProbFormula&, unsigned);
|
||||||
|
|
||||||
void absorveEvidence (unsigned, unsigned);
|
|
||||||
|
|
||||||
void normalize (void);
|
|
||||||
|
|
||||||
void setFormulaGroup (const ProbFormula&, int);
|
|
||||||
|
|
||||||
void setNewGroups (void);
|
void setNewGroups (void);
|
||||||
|
|
||||||
void applySubstitution (const Substitution&);
|
void applySubstitution (const Substitution&);
|
||||||
|
|
||||||
|
int findGroup (const Ground&) const;
|
||||||
|
|
||||||
bool containsGround (const Ground&) const;
|
bool containsGround (const Ground&) const;
|
||||||
|
|
||||||
bool containsGroup (unsigned) const;
|
bool containsGroup (unsigned) const;
|
||||||
|
|
||||||
const ProbFormula& formula (unsigned) const;
|
|
||||||
|
|
||||||
unsigned range (unsigned) const;
|
|
||||||
|
|
||||||
unsigned nrFormulas (LogVar) const;
|
unsigned nrFormulas (LogVar) const;
|
||||||
|
|
||||||
int indexOf (const ProbFormula&) const;
|
int indexOfLogVar (LogVar) const;
|
||||||
|
|
||||||
int indexOfFormulaWithLogVar (LogVar) const;
|
int indexOfGroup (unsigned) const;
|
||||||
|
|
||||||
int indexOfFormulaWithGroup (unsigned) const;
|
|
||||||
|
|
||||||
vector<unsigned> getAllGroups (void) const;
|
vector<unsigned> getAllGroups (void) const;
|
||||||
|
|
||||||
void print (bool = false) const;
|
void print (bool = false) const;
|
||||||
|
|
||||||
string getHeaderString (void) const;
|
string getLabel (void) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void expandPotential (int fIdx, unsigned newRange,
|
||||||
|
const vector<unsigned>& sumIndexes);
|
||||||
|
|
||||||
static void alignAndExponentiate (Parfactor*, Parfactor*);
|
static void alignAndExponentiate (Parfactor*, Parfactor*);
|
||||||
|
|
||||||
static void align (
|
static void align (
|
||||||
Parfactor*, const LogVars&, Parfactor*, const LogVars&);
|
Parfactor*, const LogVars&, Parfactor*, const LogVars&);
|
||||||
|
|
||||||
void insertDimension (unsigned);
|
|
||||||
|
|
||||||
ProbFormulas formulas_;
|
ConstraintTree* constr_;
|
||||||
Ranges ranges_;
|
|
||||||
Params params_;
|
|
||||||
unsigned distId_;
|
|
||||||
ConstraintTree* constr_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,9 +3,32 @@
|
|||||||
#include "ParfactorList.h"
|
#include "ParfactorList.h"
|
||||||
|
|
||||||
|
|
||||||
ParfactorList::ParfactorList (Parfactors& pfs)
|
ParfactorList::ParfactorList (const ParfactorList& pfList)
|
||||||
{
|
{
|
||||||
pfList_.insert (pfList_.end(), pfs.begin(), pfs.end());
|
ParfactorList::const_iterator it = pfList.begin();
|
||||||
|
while (it != pfList.end()) {
|
||||||
|
addShattered (new Parfactor (**it));
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParfactorList::ParfactorList (const Parfactors& pfs)
|
||||||
|
{
|
||||||
|
add (pfs);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ParfactorList::~ParfactorList (void)
|
||||||
|
{
|
||||||
|
ParfactorList::const_iterator it = pfList_.begin();
|
||||||
|
while (it != pfList_.end()) {
|
||||||
|
delete *it;
|
||||||
|
++ it;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -14,17 +37,17 @@ void
|
|||||||
ParfactorList::add (Parfactor* pf)
|
ParfactorList::add (Parfactor* pf)
|
||||||
{
|
{
|
||||||
pf->setNewGroups();
|
pf->setNewGroups();
|
||||||
pfList_.push_back (pf);
|
addToShatteredList (pf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
ParfactorList::add (Parfactors& pfs)
|
ParfactorList::add (const Parfactors& pfs)
|
||||||
{
|
{
|
||||||
for (unsigned i = 0; i < pfs.size(); i++) {
|
for (unsigned i = 0; i < pfs.size(); i++) {
|
||||||
pfs[i]->setNewGroups();
|
pfs[i]->setNewGroups();
|
||||||
pfList_.push_back (pfs[i]);
|
addToShatteredList (pfs[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -33,7 +56,20 @@ ParfactorList::add (Parfactors& pfs)
|
|||||||
void
|
void
|
||||||
ParfactorList::addShattered (Parfactor* pf)
|
ParfactorList::addShattered (Parfactor* pf)
|
||||||
{
|
{
|
||||||
|
assert (isAllShattered());
|
||||||
pfList_.push_back (pf);
|
pfList_.push_back (pf);
|
||||||
|
assert (isAllShattered());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
list<Parfactor*>::iterator
|
||||||
|
ParfactorList::insertShattered (
|
||||||
|
list<Parfactor*>::iterator it,
|
||||||
|
Parfactor* pf)
|
||||||
|
{
|
||||||
|
return pfList_.insert (it, pf);
|
||||||
|
assert (isAllShattered());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -47,7 +83,7 @@ ParfactorList::remove (list<Parfactor*>::iterator it)
|
|||||||
|
|
||||||
|
|
||||||
list<Parfactor*>::iterator
|
list<Parfactor*>::iterator
|
||||||
ParfactorList::deleteAndRemove (list<Parfactor*>::iterator it)
|
ParfactorList::removeAndDelete (list<Parfactor*>::iterator it)
|
||||||
{
|
{
|
||||||
delete *it;
|
delete *it;
|
||||||
return pfList_.erase (it);
|
return pfList_.erase (it);
|
||||||
@ -55,58 +91,21 @@ ParfactorList::deleteAndRemove (list<Parfactor*>::iterator it)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
bool
|
||||||
ParfactorList::shatter (void)
|
ParfactorList::isAllShattered (void) const
|
||||||
{
|
{
|
||||||
list<Parfactor*> tempList;
|
if (pfList_.size() <= 1) {
|
||||||
Parfactors newPfs;
|
return true;
|
||||||
newPfs.insert (newPfs.end(), pfList_.begin(), pfList_.end());
|
}
|
||||||
while (newPfs.empty() == false) {
|
vector<Parfactor*> pfs (pfList_.begin(), pfList_.end());
|
||||||
tempList.insert (tempList.end(), newPfs.begin(), newPfs.end());
|
for (unsigned i = 0; i < pfs.size() - 1; i++) {
|
||||||
newPfs.clear();
|
for (unsigned j = i + 1; j < pfs.size(); j++) {
|
||||||
list<Parfactor*>::iterator iter1 = tempList.begin();
|
if (isShattered (pfs[i], pfs[j]) == false) {
|
||||||
while (tempList.size() > 1 && iter1 != -- tempList.end()) {
|
return false;
|
||||||
list<Parfactor*>::iterator iter2 = iter1;
|
|
||||||
++ iter2;
|
|
||||||
bool incIter1 = true;
|
|
||||||
while (iter2 != tempList.end()) {
|
|
||||||
assert (iter1 != iter2);
|
|
||||||
std::pair<Parfactors, Parfactors> res = shatter (
|
|
||||||
(*iter1)->formulas(), *iter1, (*iter2)->formulas(), *iter2);
|
|
||||||
bool incIter2 = true;
|
|
||||||
if (res.second.empty() == false) {
|
|
||||||
// cout << "second unshattered" << endl;
|
|
||||||
delete *iter2;
|
|
||||||
iter2 = tempList.erase (iter2);
|
|
||||||
incIter2 = false;
|
|
||||||
newPfs.insert (
|
|
||||||
newPfs.begin(), res.second.begin(), res.second.end());
|
|
||||||
}
|
|
||||||
if (res.first.empty() == false) {
|
|
||||||
// cout << "first unshattered" << endl;
|
|
||||||
delete *iter1;
|
|
||||||
iter1 = tempList.erase (iter1);
|
|
||||||
newPfs.insert (
|
|
||||||
newPfs.begin(), res.first.begin(), res.first.end());
|
|
||||||
incIter1 = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (incIter2) {
|
|
||||||
++ iter2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (incIter1) {
|
|
||||||
++ iter1;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
|
|
||||||
// cout << "||||||||||||| SHATTERING ITERATION ||||||||||||||" << endl;
|
|
||||||
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
|
|
||||||
// printParfactors (newPfs);
|
|
||||||
// cout << "|||||||||||||||||||||||||||||||||||||||||||||||||" << endl;
|
|
||||||
}
|
}
|
||||||
pfList_.clear();
|
return true;
|
||||||
pfList_.insert (pfList_.end(), tempList.begin(), tempList.end());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -117,25 +116,88 @@ ParfactorList::print (void) const
|
|||||||
list<Parfactor*>::const_iterator it;
|
list<Parfactor*>::const_iterator it;
|
||||||
for (it = pfList_.begin(); it != pfList_.end(); ++it) {
|
for (it = pfList_.begin(); it != pfList_.end(); ++it) {
|
||||||
(*it)->print();
|
(*it)->print();
|
||||||
cout << endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
std::pair<Parfactors, Parfactors>
|
bool
|
||||||
ParfactorList::shatter (
|
ParfactorList::isShattered (
|
||||||
ProbFormulas& formulas1,
|
const Parfactor* g1,
|
||||||
Parfactor* g1,
|
const Parfactor* g2) const
|
||||||
ProbFormulas& formulas2,
|
|
||||||
Parfactor* g2)
|
|
||||||
{
|
{
|
||||||
|
assert (g1 != g2);
|
||||||
|
const ProbFormulas& fms1 = g1->arguments();
|
||||||
|
const ProbFormulas& fms2 = g2->arguments();
|
||||||
|
for (unsigned i = 0; i < fms1.size(); i++) {
|
||||||
|
for (unsigned j = 0; j < fms2.size(); j++) {
|
||||||
|
if (fms1[i].group() == fms2[j].group()) {
|
||||||
|
if (identical (
|
||||||
|
fms1[i], *(g1->constr()),
|
||||||
|
fms2[j], *(g2->constr())) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (disjoint (
|
||||||
|
fms1[i], *(g1->constr()),
|
||||||
|
fms2[j], *(g2->constr())) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
ParfactorList::addToShatteredList (Parfactor* g)
|
||||||
|
{
|
||||||
|
queue<Parfactor*> residuals;
|
||||||
|
residuals.push (g);
|
||||||
|
while (residuals.empty() == false) {
|
||||||
|
Parfactor* pf = residuals.front();
|
||||||
|
bool pfSplitted = false;
|
||||||
|
list<Parfactor*>::iterator pfIter;
|
||||||
|
pfIter = pfList_.begin();
|
||||||
|
while (pfIter != pfList_.end()) {
|
||||||
|
std::pair<Parfactors, Parfactors> shattRes;
|
||||||
|
shattRes = shatter (*pfIter, pf);
|
||||||
|
if (shattRes.first.empty() == false) {
|
||||||
|
pfIter = removeAndDelete (pfIter);
|
||||||
|
Util::addToQueue (residuals, shattRes.first);
|
||||||
|
} else {
|
||||||
|
++ pfIter;
|
||||||
|
}
|
||||||
|
if (shattRes.second.empty() == false) {
|
||||||
|
delete pf;
|
||||||
|
Util::addToQueue (residuals, shattRes.second);
|
||||||
|
pfSplitted = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
residuals.pop();
|
||||||
|
if (pfSplitted == false) {
|
||||||
|
addShattered (pf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (isAllShattered());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::pair<Parfactors, Parfactors>
|
||||||
|
ParfactorList::shatter (Parfactor* g1, Parfactor* g2)
|
||||||
|
{
|
||||||
|
ProbFormulas& formulas1 = g1->arguments();
|
||||||
|
ProbFormulas& formulas2 = g2->arguments();
|
||||||
assert (g1 != 0 && g2 != 0 && g1 != g2);
|
assert (g1 != 0 && g2 != 0 && g1 != g2);
|
||||||
for (unsigned i = 0; i < formulas1.size(); i++) {
|
for (unsigned i = 0; i < formulas1.size(); i++) {
|
||||||
for (unsigned j = 0; j < formulas2.size(); j++) {
|
for (unsigned j = 0; j < formulas2.size(); j++) {
|
||||||
if (formulas1[i].sameSkeletonAs (formulas2[j])) {
|
if (formulas1[i].sameSkeletonAs (formulas2[j])) {
|
||||||
std::pair<Parfactors, Parfactors> res
|
std::pair<Parfactors, Parfactors> res;
|
||||||
= shatter (formulas1[i], g1, formulas2[j], g2);
|
res = shatter (i, g1, j, g2);
|
||||||
if (res.first.empty() == false ||
|
if (res.first.empty() == false ||
|
||||||
res.second.empty() == false) {
|
res.second.empty() == false) {
|
||||||
return res;
|
return res;
|
||||||
@ -150,21 +212,22 @@ ParfactorList::shatter (
|
|||||||
|
|
||||||
std::pair<Parfactors, Parfactors>
|
std::pair<Parfactors, Parfactors>
|
||||||
ParfactorList::shatter (
|
ParfactorList::shatter (
|
||||||
ProbFormula& f1,
|
unsigned fIdx1, Parfactor* g1,
|
||||||
Parfactor* g1,
|
unsigned fIdx2, Parfactor* g2)
|
||||||
ProbFormula& f2,
|
|
||||||
Parfactor* g2)
|
|
||||||
{
|
{
|
||||||
|
ProbFormula& f1 = g1->argument (fIdx1);
|
||||||
|
ProbFormula& f2 = g2->argument (fIdx2);
|
||||||
// cout << endl;
|
// cout << endl;
|
||||||
// cout << "-------------------------------------------------" << endl;
|
// Util::printDashedLine();
|
||||||
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
|
// cout << "-> SHATTERING (#" << g1 << ", #" << g2 << ")" << endl;
|
||||||
// g1->print();
|
// g1->print();
|
||||||
// cout << "-> WITH" << endl;
|
// cout << "-> WITH" << endl;
|
||||||
// g2->print();
|
// g2->print();
|
||||||
// cout << "-> ON: " << f1.toString (g1->constr()) << endl;
|
// cout << "-> ON: " << f1 << "|" ;
|
||||||
// cout << "-> ON: " << f2.toString (g2->constr()) << endl;
|
// cout << g1->constr()->tupleSet (f1.logVars()) << endl;
|
||||||
// cout << "-------------------------------------------------" << endl;
|
// cout << "-> ON: " << f2 << "|" ;
|
||||||
|
// cout << g2->constr()->tupleSet (f2.logVars()) << endl;
|
||||||
|
// Util::printDashedLine();
|
||||||
if (f1.isAtom()) {
|
if (f1.isAtom()) {
|
||||||
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
|
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
|
||||||
f1.setGroup (group);
|
f1.setGroup (group);
|
||||||
@ -174,7 +237,7 @@ ParfactorList::shatter (
|
|||||||
assert (g1->constr()->empty() == false);
|
assert (g1->constr()->empty() == false);
|
||||||
assert (g2->constr()->empty() == false);
|
assert (g2->constr()->empty() == false);
|
||||||
if (f1.group() == f2.group()) {
|
if (f1.group() == f2.group()) {
|
||||||
// assert (identical (f1, g1->constr(), f2, g2->constr()));
|
assert (identical (f1, *(g1->constr()), f2, *(g2->constr())));
|
||||||
return { };
|
return { };
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -201,21 +264,24 @@ ParfactorList::shatter (
|
|||||||
assert (commCt1->tupleSet (f1.arity()) ==
|
assert (commCt1->tupleSet (f1.arity()) ==
|
||||||
commCt2->tupleSet (f2.arity()));
|
commCt2->tupleSet (f2.arity()));
|
||||||
|
|
||||||
// stringstream ss1; ss1 << "" << count << "_A.dot" ;
|
// unsigned static count = 0; count ++;
|
||||||
// stringstream ss2; ss2 << "" << count << "_B.dot" ;
|
// stringstream ss1; ss1 << "" << count << "_A.dot" ;
|
||||||
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
|
// stringstream ss2; ss2 << "" << count << "_B.dot" ;
|
||||||
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
|
// stringstream ss3; ss3 << "" << count << "_A_comm.dot" ;
|
||||||
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
|
// stringstream ss4; ss4 << "" << count << "_A_excl.dot" ;
|
||||||
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
|
// stringstream ss5; ss5 << "" << count << "_B_comm.dot" ;
|
||||||
// ct1->exportToGraphViz (ss1.str().c_str(), true);
|
// stringstream ss6; ss6 << "" << count << "_B_excl.dot" ;
|
||||||
// ct2->exportToGraphViz (ss2.str().c_str(), true);
|
// g1->constr()->exportToGraphViz (ss1.str().c_str(), true);
|
||||||
// commCt1->exportToGraphViz (ss3.str().c_str(), true);
|
// g2->constr()->exportToGraphViz (ss2.str().c_str(), true);
|
||||||
// exclCt1->exportToGraphViz (ss4.str().c_str(), true);
|
// commCt1->exportToGraphViz (ss3.str().c_str(), true);
|
||||||
// commCt2->exportToGraphViz (ss5.str().c_str(), true);
|
// exclCt1->exportToGraphViz (ss4.str().c_str(), true);
|
||||||
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
|
// commCt2->exportToGraphViz (ss5.str().c_str(), true);
|
||||||
|
// exclCt2->exportToGraphViz (ss6.str().c_str(), true);
|
||||||
|
|
||||||
if (exclCt1->empty() && exclCt2->empty()) {
|
if (exclCt1->empty() && exclCt2->empty()) {
|
||||||
unsigned group = (f1.group() < f2.group()) ? f1.group() : f2.group();
|
unsigned group = (f1.group() < f2.group())
|
||||||
|
? f1.group()
|
||||||
|
: f2.group();
|
||||||
// identical
|
// identical
|
||||||
f1.setGroup (group);
|
f1.setGroup (group);
|
||||||
f2.setGroup (group);
|
f2.setGroup (group);
|
||||||
@ -235,8 +301,8 @@ ParfactorList::shatter (
|
|||||||
} else {
|
} else {
|
||||||
group = ProbFormula::getNewGroup();
|
group = ProbFormula::getNewGroup();
|
||||||
}
|
}
|
||||||
Parfactors res1 = shatter (g1, f1, commCt1, exclCt1, group);
|
Parfactors res1 = shatter (g1, fIdx1, commCt1, exclCt1, group);
|
||||||
Parfactors res2 = shatter (g2, f2, commCt2, exclCt2, group);
|
Parfactors res2 = shatter (g2, fIdx2, commCt2, exclCt2, group);
|
||||||
return make_pair (res1, res2);
|
return make_pair (res1, res2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,11 +311,19 @@ ParfactorList::shatter (
|
|||||||
Parfactors
|
Parfactors
|
||||||
ParfactorList::shatter (
|
ParfactorList::shatter (
|
||||||
Parfactor* g,
|
Parfactor* g,
|
||||||
const ProbFormula& f,
|
unsigned fIdx,
|
||||||
ConstraintTree* commCt,
|
ConstraintTree* commCt,
|
||||||
ConstraintTree* exclCt,
|
ConstraintTree* exclCt,
|
||||||
unsigned commGroup)
|
unsigned commGroup)
|
||||||
{
|
{
|
||||||
|
ProbFormula& f = g->argument (fIdx);
|
||||||
|
if (exclCt->empty()) {
|
||||||
|
delete commCt;
|
||||||
|
delete exclCt;
|
||||||
|
f.setGroup (commGroup);
|
||||||
|
return { };
|
||||||
|
}
|
||||||
|
|
||||||
Parfactors result;
|
Parfactors result;
|
||||||
if (f.isCounting()) {
|
if (f.isCounting()) {
|
||||||
LogVar X_new1 = g->constr()->logVarSet().back() + 1;
|
LogVar X_new1 = g->constr()->logVarSet().back() + 1;
|
||||||
@ -259,7 +333,7 @@ ParfactorList::shatter (
|
|||||||
for (unsigned i = 0; i < cts.size(); i++) {
|
for (unsigned i = 0; i < cts.size(); i++) {
|
||||||
Parfactor* newPf = new Parfactor (g, cts[i]);
|
Parfactor* newPf = new Parfactor (g, cts[i]);
|
||||||
if (cts[i]->nrLogVars() == g->constr()->nrLogVars() + 1) {
|
if (cts[i]->nrLogVars() == g->constr()->nrLogVars() + 1) {
|
||||||
newPf->expandPotential (f.countedLogVar(), X_new1, X_new2);
|
newPf->expand (f.countedLogVar(), X_new1, X_new2);
|
||||||
assert (g->constr()->getConditionalCount (f.countedLogVar()) ==
|
assert (g->constr()->getConditionalCount (f.countedLogVar()) ==
|
||||||
cts[i]->getConditionalCount (X_new1) +
|
cts[i]->getConditionalCount (X_new1) +
|
||||||
cts[i]->getConditionalCount (X_new2));
|
cts[i]->getConditionalCount (X_new2));
|
||||||
@ -270,20 +344,16 @@ ParfactorList::shatter (
|
|||||||
newPf->setNewGroups();
|
newPf->setNewGroups();
|
||||||
result.push_back (newPf);
|
result.push_back (newPf);
|
||||||
}
|
}
|
||||||
|
delete commCt;
|
||||||
|
delete exclCt;
|
||||||
} else {
|
} else {
|
||||||
if (exclCt->empty()) {
|
Parfactor* newPf = new Parfactor (g, commCt);
|
||||||
delete commCt;
|
newPf->setNewGroups();
|
||||||
delete exclCt;
|
newPf->argument (fIdx).setGroup (commGroup);
|
||||||
g->setFormulaGroup (f, commGroup);
|
result.push_back (newPf);
|
||||||
} else {
|
newPf = new Parfactor (g, exclCt);
|
||||||
Parfactor* newPf = new Parfactor (g, commCt);
|
newPf->setNewGroups();
|
||||||
newPf->setNewGroups();
|
result.push_back (newPf);
|
||||||
newPf->setFormulaGroup (f, commGroup);
|
|
||||||
result.push_back (newPf);
|
|
||||||
newPf = new Parfactor (g, exclCt);
|
|
||||||
newPf->setNewGroups();
|
|
||||||
result.push_back (newPf);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -296,7 +366,7 @@ ParfactorList::unifyGroups (unsigned group1, unsigned group2)
|
|||||||
unsigned newGroup = ProbFormula::getNewGroup();
|
unsigned newGroup = ProbFormula::getNewGroup();
|
||||||
for (ParfactorList::iterator it = pfList_.begin();
|
for (ParfactorList::iterator it = pfList_.begin();
|
||||||
it != pfList_.end(); it++) {
|
it != pfList_.end(); it++) {
|
||||||
ProbFormulas& formulas = (*it)->formulas();
|
ProbFormulas& formulas = (*it)->arguments();
|
||||||
for (unsigned i = 0; i < formulas.size(); i++) {
|
for (unsigned i = 0; i < formulas.size(); i++) {
|
||||||
if (formulas[i].group() == group1 ||
|
if (formulas[i].group() == group1 ||
|
||||||
formulas[i].group() == group2) {
|
formulas[i].group() == group2) {
|
||||||
@ -306,3 +376,52 @@ ParfactorList::unifyGroups (unsigned group1, unsigned group2)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ParfactorList::proper (
|
||||||
|
const ProbFormula& f1, ConstraintTree c1,
|
||||||
|
const ProbFormula& f2, ConstraintTree c2) const
|
||||||
|
{
|
||||||
|
return disjoint (f1, c1, f2, c2)
|
||||||
|
|| identical (f1, c1, f2, c2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ParfactorList::identical (
|
||||||
|
const ProbFormula& f1, ConstraintTree c1,
|
||||||
|
const ProbFormula& f2, ConstraintTree c2) const
|
||||||
|
{
|
||||||
|
if (f1.sameSkeletonAs (f2) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (f1.isAtom()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
c1.moveToTop (f1.logVars());
|
||||||
|
c2.moveToTop (f2.logVars());
|
||||||
|
return ConstraintTree::identical (
|
||||||
|
&c1, &c2, f1.logVars().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
ParfactorList::disjoint (
|
||||||
|
const ProbFormula& f1, ConstraintTree c1,
|
||||||
|
const ProbFormula& f2, ConstraintTree c2) const
|
||||||
|
{
|
||||||
|
if (f1.sameSkeletonAs (f2) == false) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (f1.isAtom()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
c1.moveToTop (f1.logVars());
|
||||||
|
c2.moveToTop (f2.logVars());
|
||||||
|
return ConstraintTree::overlap (
|
||||||
|
&c1, &c2, f1.arity()) == false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#define HORUS_PARFACTORLIST_H
|
#define HORUS_PARFACTORLIST_H
|
||||||
|
|
||||||
#include <list>
|
#include <list>
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
#include "Parfactor.h"
|
#include "Parfactor.h"
|
||||||
#include "ProbFormula.h"
|
#include "ProbFormula.h"
|
||||||
@ -14,56 +15,82 @@ class ParfactorList
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
ParfactorList (void) { }
|
ParfactorList (void) { }
|
||||||
ParfactorList (Parfactors&);
|
|
||||||
list<Parfactor*>& getParfactors (void) { return pfList_; }
|
|
||||||
const list<Parfactor*>& getParfactors (void) const { return pfList_; }
|
|
||||||
|
|
||||||
void add (Parfactor* pf);
|
ParfactorList (const ParfactorList&);
|
||||||
void add (Parfactors& pfs);
|
|
||||||
void addShattered (Parfactor* pf);
|
|
||||||
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
|
|
||||||
list<Parfactor*>::iterator deleteAndRemove (list<Parfactor*>::iterator);
|
|
||||||
|
|
||||||
void clear (void) { pfList_.clear(); }
|
ParfactorList (const Parfactors&);
|
||||||
unsigned size (void) const { return pfList_.size(); }
|
|
||||||
|
|
||||||
|
|
||||||
void shatter (void);
|
~ParfactorList (void);
|
||||||
|
|
||||||
|
const list<Parfactor*>& parfactors (void) const { return pfList_; }
|
||||||
|
|
||||||
|
void clear (void) { pfList_.clear(); }
|
||||||
|
|
||||||
|
unsigned size (void) const { return pfList_.size(); }
|
||||||
|
|
||||||
typedef std::list<Parfactor*>::iterator iterator;
|
typedef std::list<Parfactor*>::iterator iterator;
|
||||||
|
|
||||||
iterator begin (void) { return pfList_.begin(); }
|
iterator begin (void) { return pfList_.begin(); }
|
||||||
iterator end (void) { return pfList_.end(); }
|
|
||||||
|
iterator end (void) { return pfList_.end(); }
|
||||||
|
|
||||||
typedef std::list<Parfactor*>::const_iterator const_iterator;
|
typedef std::list<Parfactor*>::const_iterator const_iterator;
|
||||||
|
|
||||||
const_iterator begin (void) const { return pfList_.begin(); }
|
const_iterator begin (void) const { return pfList_.begin(); }
|
||||||
const_iterator end (void) const { return pfList_.end(); }
|
|
||||||
|
const_iterator end (void) const { return pfList_.end(); }
|
||||||
|
|
||||||
|
void add (Parfactor* pf);
|
||||||
|
|
||||||
|
void add (const Parfactors& pfs);
|
||||||
|
|
||||||
|
void addShattered (Parfactor* pf);
|
||||||
|
|
||||||
|
list<Parfactor*>::iterator insertShattered (
|
||||||
|
list<Parfactor*>::iterator, Parfactor*);
|
||||||
|
|
||||||
|
list<Parfactor*>::iterator remove (list<Parfactor*>::iterator);
|
||||||
|
|
||||||
|
list<Parfactor*>::iterator removeAndDelete (list<Parfactor*>::iterator);
|
||||||
|
|
||||||
|
bool isAllShattered (void) const;
|
||||||
|
|
||||||
void print (void) const;
|
void print (void) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
|
bool isShattered (const Parfactor*, const Parfactor*) const;
|
||||||
|
|
||||||
static std::pair<Parfactors, Parfactors> shatter (
|
void addToShatteredList (Parfactor*);
|
||||||
ProbFormulas&,
|
|
||||||
Parfactor*,
|
std::pair<Parfactors, Parfactors> shatter (
|
||||||
ProbFormulas&,
|
Parfactor*, Parfactor*);
|
||||||
Parfactor*);
|
|
||||||
|
|
||||||
static std::pair<Parfactors, Parfactors> shatter (
|
std::pair<Parfactors, Parfactors> shatter (
|
||||||
ProbFormula&,
|
unsigned, Parfactor*, unsigned, Parfactor*);
|
||||||
Parfactor*,
|
|
||||||
ProbFormula&,
|
|
||||||
Parfactor*);
|
|
||||||
|
|
||||||
static Parfactors shatter (
|
Parfactors shatter (
|
||||||
Parfactor*,
|
Parfactor*,
|
||||||
const ProbFormula&,
|
unsigned,
|
||||||
ConstraintTree*,
|
ConstraintTree*,
|
||||||
ConstraintTree*,
|
ConstraintTree*,
|
||||||
unsigned);
|
unsigned);
|
||||||
|
|
||||||
void unifyGroups (unsigned group1, unsigned group2);
|
void unifyGroups (unsigned group1, unsigned group2);
|
||||||
|
|
||||||
list<Parfactor*> pfList_;
|
bool proper (
|
||||||
|
const ProbFormula&, ConstraintTree,
|
||||||
|
const ProbFormula&, ConstraintTree) const;
|
||||||
|
|
||||||
|
bool identical (
|
||||||
|
const ProbFormula&, ConstraintTree,
|
||||||
|
const ProbFormula&, ConstraintTree) const;
|
||||||
|
|
||||||
|
bool disjoint (
|
||||||
|
const ProbFormula&, ConstraintTree,
|
||||||
|
const ProbFormula&, ConstraintTree) const;
|
||||||
|
|
||||||
|
list<Parfactor*> pfList_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_PARFACTORLIST_H
|
#endif // HORUS_PARFACTORLIST_H
|
||||||
|
@ -16,8 +16,7 @@ ProbFormula::sameSkeletonAs (const ProbFormula& f) const
|
|||||||
bool
|
bool
|
||||||
ProbFormula::contains (LogVar lv) const
|
ProbFormula::contains (LogVar lv) const
|
||||||
{
|
{
|
||||||
return std::find (logVars_.begin(), logVars_.end(), lv) !=
|
return Util::contains (logVars_, lv);
|
||||||
logVars_.end();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -77,16 +76,15 @@ ProbFormula::rename (LogVar oldName, LogVar newName)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool operator== (const ProbFormula& f1, const ProbFormula& f2)
|
||||||
bool
|
|
||||||
ProbFormula::operator== (const ProbFormula& f) const
|
|
||||||
{
|
{
|
||||||
return functor_ == f.functor_ && logVars_ == f.logVars_ ;
|
return f1.group_ == f2.group_;
|
||||||
|
//return functor_ == f.functor_ && logVars_ == f.logVars_ ;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ostream& operator<< (ostream &os, const ProbFormula& f)
|
std::ostream& operator<< (ostream &os, const ProbFormula& f)
|
||||||
{
|
{
|
||||||
os << f.functor_;
|
os << f.functor_;
|
||||||
if (f.isAtom() == false) {
|
if (f.isAtom() == false) {
|
||||||
@ -113,3 +111,13 @@ ProbFormula::getNewGroup (void)
|
|||||||
return freeGroup_;
|
return freeGroup_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ostream& operator<< (ostream &os, const ObservedFormula& of)
|
||||||
|
{
|
||||||
|
os << of.functor_ << "/" << of.arity_;
|
||||||
|
os << "|" << of.constr_.tupleSet();
|
||||||
|
os << " [evidence=" << of.evidence_ << "]";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -8,14 +8,16 @@
|
|||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ProbFormula
|
class ProbFormula
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
|
ProbFormula (Symbol f, const LogVars& lvs, unsigned range)
|
||||||
: functor_(f), logVars_(lvs), range_(range),
|
: functor_(f), logVars_(lvs), range_(range),
|
||||||
countedLogVar_() { }
|
countedLogVar_(), group_(Util::maxUnsigned()) { }
|
||||||
|
|
||||||
ProbFormula (Symbol f, unsigned r) : functor_(f), range_(r) { }
|
ProbFormula (Symbol f, unsigned r)
|
||||||
|
: functor_(f), range_(r), group_(Util::maxUnsigned()) { }
|
||||||
|
|
||||||
Symbol functor (void) const { return functor_; }
|
Symbol functor (void) const { return functor_; }
|
||||||
|
|
||||||
@ -29,9 +31,9 @@ class ProbFormula
|
|||||||
|
|
||||||
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
LogVarSet logVarSet (void) const { return LogVarSet (logVars_); }
|
||||||
|
|
||||||
unsigned group (void) const { return groupId_; }
|
unsigned group (void) const { return group_; }
|
||||||
|
|
||||||
void setGroup (unsigned g) { groupId_ = g; }
|
void setGroup (unsigned g) { group_ = g; }
|
||||||
|
|
||||||
bool sameSkeletonAs (const ProbFormula&) const;
|
bool sameSkeletonAs (const ProbFormula&) const;
|
||||||
|
|
||||||
@ -49,23 +51,58 @@ class ProbFormula
|
|||||||
|
|
||||||
void rename (LogVar, LogVar);
|
void rename (LogVar, LogVar);
|
||||||
|
|
||||||
bool operator== (const ProbFormula& f) const;
|
|
||||||
|
|
||||||
friend ostream& operator<< (ostream &out, const ProbFormula& f);
|
|
||||||
|
|
||||||
static unsigned getNewGroup (void);
|
static unsigned getNewGroup (void);
|
||||||
|
|
||||||
|
friend std::ostream& operator<< (ostream &os, const ProbFormula& f);
|
||||||
|
|
||||||
|
friend bool operator== (const ProbFormula& f1, const ProbFormula& f2);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Symbol functor_;
|
Symbol functor_;
|
||||||
LogVars logVars_;
|
LogVars logVars_;
|
||||||
unsigned range_;
|
unsigned range_;
|
||||||
LogVar countedLogVar_;
|
LogVar countedLogVar_;
|
||||||
unsigned groupId_;
|
unsigned group_;
|
||||||
static int freeGroup_;
|
static int freeGroup_;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef vector<ProbFormula> ProbFormulas;
|
typedef vector<ProbFormula> ProbFormulas;
|
||||||
|
|
||||||
|
|
||||||
|
class ObservedFormula
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
ObservedFormula (Symbol f, unsigned a, unsigned ev)
|
||||||
|
: functor_(f), arity_(a), evidence_(ev), constr_(a) { }
|
||||||
|
|
||||||
|
ObservedFormula (Symbol f, unsigned ev, const Tuple& tuple)
|
||||||
|
: functor_(f), arity_(tuple.size()), evidence_(ev), constr_(arity_)
|
||||||
|
{
|
||||||
|
constr_.addTuple (tuple);
|
||||||
|
}
|
||||||
|
|
||||||
|
Symbol functor (void) const { return functor_; }
|
||||||
|
|
||||||
|
unsigned arity (void) const { return arity_; }
|
||||||
|
|
||||||
|
unsigned evidence (void) const { return evidence_; }
|
||||||
|
|
||||||
|
ConstraintTree& constr (void) { return constr_; }
|
||||||
|
|
||||||
|
bool isAtom (void) const { return arity_ == 0; }
|
||||||
|
|
||||||
|
void addTuple (const Tuple& tuple) { constr_.addTuple (tuple); }
|
||||||
|
|
||||||
|
friend ostream& operator<< (ostream &os, const ObservedFormula& of);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Symbol functor_;
|
||||||
|
unsigned arity_;
|
||||||
|
unsigned evidence_;
|
||||||
|
ConstraintTree constr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef vector<ObservedFormula> ObservedFormulas;
|
||||||
|
|
||||||
#endif // HORUS_PROBFORMULA_H
|
#endif // HORUS_PROBFORMULA_H
|
||||||
|
|
||||||
|
@ -3,51 +3,35 @@
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Solver::printAllPosterioris (void)
|
Solver::printAnswer (const VarIds& vids)
|
||||||
{
|
{
|
||||||
const VarNodes& vars = gm_->getVariableNodes();
|
Vars unobservedVars;
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
VarIds unobservedVids;
|
||||||
printPosterioriOf (vars[i]->varId());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Solver::printPosterioriOf (VarId vid)
|
|
||||||
{
|
|
||||||
VarNode* var = gm_->getVariableNode (vid);
|
|
||||||
const Params& posterioriDist = getPosterioriOf (vid);
|
|
||||||
const States& states = var->states();
|
|
||||||
for (unsigned i = 0; i < states.size(); i++) {
|
|
||||||
cout << "P(" << var->label() << "=" << states[i] << ") = " ;
|
|
||||||
cout << setprecision (PRECISION) << posterioriDist[i];
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
Solver::printJointDistributionOf (const VarIds& vids)
|
|
||||||
{
|
|
||||||
VarNodes vars;
|
|
||||||
VarIds vidsWithoutEvidence;
|
|
||||||
for (unsigned i = 0; i < vids.size(); i++) {
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
VarNode* var = gm_->getVariableNode (vids[i]);
|
VarNode* vn = fg.getVarNode (vids[i]);
|
||||||
if (var->hasEvidence() == false) {
|
if (vn->hasEvidence() == false) {
|
||||||
vars.push_back (var);
|
unobservedVars.push_back (vn);
|
||||||
vidsWithoutEvidence.push_back (vids[i]);
|
unobservedVids.push_back (vids[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const Params& jointDist = getJointDistributionOf (vidsWithoutEvidence);
|
Params res = solveQuery (unobservedVids);
|
||||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
vector<string> stateLines = Util::getStateLines (unobservedVars);
|
||||||
for (unsigned i = 0; i < jointDist.size(); i++) {
|
for (unsigned i = 0; i < res.size(); i++) {
|
||||||
cout << "P(" << jointStrings[i] << ") = " ;
|
cout << "P(" << stateLines[i] << ") = " ;
|
||||||
cout << setprecision (PRECISION) << jointDist[i];
|
cout << std::setprecision (Constants::PRECISION) << res[i];
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Solver::printAllPosterioris (void)
|
||||||
|
{
|
||||||
|
const VarNodes& vars = fg.varNodes();
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
printAnswer ({vars[i]->varId()});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -3,29 +3,27 @@
|
|||||||
|
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
|
||||||
#include "GraphicalModel.h"
|
#include "Var.h"
|
||||||
#include "VarNode.h"
|
#include "FactorGraph.h"
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
class Solver
|
class Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Solver (const GraphicalModel* gm)
|
Solver (const FactorGraph& factorGraph) : fg(factorGraph) { }
|
||||||
{
|
|
||||||
gm_ = gm;
|
virtual ~Solver() { } // ensure that subclass destructor is called
|
||||||
}
|
|
||||||
virtual ~Solver() {} // to ensure that subclass destructor is called
|
virtual Params solveQuery (VarIds queryVids) = 0;
|
||||||
virtual void runSolver (void) = 0;
|
|
||||||
virtual Params getPosterioriOf (VarId) = 0;
|
void printAnswer (const VarIds& vids);
|
||||||
virtual Params getJointDistributionOf (const VarIds&) = 0;
|
|
||||||
|
|
||||||
void printAllPosterioris (void);
|
void printAllPosterioris (void);
|
||||||
void printPosterioriOf (VarId vid);
|
|
||||||
void printJointDistributionOf (const VarIds& vids);
|
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
const GraphicalModel* gm_;
|
const FactorGraph& fg;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // HORUS_SOLVER_H
|
#endif // HORUS_SOLVER_H
|
||||||
|
4
packages/CLPBN/clpbn/bp/TODO
Normal file
4
packages/CLPBN/clpbn/bp/TODO
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
TODO
|
||||||
|
- add a way to calculate combinations and factorials with large numbers
|
||||||
|
- refactor sumOut in parfactor -> is really ugly code
|
||||||
|
- Indexer: start receiving ranges as constant reference
|
@ -1,21 +1,22 @@
|
|||||||
|
#include <limits>
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
#include "Indexer.h"
|
#include "Indexer.h"
|
||||||
#include "GraphicalModel.h"
|
|
||||||
|
|
||||||
|
|
||||||
namespace Globals {
|
namespace Globals {
|
||||||
bool logDomain = false;
|
bool logDomain = false;
|
||||||
|
|
||||||
|
//InfAlgs infAlgorithm = InfAlgorithms::VE;
|
||||||
|
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
|
||||||
|
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
|
||||||
|
InfAlgorithms infAlgorithm = InfAlgorithms::CBP;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
namespace InfAlgorithms {
|
|
||||||
//InfAlgs infAlgorithm = InfAlgorithms::VE;
|
|
||||||
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
|
|
||||||
InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
|
|
||||||
//InfAlgs infAlgorithm = InfAlgorithms::CBP;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
namespace BpOptions {
|
namespace BpOptions {
|
||||||
@ -23,13 +24,11 @@ Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
|
|||||||
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
|
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
|
||||||
//Schedule schedule = BpOptions::Schedule::PARALLEL;
|
//Schedule schedule = BpOptions::Schedule::PARALLEL;
|
||||||
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
||||||
double accuracy = 0.0001;
|
double accuracy = 0.0001;
|
||||||
unsigned maxIter = 1000;
|
unsigned maxIter = 1000;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
unordered_map<VarId,VariableInfo> GraphicalModel::varsInfo_;
|
|
||||||
unordered_map<unsigned,Distribution*> GraphicalModel::distsInfo_;
|
|
||||||
|
|
||||||
vector<NetInfo> Statistics::netInfo_;
|
vector<NetInfo> Statistics::netInfo_;
|
||||||
vector<CompressInfo> Statistics::compressInfo_;
|
vector<CompressInfo> Statistics::compressInfo_;
|
||||||
@ -58,76 +57,6 @@ fromLog (Params& v)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
normalize (Params& v)
|
|
||||||
{
|
|
||||||
double sum;
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
sum = addIdenty();
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
logSum (sum, v[i]);
|
|
||||||
}
|
|
||||||
assert (sum != -numeric_limits<double>::infinity());
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] -= sum;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sum = 0.0;
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
sum += v[i];
|
|
||||||
}
|
|
||||||
assert (sum != 0.0);
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] /= sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
pow (Params& v, double expoent)
|
|
||||||
{
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] *= expoent;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] = std::pow (v[i], expoent);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
pow (Params& v, unsigned expoent)
|
|
||||||
{
|
|
||||||
if (expoent == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (Globals::logDomain) {
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] *= expoent;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
v[i] = std::pow (v[i], expoent);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
|
||||||
pow (double p, unsigned expoent)
|
|
||||||
{
|
|
||||||
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
double
|
||||||
factorial (double num)
|
factorial (double num)
|
||||||
{
|
{
|
||||||
@ -153,6 +82,151 @@ nrCombinations (unsigned n, unsigned r)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
expectedSize (const Ranges& ranges)
|
||||||
|
{
|
||||||
|
unsigned prod = 1;
|
||||||
|
for (unsigned i = 0; i < ranges.size(); i++) {
|
||||||
|
prod *= ranges[i];
|
||||||
|
}
|
||||||
|
return prod;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
unsigned
|
||||||
|
getNumberOfDigits (int number)
|
||||||
|
{
|
||||||
|
unsigned count = 1;
|
||||||
|
while (number >= 10) {
|
||||||
|
number /= 10;
|
||||||
|
count ++;
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
isInteger (const string& s)
|
||||||
|
{
|
||||||
|
stringstream ss1 (s);
|
||||||
|
stringstream ss2;
|
||||||
|
int integer;
|
||||||
|
ss1 >> integer;
|
||||||
|
ss2 << integer;
|
||||||
|
return (ss1.str() == ss2.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
parametersToString (const Params& v, unsigned precision)
|
||||||
|
{
|
||||||
|
stringstream ss;
|
||||||
|
ss.precision (precision);
|
||||||
|
ss << "[" ;
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
if (i != 0) ss << ", " ;
|
||||||
|
ss << v[i];
|
||||||
|
}
|
||||||
|
ss << "]" ;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vector<string>
|
||||||
|
getStateLines (const Vars& vars)
|
||||||
|
{
|
||||||
|
StatesIndexer idx (vars);
|
||||||
|
vector<string> jointStrings;
|
||||||
|
while (idx.valid()) {
|
||||||
|
stringstream ss;
|
||||||
|
for (unsigned i = 0; i < vars.size(); i++) {
|
||||||
|
if (i != 0) ss << ", " ;
|
||||||
|
ss << vars[i]->label() << "=" << vars[i]->states()[(idx[i])];
|
||||||
|
}
|
||||||
|
jointStrings.push_back (ss.str());
|
||||||
|
++ idx;
|
||||||
|
}
|
||||||
|
return jointStrings;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
printHeader (string header, std::ostream& os)
|
||||||
|
{
|
||||||
|
printAsteriskLine (os);
|
||||||
|
os << header << endl;
|
||||||
|
printAsteriskLine (os);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
printSubHeader (string header, std::ostream& os)
|
||||||
|
{
|
||||||
|
printDashedLine (os);
|
||||||
|
os << header << endl;
|
||||||
|
printDashedLine (os);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
printAsteriskLine (std::ostream& os)
|
||||||
|
{
|
||||||
|
os << "********************************" ;
|
||||||
|
os << "********************************" ;
|
||||||
|
os << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
printDashedLine (std::ostream& os)
|
||||||
|
{
|
||||||
|
os << "--------------------------------" ;
|
||||||
|
os << "--------------------------------" ;
|
||||||
|
os << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
namespace LogAware {
|
||||||
|
|
||||||
|
void
|
||||||
|
normalize (Params& v)
|
||||||
|
{
|
||||||
|
double sum;
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
sum = LogAware::addIdenty();
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
sum = Util::logSum (sum, v[i]);
|
||||||
|
}
|
||||||
|
assert (sum != -numeric_limits<double>::infinity());
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
v[i] -= sum;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sum = 0.0;
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
sum += v[i];
|
||||||
|
}
|
||||||
|
assert (sum != 0.0);
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
v[i] /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
double
|
double
|
||||||
getL1Distance (const Params& v1, const Params& v2)
|
getL1Distance (const Params& v1, const Params& v2)
|
||||||
{
|
{
|
||||||
@ -196,67 +270,57 @@ getMaxNorm (const Params& v1, const Params& v2)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
double
|
||||||
|
pow (double p, unsigned expoent)
|
||||||
|
{
|
||||||
|
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
||||||
|
}
|
||||||
|
|
||||||
unsigned
|
|
||||||
getNumberOfDigits (int number) {
|
|
||||||
unsigned count = 1;
|
double
|
||||||
while (number >= 10) {
|
pow (double p, double expoent)
|
||||||
number /= 10;
|
{
|
||||||
count ++;
|
// assumes that `expoent' is never in log domain
|
||||||
|
return Globals::logDomain ? p * expoent : std::pow (p, expoent);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
pow (Params& v, unsigned expoent)
|
||||||
|
{
|
||||||
|
if (expoent == 1) {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
return count;
|
if (Globals::logDomain) {
|
||||||
}
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
v[i] *= expoent;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
bool
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
isInteger (const string& s)
|
v[i] = std::pow (v[i], expoent);
|
||||||
{
|
|
||||||
stringstream ss1 (s);
|
|
||||||
stringstream ss2;
|
|
||||||
int integer;
|
|
||||||
ss1 >> integer;
|
|
||||||
ss2 << integer;
|
|
||||||
return (ss1.str() == ss2.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
|
||||||
parametersToString (const Params& v, unsigned precision)
|
|
||||||
{
|
|
||||||
stringstream ss;
|
|
||||||
ss.precision (precision);
|
|
||||||
ss << "[" ;
|
|
||||||
for (unsigned i = 0; i < v.size(); i++) {
|
|
||||||
if (i != 0) ss << ", " ;
|
|
||||||
ss << v[i];
|
|
||||||
}
|
|
||||||
ss << "]" ;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vector<string>
|
|
||||||
getJointStateStrings (const VarNodes& vars)
|
|
||||||
{
|
|
||||||
StatesIndexer idx (vars);
|
|
||||||
vector<string> jointStrings;
|
|
||||||
while (idx.valid()) {
|
|
||||||
stringstream ss;
|
|
||||||
for (unsigned i = 0; i < vars.size(); i++) {
|
|
||||||
if (i != 0) ss << ", " ;
|
|
||||||
ss << vars[i]->label() << "=" << vars[i]->states()[(idx[i])];
|
|
||||||
}
|
}
|
||||||
jointStrings.push_back (ss.str());
|
|
||||||
++ idx;
|
|
||||||
}
|
}
|
||||||
return jointStrings;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
pow (Params& v, double expoent)
|
||||||
|
{
|
||||||
|
// assumes that `expoent' is never in log domain
|
||||||
|
if (Globals::logDomain) {
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
v[i] *= expoent;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < v.size(); i++) {
|
||||||
|
v[i] = std::pow (v[i], expoent);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -286,8 +350,11 @@ Statistics::getPrimaryNetworksCounting (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Statistics::updateStatistics (unsigned size, bool loopy,
|
Statistics::updateStatistics (
|
||||||
unsigned nIters, double time)
|
unsigned size,
|
||||||
|
bool loopy,
|
||||||
|
unsigned nIters,
|
||||||
|
double time)
|
||||||
{
|
{
|
||||||
netInfo_.push_back (NetInfo (size, loopy, nIters, time));
|
netInfo_.push_back (NetInfo (size, loopy, nIters, time));
|
||||||
}
|
}
|
||||||
@ -303,12 +370,12 @@ Statistics::printStatistics (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Statistics::writeStatisticsToFile (const char* fileName)
|
Statistics::writeStatistics (const char* fileName)
|
||||||
{
|
{
|
||||||
ofstream out (fileName);
|
ofstream out (fileName);
|
||||||
if (!out.is_open()) {
|
if (!out.is_open()) {
|
||||||
cerr << "error: cannot open file to write at " ;
|
cerr << "error: cannot open file to write at " ;
|
||||||
cerr << "Statistics::writeStatisticsToFile()" << endl;
|
cerr << "Statistics::writeStats()" << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
out << getStatisticString();
|
out << getStatisticString();
|
||||||
@ -318,13 +385,14 @@ Statistics::writeStatisticsToFile (const char* fileName)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
Statistics::updateCompressingStatistics (unsigned nGroundVars,
|
Statistics::updateCompressingStatistics (
|
||||||
unsigned nGroundFactors,
|
unsigned nrGroundVars,
|
||||||
unsigned nClusterVars,
|
unsigned nrGroundFactors,
|
||||||
unsigned nClusterFactors,
|
unsigned nrClusterVars,
|
||||||
unsigned nWithoutNeighs) {
|
unsigned nrClusterFactors,
|
||||||
compressInfo_.push_back (CompressInfo (nGroundVars, nGroundFactors,
|
unsigned nrNeighborless) {
|
||||||
nClusterVars, nClusterFactors, nWithoutNeighs));
|
compressInfo_.push_back (CompressInfo (nrGroundVars, nrGroundFactors,
|
||||||
|
nrClusterVars, nrClusterFactors, nrNeighborless));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -334,26 +402,30 @@ Statistics::getStatisticString (void)
|
|||||||
{
|
{
|
||||||
stringstream ss2, ss3, ss4, ss1;
|
stringstream ss2, ss3, ss4, ss1;
|
||||||
ss1 << "running mode: " ;
|
ss1 << "running mode: " ;
|
||||||
switch (InfAlgorithms::infAlgorithm) {
|
switch (Globals::infAlgorithm) {
|
||||||
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
||||||
case InfAlgorithms::BN_BP: ss1 << "bn_bp" << endl; break;
|
case InfAlgorithms::BP: ss1 << "bp" << endl; break;
|
||||||
case InfAlgorithms::FG_BP: ss1 << "fg_bp" << endl; break;
|
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
|
||||||
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
|
|
||||||
}
|
}
|
||||||
ss1 << "message schedule: " ;
|
ss1 << "message schedule: " ;
|
||||||
switch (BpOptions::schedule) {
|
switch (BpOptions::schedule) {
|
||||||
case BpOptions::Schedule::SEQ_FIXED: ss1 << "sequential fixed" << endl; break;
|
case BpOptions::Schedule::SEQ_FIXED:
|
||||||
case BpOptions::Schedule::SEQ_RANDOM: ss1 << "sequential random" << endl; break;
|
ss1 << "sequential fixed" << endl;
|
||||||
case BpOptions::Schedule::PARALLEL: ss1 << "parallel" << endl; break;
|
break;
|
||||||
case BpOptions::Schedule::MAX_RESIDUAL: ss1 << "max residual" << endl; break;
|
case BpOptions::Schedule::SEQ_RANDOM:
|
||||||
|
ss1 << "sequential random" << endl;
|
||||||
|
break;
|
||||||
|
case BpOptions::Schedule::PARALLEL:
|
||||||
|
ss1 << "parallel" << endl;
|
||||||
|
break;
|
||||||
|
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||||
|
ss1 << "max residual" << endl;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
ss1 << "max iterations: " << BpOptions::maxIter << endl;
|
ss1 << "max iterations: " << BpOptions::maxIter << endl;
|
||||||
ss1 << "accuracy " << BpOptions::accuracy << endl;
|
ss1 << "accuracy " << BpOptions::accuracy << endl;
|
||||||
ss1 << endl << endl;
|
ss1 << endl << endl;
|
||||||
|
Util::printSubHeader ("Network information", ss2);
|
||||||
ss2 << "---------------------------------------------------" << endl;
|
|
||||||
ss2 << " Network information" << endl;
|
|
||||||
ss2 << "---------------------------------------------------" << endl;
|
|
||||||
ss2 << left;
|
ss2 << left;
|
||||||
ss2 << setw (15) << "Network Size" ;
|
ss2 << setw (15) << "Network Size" ;
|
||||||
ss2 << setw (9) << "Loopy" ;
|
ss2 << setw (9) << "Loopy" ;
|
||||||
@ -387,24 +459,22 @@ Statistics::getStatisticString (void)
|
|||||||
|
|
||||||
unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0;
|
unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0;
|
||||||
if (compressInfo_.size() > 0) {
|
if (compressInfo_.size() > 0) {
|
||||||
ss3 << "---------------------------------------------------" << endl;
|
Util::printSubHeader ("Compress information", ss3);
|
||||||
ss3 << " Compression information" << endl;
|
|
||||||
ss3 << "---------------------------------------------------" << endl;
|
|
||||||
ss3 << left;
|
ss3 << left;
|
||||||
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
|
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
|
||||||
ss3 << "Vars Vars Factors Factors Vars" << endl;
|
ss3 << "Vars Vars Factors Factors Vars" << endl;
|
||||||
for (unsigned i = 0; i < compressInfo_.size(); i++) {
|
for (unsigned i = 0; i < compressInfo_.size(); i++) {
|
||||||
ss3 << setw (9) << compressInfo_[i].nGroundVars;
|
ss3 << setw (9) << compressInfo_[i].nrGroundVars;
|
||||||
ss3 << setw (10) << compressInfo_[i].nClusterVars;
|
ss3 << setw (10) << compressInfo_[i].nrClusterVars;
|
||||||
ss3 << setw (10) << compressInfo_[i].nGroundFactors;
|
ss3 << setw (10) << compressInfo_[i].nrGroundFactors;
|
||||||
ss3 << setw (10) << compressInfo_[i].nClusterFactors;
|
ss3 << setw (10) << compressInfo_[i].nrClusterFactors;
|
||||||
ss3 << setw (10) << compressInfo_[i].nWithoutNeighs;
|
ss3 << setw (10) << compressInfo_[i].nrNeighborless;
|
||||||
ss3 << endl;
|
ss3 << endl;
|
||||||
c1 += compressInfo_[i].nGroundVars - compressInfo_[i].nWithoutNeighs;
|
c1 += compressInfo_[i].nrGroundVars - compressInfo_[i].nrNeighborless;
|
||||||
c2 += compressInfo_[i].nClusterVars;
|
c2 += compressInfo_[i].nrClusterVars;
|
||||||
c3 += compressInfo_[i].nGroundFactors - compressInfo_[i].nWithoutNeighs;
|
c3 += compressInfo_[i].nrGroundFactors - compressInfo_[i].nrNeighborless;
|
||||||
c4 += compressInfo_[i].nClusterFactors;
|
c4 += compressInfo_[i].nrClusterFactors;
|
||||||
if (compressInfo_[i].nWithoutNeighs != 0) {
|
if (compressInfo_[i].nrNeighborless != 0) {
|
||||||
c2 --;
|
c2 --;
|
||||||
c4 --;
|
c4 --;
|
||||||
}
|
}
|
||||||
|
@ -1,53 +1,141 @@
|
|||||||
#ifndef HORUS_UTIL_H
|
#ifndef HORUS_UTIL_H
|
||||||
#define HORUS_UTIL_H
|
#define HORUS_UTIL_H
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cassert>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
#include <queue>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
namespace Util {
|
namespace Util {
|
||||||
|
|
||||||
void toLog (Params&);
|
template <typename T> void addToVector (vector<T>&, const vector<T>&);
|
||||||
void fromLog (Params&);
|
|
||||||
void normalize (Params&);
|
template <typename T> void addToSet (set<T>&, const vector<T>&);
|
||||||
void logSum (double&, double);
|
|
||||||
void multiply (Params&, const Params&);
|
template <typename T> void addToQueue (queue<T>&, const vector<T>&);
|
||||||
void multiply (Params&, const Params&, unsigned);
|
|
||||||
void add (Params&, const Params&);
|
template <typename T> bool contains (const vector<T>&, const T&);
|
||||||
void add (Params&, const Params&, unsigned);
|
|
||||||
void pow (Params&, double);
|
template <typename T> bool contains (const set<T>&, const T&);
|
||||||
void pow (Params&, unsigned);
|
|
||||||
double pow (double, unsigned);
|
template <typename K, typename V> bool contains (
|
||||||
double factorial (double);
|
const unordered_map<K, V>&, const K&);
|
||||||
unsigned nrCombinations (unsigned, unsigned);
|
|
||||||
double getL1Distance (const Params&, const Params&);
|
template <typename T> std::string toString (const T&);
|
||||||
double getMaxNorm (const Params&, const Params&);
|
|
||||||
unsigned getNumberOfDigits (int);
|
void toLog (Params&);
|
||||||
bool isInteger (const string&);
|
|
||||||
string parametersToString (const Params&, unsigned = PRECISION);
|
void fromLog (Params&);
|
||||||
vector<string> getJointStateStrings (const VarNodes&);
|
|
||||||
double tl (double);
|
double logSum (double, double);
|
||||||
double fl (double);
|
|
||||||
double multIdenty();
|
void multiply (Params&, const Params&);
|
||||||
double addIdenty();
|
|
||||||
double withEvidence();
|
void multiply (Params&, const Params&, unsigned);
|
||||||
double noEvidence();
|
|
||||||
double one();
|
void add (Params&, const Params&);
|
||||||
double zero();
|
|
||||||
|
void add (Params&, const Params&, unsigned);
|
||||||
|
|
||||||
|
double factorial (double);
|
||||||
|
|
||||||
|
unsigned nrCombinations (unsigned, unsigned);
|
||||||
|
|
||||||
|
unsigned expectedSize (const Ranges&);
|
||||||
|
|
||||||
|
unsigned getNumberOfDigits (int);
|
||||||
|
|
||||||
|
bool isInteger (const string&);
|
||||||
|
|
||||||
|
string parametersToString (const Params&, unsigned = Constants::PRECISION);
|
||||||
|
|
||||||
|
vector<string> getStateLines (const Vars&);
|
||||||
|
|
||||||
|
void printHeader (string, std::ostream& os = std::cout);
|
||||||
|
|
||||||
|
void printSubHeader (string, std::ostream& os = std::cout);
|
||||||
|
|
||||||
|
void printAsteriskLine (std::ostream& os = std::cout);
|
||||||
|
|
||||||
|
void printDashedLine (std::ostream& os = std::cout);
|
||||||
|
|
||||||
|
unsigned maxUnsigned (void);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
std::string toString (const T& t)
|
template <typename T> void
|
||||||
|
Util::addToVector (vector<T>& v, const vector<T>& elements)
|
||||||
|
{
|
||||||
|
v.insert (v.end(), elements.begin(), elements.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> void
|
||||||
|
Util::addToSet (set<T>& s, const vector<T>& elements)
|
||||||
|
{
|
||||||
|
s.insert (elements.begin(), elements.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> void
|
||||||
|
Util::addToQueue (queue<T>& q, const vector<T>& elements)
|
||||||
|
{
|
||||||
|
for (unsigned i = 0; i < elements.size(); i++) {
|
||||||
|
q.push (elements[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> bool
|
||||||
|
Util::contains (const vector<T>& v, const T& e)
|
||||||
|
{
|
||||||
|
return std::find (v.begin(), v.end(), e) != v.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> bool
|
||||||
|
Util::contains (const set<T>& s, const T& e)
|
||||||
|
{
|
||||||
|
return s.find (e) != s.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename K, typename V> bool
|
||||||
|
Util::contains (const unordered_map<K, V>& m, const K& k)
|
||||||
|
{
|
||||||
|
return m.find (k) != m.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> std::string
|
||||||
|
Util::toString (const T& t)
|
||||||
{
|
{
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << t;
|
ss << t;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -62,28 +150,31 @@ std::ostream& operator << (std::ostream& os, const vector<T>& v)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
const double INF = -numeric_limits<double>::infinity();
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
inline void
|
inline double
|
||||||
Util::logSum (double& x, double y)
|
Util::logSum (double x, double y)
|
||||||
{
|
{
|
||||||
x = log (exp (x) + exp (y)); return;
|
return log (exp (x) + exp (y));
|
||||||
assert (isfinite (x) && isfinite (y));
|
assert (isfinite (x) && isfinite (y));
|
||||||
// If one value is much smaller than the other, keep the larger value.
|
// If one value is much smaller than the other, keep the larger value.
|
||||||
if (x < (y - log (1e200))) {
|
if (x < (y - log (1e200))) {
|
||||||
x = y;
|
return y;
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
if (y < (x - log (1e200))) {
|
if (y < (x - log (1e200))) {
|
||||||
return;
|
return x;
|
||||||
}
|
}
|
||||||
double diff = x - y;
|
double diff = x - y;
|
||||||
assert (isfinite (diff) && isfinite (x) && isfinite (y));
|
assert (isfinite (diff) && isfinite (x) && isfinite (y));
|
||||||
if (!isfinite (exp (diff))) { // difference is too large
|
if (!isfinite (exp (diff))) {
|
||||||
x = x > y ? x : y;
|
// difference is too large
|
||||||
} else { // otherwise return the sum.
|
return x > y ? x : y;
|
||||||
x = y + log (static_cast<double>(1.0) + exp (diff));
|
|
||||||
}
|
}
|
||||||
|
// otherwise return the sum.
|
||||||
|
return y + log (static_cast<double>(1.0) + exp (diff));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -140,52 +231,87 @@ Util::add (Params& v1, const Params& v2, unsigned repetitions)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
inline double
|
inline unsigned
|
||||||
Util::tl (double v)
|
Util::maxUnsigned (void)
|
||||||
{
|
{
|
||||||
return Globals::logDomain ? log(v) : v;
|
return numeric_limits<unsigned>::max();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline double
|
|
||||||
Util::fl (double v)
|
|
||||||
{
|
namespace LogAware {
|
||||||
return Globals::logDomain ? exp(v) : v;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline double
|
inline double
|
||||||
Util::multIdenty() {
|
one()
|
||||||
return Globals::logDomain ? 0.0 : 1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline double
|
|
||||||
Util::addIdenty()
|
|
||||||
{
|
|
||||||
return Globals::logDomain ? INF : 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline double
|
|
||||||
Util::withEvidence()
|
|
||||||
{
|
{
|
||||||
return Globals::logDomain ? 0.0 : 1.0;
|
return Globals::logDomain ? 0.0 : 1.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline double
|
|
||||||
Util::noEvidence() {
|
|
||||||
return Globals::logDomain ? INF : 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline double
|
inline double
|
||||||
Util::one()
|
zero() {
|
||||||
{
|
|
||||||
return Globals::logDomain ? 0.0 : 1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline double
|
|
||||||
Util::zero() {
|
|
||||||
return Globals::logDomain ? INF : 0.0 ;
|
return Globals::logDomain ? INF : 0.0 ;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline double
|
||||||
|
addIdenty()
|
||||||
|
{
|
||||||
|
return Globals::logDomain ? INF : 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline double
|
||||||
|
multIdenty()
|
||||||
|
{
|
||||||
|
return Globals::logDomain ? 0.0 : 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline double
|
||||||
|
withEvidence()
|
||||||
|
{
|
||||||
|
return Globals::logDomain ? 0.0 : 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline double
|
||||||
|
noEvidence() {
|
||||||
|
return Globals::logDomain ? INF : 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline double
|
||||||
|
tl (double v)
|
||||||
|
{
|
||||||
|
return Globals::logDomain ? log (v) : v;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline double
|
||||||
|
fl (double v)
|
||||||
|
{
|
||||||
|
return Globals::logDomain ? exp (v) : v;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void normalize (Params&);
|
||||||
|
|
||||||
|
double getL1Distance (const Params&, const Params&);
|
||||||
|
|
||||||
|
double getMaxNorm (const Params&, const Params&);
|
||||||
|
|
||||||
|
double pow (double, unsigned);
|
||||||
|
|
||||||
|
double pow (double, double);
|
||||||
|
|
||||||
|
void pow (Params&, unsigned);
|
||||||
|
|
||||||
|
void pow (Params&, double);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
struct NetInfo
|
struct NetInfo
|
||||||
{
|
{
|
||||||
NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
|
NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
|
||||||
@ -206,17 +332,17 @@ struct CompressInfo
|
|||||||
{
|
{
|
||||||
CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e)
|
CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e)
|
||||||
{
|
{
|
||||||
nGroundVars = a;
|
nrGroundVars = a;
|
||||||
nGroundFactors = b;
|
nrGroundFactors = b;
|
||||||
nClusterVars = c;
|
nrClusterVars = c;
|
||||||
nClusterFactors = d;
|
nrClusterFactors = d;
|
||||||
nWithoutNeighs = e;
|
nrNeighborless = e;
|
||||||
}
|
}
|
||||||
unsigned nGroundVars;
|
unsigned nrGroundVars;
|
||||||
unsigned nGroundFactors;
|
unsigned nrGroundFactors;
|
||||||
unsigned nClusterVars;
|
unsigned nrClusterVars;
|
||||||
unsigned nClusterFactors;
|
unsigned nrClusterFactors;
|
||||||
unsigned nWithoutNeighs;
|
unsigned nrNeighborless;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -224,11 +350,17 @@ class Statistics
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
static unsigned getSolvedNetworksCounting (void);
|
static unsigned getSolvedNetworksCounting (void);
|
||||||
|
|
||||||
static void incrementPrimaryNetworksCounting (void);
|
static void incrementPrimaryNetworksCounting (void);
|
||||||
|
|
||||||
static unsigned getPrimaryNetworksCounting (void);
|
static unsigned getPrimaryNetworksCounting (void);
|
||||||
|
|
||||||
static void updateStatistics (unsigned, bool, unsigned, double);
|
static void updateStatistics (unsigned, bool, unsigned, double);
|
||||||
|
|
||||||
static void printStatistics (void);
|
static void printStatistics (void);
|
||||||
static void writeStatisticsToFile (const char*);
|
|
||||||
|
static void writeStatistics (const char*);
|
||||||
|
|
||||||
static void updateCompressingStatistics (
|
static void updateCompressingStatistics (
|
||||||
unsigned, unsigned, unsigned, unsigned, unsigned);
|
unsigned, unsigned, unsigned, unsigned, unsigned);
|
||||||
|
|
||||||
|
102
packages/CLPBN/clpbn/bp/Var.cpp
Normal file
102
packages/CLPBN/clpbn/bp/Var.cpp
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "Var.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
unordered_map<VarId, VarInfo> Var::varsInfo_;
|
||||||
|
|
||||||
|
|
||||||
|
Var::Var (const Var* v)
|
||||||
|
{
|
||||||
|
varId_ = v->varId();
|
||||||
|
range_ = v->range();
|
||||||
|
evidence_ = v->getEvidence();
|
||||||
|
index_ = std::numeric_limits<unsigned>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Var::Var (VarId varId, unsigned range, int evidence)
|
||||||
|
{
|
||||||
|
assert (range != 0);
|
||||||
|
assert (evidence < (int) range);
|
||||||
|
varId_ = varId;
|
||||||
|
range_ = range;
|
||||||
|
evidence_ = evidence;
|
||||||
|
index_ = std::numeric_limits<unsigned>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Var::isValidState (int stateIndex)
|
||||||
|
{
|
||||||
|
return stateIndex >= 0 && stateIndex < (int) range_;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
bool
|
||||||
|
Var::isValidState (const string& stateName)
|
||||||
|
{
|
||||||
|
States states = Var::getVarInfo (varId_).states;
|
||||||
|
return Util::contains (states, stateName);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Var::setEvidence (int ev)
|
||||||
|
{
|
||||||
|
assert (ev < (int) range_);
|
||||||
|
evidence_ = ev;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void
|
||||||
|
Var::setEvidence (const string& ev)
|
||||||
|
{
|
||||||
|
States states = Var::getVarInfo (varId_).states;
|
||||||
|
for (unsigned i = 0; i < states.size(); i++) {
|
||||||
|
if (states[i] == ev) {
|
||||||
|
evidence_ = i;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert (false);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
string
|
||||||
|
Var::label (void) const
|
||||||
|
{
|
||||||
|
if (Var::varsHaveInfo()) {
|
||||||
|
return Var::getVarInfo (varId_).label;
|
||||||
|
}
|
||||||
|
stringstream ss;
|
||||||
|
ss << "x" << varId_;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
States
|
||||||
|
Var::states (void) const
|
||||||
|
{
|
||||||
|
if (Var::varsHaveInfo()) {
|
||||||
|
return Var::getVarInfo (varId_).states;
|
||||||
|
}
|
||||||
|
States states;
|
||||||
|
for (unsigned i = 0; i < range_; i++) {
|
||||||
|
stringstream ss;
|
||||||
|
ss << i ;
|
||||||
|
states.push_back (ss.str());
|
||||||
|
}
|
||||||
|
return states;
|
||||||
|
}
|
||||||
|
|
108
packages/CLPBN/clpbn/bp/Var.h
Normal file
108
packages/CLPBN/clpbn/bp/Var.h
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
#ifndef HORUS_Var_H
|
||||||
|
#define HORUS_Var_H
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "Util.h"
|
||||||
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
|
||||||
|
struct VarInfo
|
||||||
|
{
|
||||||
|
VarInfo (string l, const States& sts) : label(l), states(sts) { }
|
||||||
|
string label;
|
||||||
|
States states;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Var
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Var (const Var*);
|
||||||
|
|
||||||
|
Var (VarId, unsigned, int = Constants::NO_EVIDENCE);
|
||||||
|
|
||||||
|
virtual ~Var (void) { };
|
||||||
|
|
||||||
|
unsigned varId (void) const { return varId_; }
|
||||||
|
|
||||||
|
unsigned range (void) const { return range_; }
|
||||||
|
|
||||||
|
int getEvidence (void) const { return evidence_; }
|
||||||
|
|
||||||
|
unsigned getIndex (void) const { return index_; }
|
||||||
|
|
||||||
|
void setIndex (unsigned idx) { index_ = idx; }
|
||||||
|
|
||||||
|
operator unsigned () const { return index_; }
|
||||||
|
|
||||||
|
bool hasEvidence (void) const
|
||||||
|
{
|
||||||
|
return evidence_ != Constants::NO_EVIDENCE;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator== (const Var& var) const
|
||||||
|
{
|
||||||
|
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||||
|
return varId_ == var.varId();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator!= (const Var& var) const
|
||||||
|
{
|
||||||
|
assert (!(varId_ == var.varId() && range_ != var.range()));
|
||||||
|
return varId_ != var.varId();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isValidState (int);
|
||||||
|
|
||||||
|
bool isValidState (const string&);
|
||||||
|
|
||||||
|
void setEvidence (int);
|
||||||
|
|
||||||
|
void setEvidence (const string&);
|
||||||
|
|
||||||
|
string label (void) const;
|
||||||
|
|
||||||
|
States states (void) const;
|
||||||
|
|
||||||
|
static void addVarInfo (
|
||||||
|
VarId vid, string label, const States& states)
|
||||||
|
{
|
||||||
|
assert (Util::contains (varsInfo_, vid) == false);
|
||||||
|
varsInfo_.insert (make_pair (vid, VarInfo (label, states)));
|
||||||
|
}
|
||||||
|
|
||||||
|
static VarInfo getVarInfo (VarId vid)
|
||||||
|
{
|
||||||
|
assert (Util::contains (varsInfo_, vid));
|
||||||
|
return varsInfo_.find (vid)->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool varsHaveInfo (void)
|
||||||
|
{
|
||||||
|
return varsInfo_.size() != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void clearVarsInfo (void)
|
||||||
|
{
|
||||||
|
varsInfo_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
VarId varId_;
|
||||||
|
unsigned range_;
|
||||||
|
int evidence_;
|
||||||
|
unsigned index_;
|
||||||
|
|
||||||
|
static unordered_map<VarId, VarInfo> varsInfo_;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // BP_Var_H
|
||||||
|
|
@ -6,61 +6,27 @@
|
|||||||
#include "Util.h"
|
#include "Util.h"
|
||||||
|
|
||||||
|
|
||||||
VarElimSolver::VarElimSolver (const BayesNet& bn) : Solver (&bn)
|
|
||||||
{
|
|
||||||
bayesNet_ = &bn;
|
|
||||||
factorGraph_ = new FactorGraph (bn);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (&fg)
|
|
||||||
{
|
|
||||||
bayesNet_ = 0;
|
|
||||||
factorGraph_ = &fg;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarElimSolver::~VarElimSolver (void)
|
VarElimSolver::~VarElimSolver (void)
|
||||||
{
|
{
|
||||||
if (bayesNet_) {
|
delete factorList_.back();
|
||||||
delete factorGraph_;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
Params
|
||||||
VarElimSolver::getPosterioriOf (VarId vid)
|
VarElimSolver::solveQuery (VarIds queryVids)
|
||||||
{
|
|
||||||
assert (factorGraph_->getFgVarNode (vid));
|
|
||||||
FgVarNode* vn = factorGraph_->getFgVarNode (vid);
|
|
||||||
if (vn->hasEvidence()) {
|
|
||||||
Params params (vn->nrStates(), 0.0);
|
|
||||||
params[vn->getEvidence()] = 1.0;
|
|
||||||
return params;
|
|
||||||
}
|
|
||||||
return getJointDistributionOf (VarIds() = {vid});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Params
|
|
||||||
VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
|
||||||
{
|
{
|
||||||
factorList_.clear();
|
factorList_.clear();
|
||||||
varFactors_.clear();
|
varFactors_.clear();
|
||||||
elimOrder_.clear();
|
elimOrder_.clear();
|
||||||
createFactorList();
|
createFactorList();
|
||||||
introduceEvidence();
|
absorveEvidence();
|
||||||
chooseEliminationOrder (vids);
|
findEliminationOrder (queryVids);
|
||||||
processFactorList (vids);
|
processFactorList (queryVids);
|
||||||
Params params = factorList_.back()->getParameters();
|
Params params = factorList_.back()->params();
|
||||||
if (Globals::logDomain) {
|
if (Globals::logDomain) {
|
||||||
Util::fromLog (params);
|
Util::fromLog (params);
|
||||||
}
|
}
|
||||||
delete factorList_.back();
|
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,11 +35,11 @@ VarElimSolver::getJointDistributionOf (const VarIds& vids)
|
|||||||
void
|
void
|
||||||
VarElimSolver::createFactorList (void)
|
VarElimSolver::createFactorList (void)
|
||||||
{
|
{
|
||||||
const FgFacSet& factorNodes = factorGraph_->getFactorNodes();
|
const FacNodes& facNodes = fg.facNodes();
|
||||||
factorList_.reserve (factorNodes.size() * 2);
|
factorList_.reserve (facNodes.size() * 2);
|
||||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||||
factorList_.push_back (new Factor (*factorNodes[i]->factor()));
|
factorList_.push_back (new Factor (facNodes[i]->factor()));
|
||||||
const FgVarSet& neighs = factorNodes[i]->neighbors();
|
const VarNodes& neighs = facNodes[i]->neighbors();
|
||||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||||
unordered_map<VarId,vector<unsigned> >::iterator it
|
unordered_map<VarId,vector<unsigned> >::iterator it
|
||||||
= varFactors_.find (neighs[j]->varId());
|
= varFactors_.find (neighs[j]->varId());
|
||||||
@ -89,16 +55,16 @@ VarElimSolver::createFactorList (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
VarElimSolver::introduceEvidence (void)
|
VarElimSolver::absorveEvidence (void)
|
||||||
{
|
{
|
||||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
const VarNodes& varNodes = fg.varNodes();
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||||
if (varNodes[i]->hasEvidence()) {
|
if (varNodes[i]->hasEvidence()) {
|
||||||
const vector<unsigned>& idxs =
|
const vector<unsigned>& idxs =
|
||||||
varFactors_.find (varNodes[i]->varId())->second;
|
varFactors_.find (varNodes[i]->varId())->second;
|
||||||
for (unsigned j = 0; j < idxs.size(); j++) {
|
for (unsigned j = 0; j < idxs.size(); j++) {
|
||||||
Factor* factor = factorList_[idxs[j]];
|
Factor* factor = factorList_[idxs[j]];
|
||||||
if (factor->nrVariables() == 1) {
|
if (factor->nrArguments() == 1) {
|
||||||
factorList_[idxs[j]] = 0;
|
factorList_[idxs[j]] = 0;
|
||||||
} else {
|
} else {
|
||||||
factorList_[idxs[j]]->absorveEvidence (
|
factorList_[idxs[j]]->absorveEvidence (
|
||||||
@ -112,21 +78,9 @@ VarElimSolver::introduceEvidence (void)
|
|||||||
|
|
||||||
|
|
||||||
void
|
void
|
||||||
VarElimSolver::chooseEliminationOrder (const VarIds& vids)
|
VarElimSolver::findEliminationOrder (const VarIds& vids)
|
||||||
{
|
{
|
||||||
if (bayesNet_) {
|
elimOrder_ = ElimGraph::getEliminationOrder (factorList_, vids);
|
||||||
ElimGraph graph (*bayesNet_);
|
|
||||||
elimOrder_ = graph.getEliminatingOrder (vids);
|
|
||||||
} else {
|
|
||||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
|
||||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
|
||||||
VarId vid = varNodes[i]->varId();
|
|
||||||
if (std::find (vids.begin(), vids.end(), vid) == vids.end()
|
|
||||||
&& !varNodes[i]->hasEvidence()) {
|
|
||||||
elimOrder_.push_back (vid);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -149,12 +103,12 @@ VarElimSolver::processFactorList (const VarIds& vids)
|
|||||||
|
|
||||||
VarIds unobservedVids;
|
VarIds unobservedVids;
|
||||||
for (unsigned i = 0; i < vids.size(); i++) {
|
for (unsigned i = 0; i < vids.size(); i++) {
|
||||||
if (factorGraph_->getFgVarNode (vids[i])->hasEvidence() == false) {
|
if (fg.getVarNode (vids[i])->hasEvidence() == false) {
|
||||||
unobservedVids.push_back (vids[i]);
|
unobservedVids.push_back (vids[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
finalFactor->reorderVariables (unobservedVids);
|
finalFactor->reorderArguments (unobservedVids);
|
||||||
finalFactor->normalize();
|
finalFactor->normalize();
|
||||||
factorList_.push_back (finalFactor);
|
factorList_.push_back (finalFactor);
|
||||||
}
|
}
|
||||||
@ -165,13 +119,12 @@ void
|
|||||||
VarElimSolver::eliminate (VarId elimVar)
|
VarElimSolver::eliminate (VarId elimVar)
|
||||||
{
|
{
|
||||||
Factor* result = 0;
|
Factor* result = 0;
|
||||||
FgVarNode* vn = factorGraph_->getFgVarNode (elimVar);
|
|
||||||
vector<unsigned>& idxs = varFactors_.find (elimVar)->second;
|
vector<unsigned>& idxs = varFactors_.find (elimVar)->second;
|
||||||
for (unsigned i = 0; i < idxs.size(); i++) {
|
for (unsigned i = 0; i < idxs.size(); i++) {
|
||||||
unsigned idx = idxs[i];
|
unsigned idx = idxs[i];
|
||||||
if (factorList_[idx]) {
|
if (factorList_[idx]) {
|
||||||
if (result == 0) {
|
if (result == 0) {
|
||||||
result = new Factor(*factorList_[idx]);
|
result = new Factor (*factorList_[idx]);
|
||||||
} else {
|
} else {
|
||||||
result->multiply (*factorList_[idx]);
|
result->multiply (*factorList_[idx]);
|
||||||
}
|
}
|
||||||
@ -179,10 +132,10 @@ VarElimSolver::eliminate (VarId elimVar)
|
|||||||
factorList_[idx] = 0;
|
factorList_[idx] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (result != 0 && result->nrVariables() != 1) {
|
if (result != 0 && result->nrArguments() != 1) {
|
||||||
result->sumOut (vn->varId());
|
result->sumOut (elimVar);
|
||||||
factorList_.push_back (result);
|
factorList_.push_back (result);
|
||||||
const VarIds& resultVarIds = result->getVarIds();
|
const VarIds& resultVarIds = result->arguments();
|
||||||
for (unsigned i = 0; i < resultVarIds.size(); i++) {
|
for (unsigned i = 0; i < resultVarIds.size(); i++) {
|
||||||
vector<unsigned>& idxs =
|
vector<unsigned>& idxs =
|
||||||
varFactors_.find (resultVarIds[i])->second;
|
varFactors_.find (resultVarIds[i])->second;
|
||||||
@ -199,7 +152,6 @@ VarElimSolver::printActiveFactors (void)
|
|||||||
for (unsigned i = 0; i < factorList_.size(); i++) {
|
for (unsigned i = 0; i < factorList_.size(); i++) {
|
||||||
if (factorList_[i] != 0) {
|
if (factorList_[i] != 0) {
|
||||||
factorList_[i]->print();
|
factorList_[i]->print();
|
||||||
cout << endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
#include "Solver.h"
|
#include "Solver.h"
|
||||||
#include "FactorGraph.h"
|
#include "FactorGraph.h"
|
||||||
#include "BayesNet.h"
|
|
||||||
#include "Horus.h"
|
#include "Horus.h"
|
||||||
|
|
||||||
|
|
||||||
@ -15,23 +14,25 @@ using namespace std;
|
|||||||
class VarElimSolver : public Solver
|
class VarElimSolver : public Solver
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
VarElimSolver (const BayesNet&);
|
VarElimSolver (const FactorGraph& fg) : Solver (fg) { }
|
||||||
VarElimSolver (const FactorGraph&);
|
|
||||||
~VarElimSolver (void);
|
~VarElimSolver (void);
|
||||||
void runSolver (void) { }
|
|
||||||
Params getPosterioriOf (VarId);
|
Params solveQuery (VarIds);
|
||||||
Params getJointDistributionOf (const VarIds&);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void createFactorList (void);
|
void createFactorList (void);
|
||||||
void introduceEvidence (void);
|
|
||||||
void chooseEliminationOrder (const VarIds&);
|
void absorveEvidence (void);
|
||||||
|
|
||||||
|
void findEliminationOrder (const VarIds&);
|
||||||
|
|
||||||
void processFactorList (const VarIds&);
|
void processFactorList (const VarIds&);
|
||||||
|
|
||||||
void eliminate (VarId);
|
void eliminate (VarId);
|
||||||
|
|
||||||
void printActiveFactors (void);
|
void printActiveFactors (void);
|
||||||
|
|
||||||
const BayesNet* bayesNet_;
|
|
||||||
const FactorGraph* factorGraph_;
|
|
||||||
vector<Factor*> factorList_;
|
vector<Factor*> factorList_;
|
||||||
VarIds elimOrder_;
|
VarIds elimOrder_;
|
||||||
unordered_map<VarId, vector<unsigned>> varFactors_;
|
unordered_map<VarId, vector<unsigned>> varFactors_;
|
||||||
|
@ -1,100 +0,0 @@
|
|||||||
#include <algorithm>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "VarNode.h"
|
|
||||||
#include "GraphicalModel.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
|
|
||||||
VarNode::VarNode (const VarNode* v)
|
|
||||||
{
|
|
||||||
varId_ = v->varId();
|
|
||||||
nrStates_ = v->nrStates();
|
|
||||||
evidence_ = v->getEvidence();
|
|
||||||
index_ = std::numeric_limits<unsigned>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
VarNode::VarNode (VarId varId, unsigned nrStates, int evidence)
|
|
||||||
{
|
|
||||||
assert (nrStates != 0);
|
|
||||||
assert (evidence < (int) nrStates);
|
|
||||||
varId_ = varId;
|
|
||||||
nrStates_ = nrStates;
|
|
||||||
evidence_ = evidence;
|
|
||||||
index_ = std::numeric_limits<unsigned>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
VarNode::isValidState (int stateIndex)
|
|
||||||
{
|
|
||||||
return stateIndex >= 0 && stateIndex < (int) nrStates_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool
|
|
||||||
VarNode::isValidState (const string& stateName)
|
|
||||||
{
|
|
||||||
States states = GraphicalModel::getVariableInformation (varId_).states;
|
|
||||||
return find (states.begin(), states.end(), stateName) != states.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
VarNode::setEvidence (int ev)
|
|
||||||
{
|
|
||||||
assert (ev < (int) nrStates_);
|
|
||||||
evidence_ = ev;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void
|
|
||||||
VarNode::setEvidence (const string& ev)
|
|
||||||
{
|
|
||||||
States states = GraphicalModel::getVariableInformation (varId_).states;
|
|
||||||
for (unsigned i = 0; i < states.size(); i++) {
|
|
||||||
if (states[i] == ev) {
|
|
||||||
evidence_ = i;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert (false);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
string
|
|
||||||
VarNode::label (void) const
|
|
||||||
{
|
|
||||||
if (GraphicalModel::variablesHaveInformation()) {
|
|
||||||
return GraphicalModel::getVariableInformation (varId_).label;
|
|
||||||
}
|
|
||||||
stringstream ss;
|
|
||||||
ss << "x" << varId_;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
States
|
|
||||||
VarNode::states (void) const
|
|
||||||
{
|
|
||||||
if (GraphicalModel::variablesHaveInformation()) {
|
|
||||||
return GraphicalModel::getVariableInformation (varId_).states;
|
|
||||||
}
|
|
||||||
States states;
|
|
||||||
for (unsigned i = 0; i < nrStates_; i++) {
|
|
||||||
stringstream ss;
|
|
||||||
ss << i ;
|
|
||||||
states.push_back (ss.str());
|
|
||||||
}
|
|
||||||
return states;
|
|
||||||
}
|
|
||||||
|
|
@ -1,54 +0,0 @@
|
|||||||
#ifndef HORUS_VARNODE_H
|
|
||||||
#define HORUS_VARNODE_H
|
|
||||||
|
|
||||||
#include "Horus.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
class VarNode
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
VarNode (const VarNode*);
|
|
||||||
VarNode (VarId, unsigned, int = NO_EVIDENCE);
|
|
||||||
virtual ~VarNode (void) {};
|
|
||||||
|
|
||||||
bool isValidState (int);
|
|
||||||
bool isValidState (const string&);
|
|
||||||
void setEvidence (int);
|
|
||||||
void setEvidence (const string&);
|
|
||||||
string label (void) const;
|
|
||||||
States states (void) const;
|
|
||||||
|
|
||||||
unsigned varId (void) const { return varId_; }
|
|
||||||
unsigned nrStates (void) const { return nrStates_; }
|
|
||||||
bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; }
|
|
||||||
int getEvidence (void) const { return evidence_; }
|
|
||||||
unsigned getIndex (void) const { return index_; }
|
|
||||||
void setIndex (unsigned idx) { index_ = idx; }
|
|
||||||
|
|
||||||
operator unsigned () const { return index_; }
|
|
||||||
|
|
||||||
bool operator== (const VarNode& var) const
|
|
||||||
{
|
|
||||||
cout << "equal operator called" << endl;
|
|
||||||
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
|
|
||||||
return varId_ == var.varId();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operator!= (const VarNode& var) const
|
|
||||||
{
|
|
||||||
cout << "diff operator called" << endl;
|
|
||||||
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
|
|
||||||
return varId_ != var.varId();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
VarId varId_;
|
|
||||||
unsigned nrStates_;
|
|
||||||
int evidence_;
|
|
||||||
unsigned index_;
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // BP_VARNODE_H
|
|
||||||
|
|
35
packages/CLPBN/clpbn/bp/benchmarks/benchs.sh
Executable file
35
packages/CLPBN/clpbn/bp/benchmarks/benchs.sh
Executable file
@ -0,0 +1,35 @@
|
|||||||
|
|
||||||
|
if [ $1 ] && [ $1 == "clear" ]; then
|
||||||
|
rm *~
|
||||||
|
rm -f school/*.log school/*~
|
||||||
|
rm -f city/*.log city/*~
|
||||||
|
rm -f workshop_attrs/*.log workshop_attrs/*~
|
||||||
|
fi
|
||||||
|
|
||||||
|
function run_solver
|
||||||
|
{
|
||||||
|
constraint=$1
|
||||||
|
solver_flag=true
|
||||||
|
if [ -n "$2" ]; then
|
||||||
|
if [ $SOLVER = hve ]; then
|
||||||
|
extra_flag=clpbn_horus:set_horus_flag\(elim_heuristic,$2\)
|
||||||
|
elif [ $SOLVER = bp ]; then
|
||||||
|
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||||
|
elif [ $SOLVER = cbp ]; then
|
||||||
|
extra_flag=clpbn_horus:set_horus_flag\(schedule,$2\)
|
||||||
|
else
|
||||||
|
echo "unknow flag $2"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
/usr/bin/time -o $LOG_FILE -a -f "real:%E\tuser:%U\tsys:%S" \
|
||||||
|
$YAP << EOF >> $LOG_FILE 2>> ignore.$LOG_FILE
|
||||||
|
[$NETWORK].
|
||||||
|
[$constraint].
|
||||||
|
clpbn_horus:set_solver($SOLVER).
|
||||||
|
clpbn_horus:set_horus_flag(use_logarithms, true).
|
||||||
|
$solver_flag.
|
||||||
|
$QUERY.
|
||||||
|
open("$LOG_FILE", 'append', S), format(S, '$constraint: ~15+ ', []), close(S).
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
@ -1,50 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
cp ~/bin/yap ~/bin/town_bnbp
|
|
||||||
YAP=~/bin/town_bnbp
|
|
||||||
|
|
||||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
|
||||||
OUT_FILE_NAME=bnbp.log
|
|
||||||
rm -f $OUT_FILE_NAME
|
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
|
||||||
|
|
||||||
|
|
||||||
function run_solver
|
|
||||||
{
|
|
||||||
if [ $2 = bp ]
|
|
||||||
then
|
|
||||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
|
||||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
|
||||||
else
|
|
||||||
extra_flag1=true
|
|
||||||
extra_flag2=true
|
|
||||||
fi
|
|
||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
|
||||||
[$1].
|
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
|
||||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
|
||||||
$extra_flag1, $extra_flag2,
|
|
||||||
run_query(_R),
|
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
|
||||||
format(S, '$3: ~15+ ',[]),
|
|
||||||
close(S).
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function run_all_graphs
|
|
||||||
{
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
|
||||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
|
||||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
|
||||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
|
||||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
|
||||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
|
||||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
|
||||||
}
|
|
||||||
|
|
||||||
run_all_graphs bp "bn_bp(seq_fixed) " bn_bp seq_fixed
|
|
||||||
|
|
17
packages/CLPBN/clpbn/bp/benchmarks/city/bp_tests.sh
Executable file
17
packages/CLPBN/clpbn/bp/benchmarks/city/bp_tests.sh
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source city.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="bp"
|
||||||
|
|
||||||
|
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||||
|
|
||||||
|
LOG_FILE=$SOLVER.log
|
||||||
|
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||||
|
|
||||||
|
rm -f $LOG_FILE
|
||||||
|
rm -f ignore.$LOG_FILE
|
||||||
|
|
||||||
|
run_all_graphs "bp(shedule=seq_fixed) " seq_fixed
|
||||||
|
|
@ -1,54 +1,17 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
cp ~/bin/yap ~/bin/town_cbp
|
source city.sh
|
||||||
YAP=~/bin/town_cbp
|
source ../benchs.sh
|
||||||
|
|
||||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
SOLVER="cbp"
|
||||||
OUT_FILE_NAME=cbp.log
|
|
||||||
rm -f $OUT_FILE_NAME
|
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
|
||||||
|
|
||||||
|
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||||
|
|
||||||
function run_solver
|
LOG_FILE=$SOLVER.log
|
||||||
{
|
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||||
if [ $2 = bp ]
|
|
||||||
then
|
|
||||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
|
||||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
|
||||||
else
|
|
||||||
extra_flag1=true
|
|
||||||
extra_flag2=true
|
|
||||||
fi
|
|
||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
|
||||||
[$1].
|
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
|
||||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
|
||||||
$extra_flag1, $extra_flag2,
|
|
||||||
run_query(_R),
|
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
|
||||||
format(S, '$3: ~15+ ',[]),
|
|
||||||
close(S).
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
|
rm -f $LOG_FILE
|
||||||
|
rm -f ignore.$LOG_FILE
|
||||||
|
|
||||||
function run_all_graphs
|
run_all_graphs "cbp(shedule=seq_fixed) " seq_fixed
|
||||||
{
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
|
||||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
|
||||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
|
||||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
|
||||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
|
||||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
|
||||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
|
||||||
run_solver town_2500000 $1 town_2500000 $3 $4 $5
|
|
||||||
run_solver town_5000000 $1 town_5000000 $3 $4 $5
|
|
||||||
run_solver town_7500000 $1 town_7500000 $3 $4 $5
|
|
||||||
run_solver town_10000000 $1 town_10000000 $3 $4 $5
|
|
||||||
}
|
|
||||||
|
|
||||||
run_all_graphs bp "cbp(seq_fixed) " cbp seq_fixed
|
|
||||||
|
|
||||||
|
25
packages/CLPBN/clpbn/bp/benchmarks/city/city.sh
Executable file
25
packages/CLPBN/clpbn/bp/benchmarks/city/city.sh
Executable file
@ -0,0 +1,25 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
NETWORK="'../../examples/city'"
|
||||||
|
SHORTNAME="city"
|
||||||
|
QUERY="is_joe_guilty(X)"
|
||||||
|
|
||||||
|
|
||||||
|
function run_all_graphs
|
||||||
|
{
|
||||||
|
cp ~/bin/yap $YAP
|
||||||
|
echo -n "**********************************" >> $LOG_FILE
|
||||||
|
echo "**********************************" >> $LOG_FILE
|
||||||
|
echo "results for solver $1" >> $LOG_FILE
|
||||||
|
echo -n "**********************************" >> $LOG_FILE
|
||||||
|
echo "**********************************" >> $LOG_FILE
|
||||||
|
run_solver city_5 $2
|
||||||
|
#run_solver city_1000 $2
|
||||||
|
#run_solver city_5000 $2
|
||||||
|
#run_solver city_10000 $2
|
||||||
|
#run_solver city_50000 $2
|
||||||
|
#run_solver city_100000 $2
|
||||||
|
#run_solver city_500000 $2
|
||||||
|
#run_solver city_1000000 $2
|
||||||
|
}
|
||||||
|
|
37
packages/CLPBN/clpbn/bp/benchmarks/city/city_generator.sh
Executable file
37
packages/CLPBN/clpbn/bp/benchmarks/city/city_generator.sh
Executable file
@ -0,0 +1,37 @@
|
|||||||
|
#!/home/tiago/bin/yap -L --
|
||||||
|
|
||||||
|
|
||||||
|
:- initialization(main).
|
||||||
|
|
||||||
|
|
||||||
|
main :-
|
||||||
|
unix(argv([H])),
|
||||||
|
generate_town(H).
|
||||||
|
|
||||||
|
|
||||||
|
generate_town(N) :-
|
||||||
|
atomic_concat(['city_', N, '.yap'], FileName),
|
||||||
|
open(FileName, 'write', S),
|
||||||
|
atom_number(N, N2),
|
||||||
|
generate_people(S, N2, 4),
|
||||||
|
write(S, '\n'),
|
||||||
|
generate_query(S, N2, 4),
|
||||||
|
write(S, '\n'),
|
||||||
|
close(S).
|
||||||
|
|
||||||
|
|
||||||
|
generate_people(S, N, Counting) :-
|
||||||
|
Counting > N, !.
|
||||||
|
generate_people(S, N, Counting) :-
|
||||||
|
format(S, 'people(p~w, nyc).~n', [Counting]),
|
||||||
|
Counting1 is Counting + 1,
|
||||||
|
generate_people(S, N, Counting1).
|
||||||
|
|
||||||
|
|
||||||
|
generate_query(S, N, Counting) :-
|
||||||
|
Counting > N, !.
|
||||||
|
generate_query(S, N, Counting) :- !,
|
||||||
|
format(S, 'ev(descn(p~w, t)).~n', [Counting]),
|
||||||
|
Counting1 is Counting + 1,
|
||||||
|
generate_query(S, N, Counting1).
|
||||||
|
|
@ -1,50 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
cp ~/bin/yap ~/bin/town_fgbp
|
|
||||||
YAP=~/bin/town_fgbp
|
|
||||||
|
|
||||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
|
||||||
OUT_FILE_NAME=fb_bp.log
|
|
||||||
rm -f $OUT_FILE_NAME
|
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
|
||||||
|
|
||||||
|
|
||||||
function run_solver
|
|
||||||
{
|
|
||||||
if [ $2 = bp ]
|
|
||||||
then
|
|
||||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
|
||||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
|
||||||
else
|
|
||||||
extra_flag1=true
|
|
||||||
extra_flag2=true
|
|
||||||
fi
|
|
||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
|
||||||
[$1].
|
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
|
||||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
|
||||||
$extra_flag1, $extra_flag2,
|
|
||||||
run_query(_R),
|
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
|
||||||
format(S, '$3: ~15+ ',[]),
|
|
||||||
close(S).
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function run_all_graphs
|
|
||||||
{
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
|
||||||
#run_solver town_5000 $1 town_5000 $3 $4 $5
|
|
||||||
#run_solver town_10000 $1 town_10000 $3 $4 $5
|
|
||||||
#run_solver town_50000 $1 town_50000 $3 $4 $5
|
|
||||||
#run_solver town_100000 $1 town_100000 $3 $4 $5
|
|
||||||
#run_solver town_500000 $1 town_500000 $3 $4 $5
|
|
||||||
#run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
|
||||||
}
|
|
||||||
|
|
||||||
run_all_graphs bp "fg_bp(seq_fixed) " fg_bp seq_fixed
|
|
||||||
|
|
17
packages/CLPBN/clpbn/bp/benchmarks/city/fove_tests.sh
Executable file
17
packages/CLPBN/clpbn/bp/benchmarks/city/fove_tests.sh
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source city.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="fove"
|
||||||
|
|
||||||
|
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||||
|
|
||||||
|
LOG_FILE=$SOLVER.log
|
||||||
|
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||||
|
|
||||||
|
rm -f $LOG_FILE
|
||||||
|
rm -f ignore.$LOG_FILEE
|
||||||
|
|
||||||
|
run_all_graphs "fove "
|
||||||
|
|
@ -1,50 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
cp ~/bin/yap ~/bin/town_gibbs
|
|
||||||
YAP=~/bin/town_gibbs
|
|
||||||
|
|
||||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
|
||||||
OUT_FILE_NAME=gibbs.log
|
|
||||||
rm -f $OUT_FILE_NAME
|
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
|
||||||
|
|
||||||
|
|
||||||
function run_solver
|
|
||||||
{
|
|
||||||
if [ $2 = bp ]
|
|
||||||
then
|
|
||||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
|
||||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
|
||||||
else
|
|
||||||
extra_flag1=true
|
|
||||||
extra_flag2=true
|
|
||||||
fi
|
|
||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
|
||||||
[$1].
|
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
|
||||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
|
||||||
$extra_flag1, $extra_flag2,
|
|
||||||
run_query(_R),
|
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
|
||||||
format(S, '$3: ~15+ ',[]),
|
|
||||||
close(S).
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function run_all_graphs
|
|
||||||
{
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
|
||||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
|
||||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
|
||||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
|
||||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
|
||||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
|
||||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
|
||||||
}
|
|
||||||
|
|
||||||
run_all_graphs gibbs "gibbs "
|
|
||||||
|
|
17
packages/CLPBN/clpbn/bp/benchmarks/city/hve_tests.sh
Executable file
17
packages/CLPBN/clpbn/bp/benchmarks/city/hve_tests.sh
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source city.sh
|
||||||
|
source ../benchs.sh
|
||||||
|
|
||||||
|
SOLVER="hve"
|
||||||
|
|
||||||
|
YAP=~/bin/$SHORTNAME-$SOLVER
|
||||||
|
|
||||||
|
LOG_FILE=$SOLVER.log
|
||||||
|
#LOG_FILE=results`date "+ %H:%M:%S %d-%m-%Y"`.
|
||||||
|
|
||||||
|
rm -f $LOG_FILE
|
||||||
|
rm -f ignore.$LOG_FILE
|
||||||
|
|
||||||
|
run_all_graphs "hve(elim_heuristic=min_neighbors) " min_neighbors
|
||||||
|
|
@ -1,50 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
cp ~/bin/yap ~/bin/town_jt
|
|
||||||
YAP=~/bin/town_jt
|
|
||||||
|
|
||||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
|
||||||
OUT_FILE_NAME=jt.log
|
|
||||||
rm -f $OUT_FILE_NAME
|
|
||||||
rm -f ignore.$OUT_FILE_NAME
|
|
||||||
|
|
||||||
|
|
||||||
function run_solver
|
|
||||||
{
|
|
||||||
if [ $2 = bp ]
|
|
||||||
then
|
|
||||||
extra_flag1=clpbn_bp:set_horus_flag\(inf_alg,$4\)
|
|
||||||
extra_flag2=clpbn_bp:set_horus_flag\(schedule,$5\)
|
|
||||||
else
|
|
||||||
extra_flag1=true
|
|
||||||
extra_flag2=true
|
|
||||||
fi
|
|
||||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
|
||||||
[$1].
|
|
||||||
clpbn:set_clpbn_flag(solver,$2),
|
|
||||||
clpbn_bp:set_horus_flag(use_logarithms, true),
|
|
||||||
$extra_flag1, $extra_flag2,
|
|
||||||
run_query(_R),
|
|
||||||
open("$OUT_FILE_NAME", 'append',S),
|
|
||||||
format(S, '$3: ~15+ ',[]),
|
|
||||||
close(S).
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function run_all_graphs
|
|
||||||
{
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
|
||||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
|
||||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
|
||||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
|
||||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
|
||||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
|
||||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
|
||||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
|
||||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
|
||||||
}
|
|
||||||
|
|
||||||
run_all_graphs jt "jt "
|
|
||||||
|
|
@ -1,65 +0,0 @@
|
|||||||
|
|
||||||
conservative_city(City, Cons) :-
|
|
||||||
cons_table(City, ConsDist),
|
|
||||||
{ Cons = conservative_city(City) with p([y,n], ConsDist) }.
|
|
||||||
|
|
||||||
|
|
||||||
gender(X, Gender) :-
|
|
||||||
gender_table(X, GenderDist),
|
|
||||||
{ Gender = gender(X) with p([m,f], GenderDist) }.
|
|
||||||
|
|
||||||
|
|
||||||
hair_color(X, Color) :-
|
|
||||||
lives(X, City),
|
|
||||||
conservative_city(City, Cons),
|
|
||||||
hair_color_table(X,ColorTable),
|
|
||||||
{ Color = hair_color(X) with
|
|
||||||
p([t,f], ColorTable,[Cons]) }.
|
|
||||||
|
|
||||||
|
|
||||||
car_color(X, Color) :-
|
|
||||||
hair_color(X, HColor),
|
|
||||||
car_color_table(X,CColorTable),
|
|
||||||
{ Color = car_color(X) with
|
|
||||||
p([t,f], CColorTable,[HColor]) }.
|
|
||||||
|
|
||||||
|
|
||||||
height(X, Height) :-
|
|
||||||
gender(X, Gender),
|
|
||||||
height_table(X,HeightTable),
|
|
||||||
{ Height = height(X) with
|
|
||||||
p([t,f], HeightTable,[Gender]) }.
|
|
||||||
|
|
||||||
|
|
||||||
shoe_size(X, Shoesize) :-
|
|
||||||
height(X, Height),
|
|
||||||
shoe_size_table(X,ShoesizeTable),
|
|
||||||
{ Shoesize = shoe_size(X) with
|
|
||||||
p([t,f], ShoesizeTable,[Height]) }.
|
|
||||||
|
|
||||||
|
|
||||||
guilty(X, Guilt) :-
|
|
||||||
guilty_table(X, GuiltDist),
|
|
||||||
{ Guilt = guilty(X) with p([y,n], GuiltDist) }.
|
|
||||||
|
|
||||||
|
|
||||||
descn(X, Descn) :-
|
|
||||||
car_color(X, Car),
|
|
||||||
hair_color(X, Hair),
|
|
||||||
height(X, Height),
|
|
||||||
guilty(X, Guilt),
|
|
||||||
descn_table(X, DescTable),
|
|
||||||
{ Descn = descn(X) with
|
|
||||||
p([t,f], DescTable,[Car,Hair,Height,Guilt]) }.
|
|
||||||
|
|
||||||
|
|
||||||
witness(City, Witness) :-
|
|
||||||
descn(joe, DescnJ),
|
|
||||||
descn(p2, Descn2),
|
|
||||||
wit_table(WitTable),
|
|
||||||
{ Witness = witness(City) with
|
|
||||||
p([t,f], WitTable,[DescnJ, Descn2]) }.
|
|
||||||
|
|
||||||
|
|
||||||
:- ensure_loaded(tables).
|
|
||||||
|
|
@ -1,46 +0,0 @@
|
|||||||
|
|
||||||
cons_table(amsterdam, [0.2, 0.8]) :- !.
|
|
||||||
cons_table(_, [0.8, 0.2]).
|
|
||||||
|
|
||||||
|
|
||||||
gender_table(_, [0.55, 0.44]).
|
|
||||||
|
|
||||||
|
|
||||||
hair_color_table(_,
|
|
||||||
/* conservative_city */
|
|
||||||
/* y n */
|
|
||||||
[ 0.05, 0.1,
|
|
||||||
0.95, 0.9 ]).
|
|
||||||
|
|
||||||
|
|
||||||
car_color_table(_,
|
|
||||||
/* t f */
|
|
||||||
[ 0.9, 0.2,
|
|
||||||
0.1, 0.8 ]).
|
|
||||||
|
|
||||||
|
|
||||||
height_table(_,
|
|
||||||
/* m f */
|
|
||||||
[ 0.6, 0.4,
|
|
||||||
0.4, 0.6 ]).
|
|
||||||
|
|
||||||
|
|
||||||
shoe_size_table(_,
|
|
||||||
/* t f */
|
|
||||||
[ 0.9, 0.1,
|
|
||||||
0.1, 0.9 ]).
|
|
||||||
|
|
||||||
|
|
||||||
guilty_table(_, [0.23, 0.77]).
|
|
||||||
|
|
||||||
|
|
||||||
descn_table(_,
|
|
||||||
/* color, hair, height, guilt */
|
|
||||||
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
|
|
||||||
[ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.33, 0.91, 0.95, 0.92,
|
|
||||||
0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.61, 0.28, 0.77, 0.09, 0.05, 0.08]).
|
|
||||||
|
|
||||||
|
|
||||||
wit_table([0.2, 0.45, 0.24, 0.34,
|
|
||||||
0.8, 0.55, 0.76, 0.66]).
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user