diff --git a/OPTYap/opt.config.h b/OPTYap/opt.config.h index f0318643d..a6e12586e 100644 --- a/OPTYap/opt.config.h +++ b/OPTYap/opt.config.h @@ -344,7 +344,6 @@ #endif #if defined(YAPOR) || defined(THREADS) -#undef MODE_DIRECTED_TABLING #undef TABLING_EARLY_COMPLETION #undef INCOMPLETE_TABLING #undef LIMIT_TABLING diff --git a/OPTYap/opt.preds.c b/OPTYap/opt.preds.c index a9f1aad7a..34429c6de 100644 --- a/OPTYap/opt.preds.c +++ b/OPTYap/opt.preds.c @@ -74,26 +74,32 @@ static inline void answer_to_stdout(char *answer); #endif /* YAPOR */ #ifdef TABLING -static inline long show_statistics_table_entries(IOSTREAM *out); -static inline long show_statistics_subgoal_frames(IOSTREAM *out); -static inline long show_statistics_dependency_frames(IOSTREAM *out); -static inline long show_statistics_subgoal_trie_nodes(IOSTREAM *out); -static inline long show_statistics_subgoal_trie_hashes(IOSTREAM *out); -static inline long show_statistics_answer_trie_nodes(IOSTREAM *out); -static inline long show_statistics_answer_trie_hashes(IOSTREAM *out); -static inline long show_statistics_global_trie_nodes(IOSTREAM *out); -static inline long show_statistics_global_trie_hashes(IOSTREAM *out); +static inline struct page_statistics show_statistics_table_entries(IOSTREAM *out); +#if defined(THREADS_FULL_SHARING) || defined(THREADS_CONSUMER_SHARING) +static inline struct page_statistics show_statistics_subgoal_entries(IOSTREAM *out); +#endif /* THREADS_FULL_SHARING || THREADS_CONSUMER_SHARING */ +static inline struct page_statistics show_statistics_subgoal_frames(IOSTREAM *out); +static inline struct page_statistics show_statistics_dependency_frames(IOSTREAM *out); +static inline struct page_statistics show_statistics_subgoal_trie_nodes(IOSTREAM *out); +static inline struct page_statistics show_statistics_subgoal_trie_hashes(IOSTREAM *out); +static inline struct page_statistics show_statistics_answer_trie_nodes(IOSTREAM *out); +static inline struct page_statistics show_statistics_answer_trie_hashes(IOSTREAM *out); +#if defined(THREADS_FULL_SHARING) +static inline struct page_statistics show_statistics_answer_ref_nodes(IOSTREAM *out); +#endif /* THREADS_FULL_SHARING */ +static inline struct page_statistics show_statistics_global_trie_nodes(IOSTREAM *out); +static inline struct page_statistics show_statistics_global_trie_hashes(IOSTREAM *out); #endif /* TABLING */ #ifdef YAPOR -static inline long show_statistics_or_frames(IOSTREAM *out); -static inline long show_statistics_query_goal_solution_frames(IOSTREAM *out); -static inline long show_statistics_query_goal_answer_frames(IOSTREAM *out); +static inline struct page_statistics show_statistics_or_frames(IOSTREAM *out); +static inline struct page_statistics show_statistics_query_goal_solution_frames(IOSTREAM *out); +static inline struct page_statistics show_statistics_query_goal_answer_frames(IOSTREAM *out); #endif /* YAPOR */ #if defined(YAPOR) && defined(TABLING) -static inline long show_statistics_suspension_frames(IOSTREAM *out); +static inline struct page_statistics show_statistics_suspension_frames(IOSTREAM *out); #ifdef TABLING_INNER_CUTS -static inline long show_statistics_table_subgoal_solution_frames(IOSTREAM *out); -static inline long show_statistics_table_subgoal_answer_frames(IOSTREAM *out); +static inline struct page_statistics show_statistics_table_subgoal_solution_frames(IOSTREAM *out); +static inline struct page_statistics show_statistics_table_subgoal_answer_frames(IOSTREAM *out); #endif /* TABLING_INNER_CUTS */ #endif /* YAPOR && TABLING */ @@ -113,16 +119,19 @@ struct page_statistics { #ifdef USE_PAGES_MALLOC long pages_allocated; /* same as struct pages (opt.structs.h) */ #endif /* USE_PAGES_MALLOC */ - long structs_in_use; /* same as struct pages (opt.structs.h) */ + long structs_in_use; /* same as struct pages (opt.structs.h) */ + long bytes_in_use; }; +#define Pg_bytes_in_use(STATS) STATS.bytes_in_use + #ifdef USE_PAGES_MALLOC #ifdef DEBUG_TABLING -#define CHECK_PAGE_FREE_STRUCTS(STR_TYPE, STR_PAGES) \ +#define CHECK_PAGE_FREE_STRUCTS(STR_TYPE, PAGE) \ { pg_hd_ptr pg_hd; \ STR_TYPE *aux_ptr; \ long cont = 0; \ - pg_hd = Pg_free_pg(STR_PAGES); \ + pg_hd = Pg_free_pg(PAGE); \ while (pg_hd) { \ aux_ptr = PgHd_free_str(pg_hd); \ while (aux_ptr) { \ @@ -131,59 +140,66 @@ struct page_statistics { } \ pg_hd = PgHd_next(pg_hd); \ } \ - if(Pg_str_free(STR_PAGES) != cont)printf("ERRRO!!!!!!!!\n");\ - TABLING_ERROR_CHECKING(CHECK_PAGE_FREE_STRUCTS, Pg_str_free(STR_PAGES) != cont); \ + TABLING_ERROR_CHECKING(CHECK_PAGE_FREE_STRUCTS, Pg_str_free(PAGE) != cont); \ } #else -#define CHECK_PAGE_FREE_STRUCTS(STR_TYPE,STR_PAGES) +#define CHECK_PAGE_FREE_STRUCTS(STR_TYPE, PAGE) #endif /* DEBUG_TABLING */ #define INIT_PAGE_STATS(STATS) \ Pg_pg_alloc(STATS) = 0; \ Pg_str_in_use(STATS) = 0 -#define INCREMENT_PAGE_STATS(STATS,PAGE) \ +#define INCREMENT_PAGE_STATS(STATS, PAGE) \ Pg_pg_alloc(STATS) += Pg_pg_alloc(PAGE); \ Pg_str_in_use(STATS) += Pg_str_in_use(PAGE) -#define SHOW_PAGE_STATS_MSG(STR_NAME) " " STR_NAME " %10ld bytes (%ld pages and %ld structs in use)\n" -#define SHOW_PAGE_STATS_ARGS(STATS,STR_TYPE) Pg_str_in_use(STATS) * sizeof(STR_TYPE), Pg_pg_alloc(STATS), Pg_str_in_use(STATS) +#define INCREMENT_AUX_STATS(STATS, BYTES, PAGES) \ + BYTES += Pg_bytes_in_use(STATS); \ + PAGES += Pg_pg_alloc(STATS) +#define SHOW_PAGE_STATS_MSG(STR_NAME) " " STR_NAME " %10ld bytes (%ld pages and %ld structs in use)\n" +#define SHOW_PAGE_STATS_ARGS(STATS, STR_TYPE) Pg_str_in_use(STATS) * sizeof(STR_TYPE), Pg_pg_alloc(STATS), Pg_str_in_use(STATS) #else /* !USE_PAGES_MALLOC */ +#define CHECK_PAGE_FREE_STRUCTS(STR_TYPE, PAGE) #define INIT_PAGE_STATS(STATS) \ Pg_str_in_use(STATS) = 0 -#define INCREMENT_PAGE_STATS(STATS,PAGE) \ +#define INCREMENT_PAGE_STATS(STATS, PAGE) \ Pg_str_in_use(STATS) += Pg_str_in_use(PAGE) -#define CHECK_PAGE_FREE_STRUCTS(STR_TYPE,STR_PAGES) -#define SHOW_PAGE_STATS_MSG(STR_NAME) " " STR_NAME " %10ld bytes (%ld structs in use)\n" -#define SHOW_PAGE_STATS_ARGS(STATS,STR_TYPE) Pg_str_in_use(STATS) * sizeof(STR_TYPE), Pg_str_in_use(STATS) -#endif +#define INCREMENT_AUX_STATS(STATS, BYTES, PAGES) \ + BYTES += Pg_bytes_in_use(STATS) +#define SHOW_PAGE_STATS_MSG(STR_NAME) " " STR_NAME " %10ld bytes (%ld structs in use)\n" +#define SHOW_PAGE_STATS_ARGS(STATS, STR_TYPE) Pg_str_in_use(STATS) * sizeof(STR_TYPE), Pg_str_in_use(STATS) +#endif /* USE_PAGES_MALLOC */ -#define GET_GLOBAL_PAGE_STATS(STATS,STR_PAGES) \ - INIT_PAGE_STATS(STATS); \ - CHECK_PAGE_FREE_STRUCTS(STR_TYPE,STR_PAGES); \ - INCREMENT_PAGE_STATS(STATS,STR_PAGES) -#define GET_REMOTE_PAGE_STATS(STATS,STR_PAGES) \ - INIT_PAGE_STATS(STATS); \ - LOCK(GLOBAL_ThreadHandlesLock); \ - { int wid; \ - for (wid = 0; wid < MAX_THREADS; wid++) { \ - if (!Yap_local[wid]) \ - break; \ - if (REMOTE_ThreadHandle(wid).in_use) { \ - CHECK_PAGE_FREE_STRUCTS(STR_TYPE,STR_PAGES); \ - INCREMENT_PAGE_STATS(STATS,STR_PAGES(wid)); \ - } \ - } \ - } \ - UNLOCK(GLOBAL_ThreadHandlesLock) -#define SHOW_GLOBAL_PAGE_STATS(OUT_STREAM,STR_TYPE,STR_PAGES,STR_NAME) \ - { struct page_statistics stats; \ - GET_GLOBAL_PAGE_STATS(stats,STR_PAGES); \ - Sfprintf(OUT_STREAM,SHOW_PAGE_STATS_MSG(STR_NAME),SHOW_PAGE_STATS_ARGS(stats,STR_TYPE)); \ - return Pg_str_in_use(stats) * sizeof(STR_TYPE); \ +#define GET_GLOBAL_PAGE_STATS(STATS, STR_TYPE, STR_PAGES) \ + INIT_PAGE_STATS(STATS); \ + CHECK_PAGE_FREE_STRUCTS(STR_TYPE, STR_PAGES); \ + INCREMENT_PAGE_STATS(STATS, STR_PAGES); \ + Pg_bytes_in_use(STATS) = Pg_str_in_use(STATS) * sizeof(STR_TYPE) +#define GET_REMOTE_PAGE_STATS(STATS, STR_TYPE, STR_PAGES) \ + INIT_PAGE_STATS(STATS); \ + LOCK(GLOBAL_ThreadHandlesLock); \ + { int wid; \ + for (wid = 0; wid < MAX_THREADS; wid++) { \ + if (! Yap_local[wid]) \ + break; \ + if (REMOTE_ThreadHandle(wid).in_use) { \ + CHECK_PAGE_FREE_STRUCTS(STR_TYPE, STR_PAGES(wid)); \ + INCREMENT_PAGE_STATS(STATS, STR_PAGES(wid)); \ + } \ + } \ + } \ + UNLOCK(GLOBAL_ThreadHandlesLock); \ + Pg_bytes_in_use(STATS) = Pg_str_in_use(STATS) * sizeof(STR_TYPE) + +#define SHOW_GLOBAL_PAGE_STATS(OUT_STREAM, STR_TYPE, STR_PAGES, STR_NAME) \ + { struct page_statistics stats; \ + GET_GLOBAL_PAGE_STATS(stats, STR_TYPE, STR_PAGES); \ + Sfprintf(OUT_STREAM, SHOW_PAGE_STATS_MSG(STR_NAME), SHOW_PAGE_STATS_ARGS(stats, STR_TYPE)); \ + return stats; \ } -#define SHOW_REMOTE_PAGE_STATS(OUT_STREAM,STR_TYPE,STR_PAGES,STR_NAME) \ - { struct page_statistics stats; \ - GET_REMOTE_PAGE_STATS(stats,STR_PAGES); \ - Sfprintf(OUT_STREAM,SHOW_PAGE_STATS_MSG(STR_NAME),SHOW_PAGE_STATS_ARGS(stats,STR_TYPE)); \ - return Pg_str_in_use(stats) * sizeof(STR_TYPE); \ +#define SHOW_REMOTE_PAGE_STATS(OUT_STREAM, STR_TYPE, STR_PAGES, STR_NAME) \ + { struct page_statistics stats; \ + GET_REMOTE_PAGE_STATS(stats, STR_TYPE, STR_PAGES); \ + Sfprintf(OUT_STREAM, SHOW_PAGE_STATS_MSG(STR_NAME), SHOW_PAGE_STATS_ARGS(stats, STR_TYPE)); \ + return stats; \ } @@ -636,35 +652,56 @@ static Int p_show_statistics_table( USES_REGS1 ) { static Int p_show_statistics_tabling( USES_REGS1 ) { + struct page_statistics stats; + long bytes, total_bytes = 0; +#ifdef USE_PAGES_MALLOC + long total_pages = 0; +#endif /* USE_PAGES_MALLOC */ IOSTREAM *out; - long total_bytes = 0, aux_bytes; if (!PL_get_stream_handle(Yap_InitSlot(Deref(ARG1) PASS_REGS), &out)) return (FALSE); - aux_bytes = 0; + bytes = 0; Sfprintf(out, "Execution data structures\n"); - aux_bytes += show_statistics_table_entries(out); - aux_bytes += show_statistics_subgoal_frames(out); - aux_bytes += show_statistics_dependency_frames(out); - Sfprintf(out, " Memory in use (I): %10ld bytes\n\n", aux_bytes); - total_bytes += aux_bytes; - aux_bytes = 0; + stats = show_statistics_table_entries(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); +#if defined(THREADS_FULL_SHARING) || defined(THREADS_CONSUMER_SHARING) + stats = show_statistics_subgoal_entries(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); +#endif /* THREADS_FULL_SHARING || THREADS_CONSUMER_SHARING */ + stats = show_statistics_subgoal_frames(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_dependency_frames(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + Sfprintf(out, " Memory in use (I): %10ld bytes\n\n", bytes); + total_bytes += bytes; + bytes = 0; Sfprintf(out, "Local trie data structures\n"); - aux_bytes += show_statistics_subgoal_trie_nodes(out); - aux_bytes += show_statistics_answer_trie_nodes(out); - aux_bytes += show_statistics_subgoal_trie_hashes(out); - aux_bytes += show_statistics_answer_trie_hashes(out); - Sfprintf(out, " Memory in use (II): %10ld bytes\n\n", aux_bytes); - total_bytes += aux_bytes; - aux_bytes = 0; + stats = show_statistics_subgoal_trie_nodes(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_answer_trie_nodes(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_subgoal_trie_hashes(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_answer_trie_hashes(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); +#if defined(THREADS_FULL_SHARING) + stats = show_statistics_answer_ref_nodes(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); +#endif /* THREADS_FULL_SHARING */ + Sfprintf(out, " Memory in use (II): %10ld bytes\n\n", bytes); + total_bytes += bytes; + bytes = 0; Sfprintf(out, "Global trie data structures\n"); - aux_bytes += show_statistics_global_trie_nodes(out); - aux_bytes += show_statistics_global_trie_hashes(out); - Sfprintf(out, " Memory in use (III): %10ld bytes\n\n", aux_bytes); - total_bytes += aux_bytes; + stats = show_statistics_global_trie_nodes(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_global_trie_hashes(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + Sfprintf(out, " Memory in use (III): %10ld bytes\n\n", bytes); + total_bytes += bytes; #ifdef USE_PAGES_MALLOC Sfprintf(out, "Total memory in use (I+II+III): %10ld bytes (%ld pages in use)\n", - total_bytes, Pg_str_in_use(GLOBAL_pages_void)); + total_bytes, total_pages); Sfprintf(out, "Total memory allocated: %10ld bytes (%ld pages in total)\n", Pg_pg_alloc(GLOBAL_pages_void) * Yap_page_size, Pg_pg_alloc(GLOBAL_pages_void)); #else @@ -674,6 +711,7 @@ static Int p_show_statistics_tabling( USES_REGS1 ) { return (TRUE); } + static Int p_show_statistics_global_trie( USES_REGS1 ) { IOSTREAM *out; @@ -686,6 +724,7 @@ static Int p_show_statistics_global_trie( USES_REGS1 ) { #endif /* TABLING */ + /********************************* ** YapOr C Predicates ** *********************************/ @@ -779,25 +818,32 @@ static Int p_parallel_new_answer( USES_REGS1 ) { static Int p_show_statistics_or( USES_REGS1 ) { + struct page_statistics stats; + long bytes, total_bytes = 0; +#ifdef USE_PAGES_MALLOC + long total_pages = 0; +#endif /* USE_PAGES_MALLOC */ IOSTREAM *out; - long total_bytes = 0, aux_bytes; if (!PL_get_stream_handle(Yap_InitSlot(Deref(ARG1) PASS_REGS), &out)) return (FALSE); - aux_bytes = 0; + bytes = 0; Sfprintf(out, "Execution data structures\n"); - aux_bytes += show_statistics_or_frames(out); - Sfprintf(out, " Memory in use (I): %10ld bytes\n\n", aux_bytes); - total_bytes += aux_bytes; - aux_bytes = 0; + stats = show_statistics_or_frames(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + Sfprintf(out, " Memory in use (I): %10ld bytes\n\n", bytes); + total_bytes += bytes; + bytes = 0; Sfprintf(out, "Cut support data structures\n"); - aux_bytes += show_statistics_query_goal_solution_frames(out); - aux_bytes += show_statistics_query_goal_answer_frames(out); - Sfprintf(out, " Memory in use (II): %10ld bytes\n\n", aux_bytes); - total_bytes += aux_bytes; + stats = show_statistics_query_goal_solution_frames(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_query_goal_answer_frames(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + Sfprintf(out, " Memory in use (II): %10ld bytes\n\n", bytes); + total_bytes += bytes; #ifdef USE_PAGES_MALLOC Sfprintf(out, "Total memory in use (I+II): %10ld bytes (%ld pages in use)\n", - total_bytes, Pg_str_in_use(GLOBAL_pages_void)); + total_bytes, total_pages); Sfprintf(out, "Total memory allocated: %10ld bytes (%ld pages in total)\n", Pg_pg_alloc(GLOBAL_pages_void) * Yap_page_size, Pg_pg_alloc(GLOBAL_pages_void)); #else @@ -816,47 +862,74 @@ static Int p_show_statistics_or( USES_REGS1 ) { #if defined(YAPOR) && defined(TABLING) static Int p_show_statistics_opt( USES_REGS1 ) { + struct page_statistics stats; + long bytes, total_bytes = 0; +#ifdef USE_PAGES_MALLOC + long total_pages = 0; +#endif /* USE_PAGES_MALLOC */ IOSTREAM *out; - long total_bytes = 0, aux_bytes; if (!PL_get_stream_handle(Yap_InitSlot(Deref(ARG1) PASS_REGS), &out)) return (FALSE); - aux_bytes = 0; + bytes = 0; Sfprintf(out, "Execution data structures\n"); - aux_bytes += show_statistics_table_entries(out); - aux_bytes += show_statistics_subgoal_frames(out); - aux_bytes += show_statistics_dependency_frames(out); - aux_bytes += show_statistics_or_frames(out); - aux_bytes += show_statistics_suspension_frames(out); - Sfprintf(out, " Memory in use (I): %10ld bytes\n\n", aux_bytes); - total_bytes += aux_bytes; - aux_bytes = 0; + stats = show_statistics_table_entries(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); +#if defined(THREADS_FULL_SHARING) || defined(THREADS_CONSUMER_SHARING) + stats = show_statistics_subgoal_entries(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); +#endif /* THREADS_FULL_SHARING || THREADS_CONSUMER_SHARING */ + stats = show_statistics_subgoal_frames(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_dependency_frames(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_or_frames(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_suspension_frames(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + Sfprintf(out, " Memory in use (I): %10ld bytes\n\n", bytes); + total_bytes += bytes; + bytes = 0; Sfprintf(out, "Local trie data structures\n"); - aux_bytes += show_statistics_subgoal_trie_nodes(out); - aux_bytes += show_statistics_answer_trie_nodes(out); - aux_bytes += show_statistics_subgoal_trie_hashes(out); - aux_bytes += show_statistics_answer_trie_hashes(out); - Sfprintf(out, " Memory in use (II): %10ld bytes\n\n", aux_bytes); - total_bytes += aux_bytes; - aux_bytes = 0; + stats = show_statistics_subgoal_trie_nodes(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_answer_trie_nodes(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_subgoal_trie_hashes(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_answer_trie_hashes(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); +#if defined(THREADS_FULL_SHARING) + stats = show_statistics_answer_ref_nodes(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); +#endif /* THREADS_FULL_SHARING */ + Sfprintf(out, " Memory in use (II): %10ld bytes\n\n", bytes); + total_bytes += bytes; + bytes = 0; Sfprintf(out, "Global trie data structures\n"); - aux_bytes += show_statistics_global_trie_nodes(out); - aux_bytes += show_statistics_global_trie_hashes(out); - Sfprintf(out, " Memory in use (III): %10ld bytes\n\n", aux_bytes); - total_bytes += aux_bytes; - aux_bytes = 0; + stats = show_statistics_global_trie_nodes(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_global_trie_hashes(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + Sfprintf(out, " Memory in use (III): %10ld bytes\n\n", bytes); + total_bytes += bytes; + bytes = 0; Sfprintf(out, "Cut support data structures\n"); - aux_bytes += show_statistics_query_goal_solution_frames(out); - aux_bytes += show_statistics_query_goal_answer_frames(out); + stats = show_statistics_query_goal_solution_frames(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_query_goal_answer_frames(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); #ifdef TABLING_INNER_CUTS - aux_bytes += show_statistics_table_subgoal_solution_frames(out); - aux_bytes += show_statistics_table_subgoal_answer_frames(out); + stats = show_statistics_table_subgoal_solution_frames(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); + stats = show_statistics_table_subgoal_answer_frames(out); + INCREMENT_AUX_STATS(stats, bytes, total_pages); #endif /* TABLING_INNER_CUTS */ - Sfprintf(out, " Memory in use (IV): %10ld bytes\n\n", aux_bytes); - total_bytes += aux_bytes; + Sfprintf(out, " Memory in use (IV): %10ld bytes\n\n", bytes); + total_bytes += bytes; #ifdef USE_PAGES_MALLOC Sfprintf(out, "Total memory in use (I+II+III+IV): %10ld bytes (%ld pages in use)\n", - total_bytes, Pg_str_in_use(GLOBAL_pages_void)); + total_bytes, total_pages); Sfprintf(out, "Total memory allocated: %10ld bytes (%ld pages in total)\n", Pg_pg_alloc(GLOBAL_pages_void) * Yap_page_size, Pg_pg_alloc(GLOBAL_pages_void)); #else @@ -876,121 +949,121 @@ static Int p_get_optyap_statistics( USES_REGS1 ) { value = IntOfTerm(Deref(ARG1)); #ifdef TABLING if (value == 0 || value == 1) { /* table_entries */ - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_tab_ent); - bytes += Pg_str_in_use(stats) * sizeof(struct table_entry); + GET_GLOBAL_PAGE_STATS(stats, struct table_entry, GLOBAL_pages_tab_ent); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } #if defined(THREADS_FULL_SHARING) || defined(THREADS_CONSUMER_SHARING) if (value == 0 || value == 16) { /* subgoal_entries */ - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_sg_entry); - bytes += Pg_str_in_use(stats) * sizeof(struct subgoal_entry); + GET_GLOBAL_PAGE_STATS(stats, struct subgoal_entry, GLOBAL_pages_sg_ent); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } #endif /* THREADS_FULL_SHARING || THREADS_CONSUMER_SHARING */ if (value == 0 || value == 2) { /* subgoal_frames */ #if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING) && !defined(THREADS_FULL_SHARING) && !defined(THREADS_CONSUMER_SHARING) - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_sg_fr); + GET_GLOBAL_PAGE_STATS(stats, struct subgoal_frame, GLOBAL_pages_sg_fr); #else - GET_REMOTE_PAGE_STATS(stats, REMOTE_pages_sg_fr); + GET_REMOTE_PAGE_STATS(stats, struct subgoal_frame, REMOTE_pages_sg_fr); #endif - bytes += Pg_str_in_use(stats) * sizeof(struct subgoal_frame); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } if (value == 0 || value == 3) { /* dependency_frames */ #if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING) && !defined(THREADS_FULL_SHARING) && !defined(THREADS_CONSUMER_SHARING) - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_dep_fr); + GET_GLOBAL_PAGE_STATS(stats, struct dependency_frame, GLOBAL_pages_dep_fr); #else - GET_REMOTE_PAGE_STATS(stats, REMOTE_pages_dep_fr); + GET_REMOTE_PAGE_STATS(stats, struct dependency_frame, REMOTE_pages_dep_fr); #endif - bytes += Pg_str_in_use(stats) * sizeof(struct dependency_frame); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } if (value == 0 || value == 6) { /* subgoal_trie_nodes */ #if !defined(THREADS_NO_SHARING) - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_sg_node); + GET_GLOBAL_PAGE_STATS(stats, struct subgoal_trie_node, GLOBAL_pages_sg_node); #else - GET_REMOTE_PAGE_STATS(stats, REMOTE_pages_sg_node); + GET_REMOTE_PAGE_STATS(stats, struct subgoal_trie_node, REMOTE_pages_sg_node); #endif - bytes += Pg_str_in_use(stats) * sizeof(struct subgoal_trie_node); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } if (value == 0 || value == 8) { /* subgoal_trie_hashes */ #if !defined(THREADS_NO_SHARING) - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_sg_hash); + GET_GLOBAL_PAGE_STATS(stats, struct subgoal_trie_hash, GLOBAL_pages_sg_hash); #else - GET_REMOTE_PAGE_STATS(stats, REMOTE_pages_sg_hash); + GET_REMOTE_PAGE_STATS(stats, struct subgoal_trie_hash, REMOTE_pages_sg_hash); #endif - bytes += Pg_str_in_use(stats) * sizeof(struct subgoal_trie_hash); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } if (value == 0 || value == 7) { /* answer_trie_nodes */ #if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING) - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_ans_node); + GET_GLOBAL_PAGE_STATS(stats, struct answer_trie_node, GLOBAL_pages_ans_node); #else - GET_REMOTE_PAGE_STATS(stats, REMOTE_pages_ans_node); + GET_REMOTE_PAGE_STATS(stats, struct answer_trie_node, REMOTE_pages_ans_node); #endif - bytes += Pg_str_in_use(stats) * sizeof(struct answer_trie_node); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } if (value == 0 || value == 9) { /* answer_trie_hashes */ #if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING) - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_ans_hash); + GET_GLOBAL_PAGE_STATS(stats, struct answer_trie_hash, GLOBAL_pages_ans_hash); #else - GET_REMOTE_PAGE_STATS(stats, REMOTE_pages_ans_hash); + GET_REMOTE_PAGE_STATS(stats, struct answer_trie_hash, REMOTE_pages_ans_hash); #endif - bytes += Pg_str_in_use(stats) * sizeof(struct answer_trie_hash); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } #if defined(THREADS_FULL_SHARING) if (value == 0 || value == 17) { /* answer_ref_nodes */ - GET_REMOTE_PAGE_STATS(stats, REMOTE_pages_ans_ref_node); - bytes += Pg_str_in_use(stats) * sizeof(struct answer_ref_node); + GET_REMOTE_PAGE_STATS(stats, struct answer_ref_node, REMOTE_pages_ans_ref_node); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } #endif /* THREADS_FULL_SHARING */ if (value == 0 || value == 10) { /* global_trie_nodes */ - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_gt_node); - bytes += Pg_str_in_use(stats) * sizeof(struct global_trie_node); + GET_GLOBAL_PAGE_STATS(stats, struct global_trie_node, GLOBAL_pages_gt_node); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } if (value == 0 || value == 11) { /* global_trie_hashes */ - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_gt_hash); - bytes += Pg_str_in_use(stats) * sizeof(struct global_trie_hash); + GET_GLOBAL_PAGE_STATS(stats, struct global_trie_hash, GLOBAL_pages_gt_hash); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } #endif /* TABLING */ #ifdef YAPOR if (value == 0 || value == 4) { /* or_frames */ - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_or_fr); - bytes += Pg_str_in_use(stats) * sizeof(struct or_frame); + GET_GLOBAL_PAGE_STATS(stats, struct or_frame, GLOBAL_pages_or_fr); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } if (value == 0 || value == 12) { /* query_goal_solution_frames */ - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_qg_sol_fr); - bytes += Pg_str_in_use(stats) * sizeof(struct query_goal_solution_frame); + GET_GLOBAL_PAGE_STATS(stats, struct query_goal_solution_frame, GLOBAL_pages_qg_sol_fr); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } if (value == 0 || value == 13) { /* query_goal_answer_frames */ - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_qg_ans_fr); - bytes += Pg_str_in_use(stats) * sizeof(struct query_goal_answer_frame); + GET_GLOBAL_PAGE_STATS(stats, struct query_goal_answer_frame, GLOBAL_pages_qg_ans_fr); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } #endif /* YAPOR */ #if defined(YAPOR) && defined(TABLING) if (value == 0 || value == 5) { /* suspension_frames */ - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_susp_fr); - bytes += Pg_str_in_use(stats) * sizeof(struct suspension_frame); + GET_GLOBAL_PAGE_STATS(stats, struct suspension_frame, GLOBAL_pages_susp_fr); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } #ifdef TABLING_INNER_CUTS if (value == 0 || value == 14) { /* table_subgoal_solution_frames */ - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_tg_sol_fr); - bytes += Pg_str_in_use(stats) * sizeof(struct table_subgoal_solution_frame); + GET_GLOBAL_PAGE_STATS(stats, struct table_subgoal_solution_frame, GLOBAL_pages_tg_sol_fr); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } if (value == 0 || value == 15) { /* table_subgoal_answer_frames */ - GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_tg_ans_fr); - bytes += Pg_str_in_use(stats) * sizeof(struct table_subgoal_answer_frame); + GET_GLOBAL_PAGE_STATS(stats, struct table_subgoal_answer_frame, GLOBAL_pages_tg_ans_fr); + bytes += Pg_bytes_in_use(stats); if (value != 0) structs = Pg_str_in_use(stats); } #endif /* TABLING_INNER_CUTS */ @@ -1120,11 +1193,17 @@ static inline void answer_to_stdout(char *answer) { #ifdef TABLING -static inline long show_statistics_table_entries(IOSTREAM *out) { +static inline struct page_statistics show_statistics_table_entries(IOSTREAM *out) { SHOW_GLOBAL_PAGE_STATS(out, struct table_entry, GLOBAL_pages_tab_ent, "Table entries: "); } -static inline long show_statistics_subgoal_frames(IOSTREAM *out) { +#if defined(THREADS_FULL_SHARING) || defined(THREADS_CONSUMER_SHARING) +static inline struct page_statistics show_statistics_subgoal_entries(IOSTREAM *out) { + SHOW_GLOBAL_PAGE_STATS(out, struct subgoal_entry, GLOBAL_pages_sg_ent, "Subgoal entries: "); +} +#endif /* THREADS_FULL_SHARING || THREADS_CONSUMER_SHARING */ + +static inline struct page_statistics show_statistics_subgoal_frames(IOSTREAM *out) { #if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING) && !defined(THREADS_FULL_SHARING) && !defined(THREADS_CONSUMER_SHARING) SHOW_GLOBAL_PAGE_STATS(out, struct subgoal_frame, GLOBAL_pages_sg_fr, "Subgoal frames: "); #else @@ -1132,7 +1211,7 @@ static inline long show_statistics_subgoal_frames(IOSTREAM *out) { #endif } -static inline long show_statistics_dependency_frames(IOSTREAM *out) { +static inline struct page_statistics show_statistics_dependency_frames(IOSTREAM *out) { #if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING) && !defined(THREADS_FULL_SHARING) && !defined(THREADS_CONSUMER_SHARING) SHOW_GLOBAL_PAGE_STATS(out, struct dependency_frame, GLOBAL_pages_dep_fr, "Dependency frames: "); #else @@ -1140,7 +1219,7 @@ static inline long show_statistics_dependency_frames(IOSTREAM *out) { #endif } -static inline long show_statistics_subgoal_trie_nodes(IOSTREAM *out) { +static inline struct page_statistics show_statistics_subgoal_trie_nodes(IOSTREAM *out) { #if !defined(THREADS_NO_SHARING) SHOW_GLOBAL_PAGE_STATS(out, struct subgoal_trie_node, GLOBAL_pages_sg_node, "Subgoal trie nodes: "); #else @@ -1148,7 +1227,7 @@ static inline long show_statistics_subgoal_trie_nodes(IOSTREAM *out) { #endif } -static inline long show_statistics_subgoal_trie_hashes(IOSTREAM *out) { +static inline struct page_statistics show_statistics_subgoal_trie_hashes(IOSTREAM *out) { #if !defined(THREADS_NO_SHARING) SHOW_GLOBAL_PAGE_STATS(out, struct subgoal_trie_hash, GLOBAL_pages_sg_hash, "Subgoal trie hashes: "); #else @@ -1156,7 +1235,7 @@ static inline long show_statistics_subgoal_trie_hashes(IOSTREAM *out) { #endif } -static inline long show_statistics_answer_trie_nodes(IOSTREAM *out) { +static inline struct page_statistics show_statistics_answer_trie_nodes(IOSTREAM *out) { #if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING) SHOW_GLOBAL_PAGE_STATS(out, struct answer_trie_node, GLOBAL_pages_ans_node, "Answer trie nodes: "); #else @@ -1164,7 +1243,7 @@ static inline long show_statistics_answer_trie_nodes(IOSTREAM *out) { #endif } -static inline long show_statistics_answer_trie_hashes(IOSTREAM *out) { +static inline struct page_statistics show_statistics_answer_trie_hashes(IOSTREAM *out) { #if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING) SHOW_GLOBAL_PAGE_STATS(out, struct answer_trie_hash, GLOBAL_pages_ans_hash, "Answer trie hashes: "); #else @@ -1172,42 +1251,48 @@ static inline long show_statistics_answer_trie_hashes(IOSTREAM *out) { #endif } -static inline long show_statistics_global_trie_nodes(IOSTREAM *out) { +#if defined(THREADS_FULL_SHARING) +static inline struct page_statistics show_statistics_answer_ref_nodes(IOSTREAM *out) { + SHOW_GLOBAL_PAGE_STATS(out, struct answer_ref_node, REMOTE_pages_ans_ref_node, "Answer ref nodes: "); +} +#endif /* THREADS_FULL_SHARING */ + +static inline struct page_statistics show_statistics_global_trie_nodes(IOSTREAM *out) { SHOW_GLOBAL_PAGE_STATS(out, struct global_trie_node, GLOBAL_pages_gt_node, "Global trie nodes: "); } -static inline long show_statistics_global_trie_hashes(IOSTREAM *out) { +static inline struct page_statistics show_statistics_global_trie_hashes(IOSTREAM *out) { SHOW_GLOBAL_PAGE_STATS(out, struct global_trie_hash, GLOBAL_pages_gt_hash, "Global trie hashes: "); } #endif /* TABLING */ #ifdef YAPOR -static inline long show_statistics_or_frames(IOSTREAM *out) { +static inline struct page_statistics show_statistics_or_frames(IOSTREAM *out) { SHOW_GLOBAL_PAGE_STATS(out, struct or_frame, GLOBAL_pages_or_fr, "Or-frames: "); } -static inline long show_statistics_query_goal_solution_frames(IOSTREAM *out) { +static inline struct page_statistics show_statistics_query_goal_solution_frames(IOSTREAM *out) { SHOW_GLOBAL_PAGE_STATS(out, struct query_goal_solution_frame, GLOBAL_pages_qg_sol_fr, "Query goal solution frames: "); } -static inline long show_statistics_query_goal_answer_frames(IOSTREAM *out) { +static inline struct page_statistics show_statistics_query_goal_answer_frames(IOSTREAM *out) { SHOW_GLOBAL_PAGE_STATS(out, struct query_goal_answer_frame, GLOBAL_pages_qg_ans_fr, "Query goal answer frames: "); } #endif /* YAPOR */ #if defined(YAPOR) && defined(TABLING) -static inline long show_statistics_suspension_frames(IOSTREAM *out) { +static inline struct page_statistics show_statistics_suspension_frames(IOSTREAM *out) { SHOW_GLOBAL_PAGE_STATS(out, struct suspension_frame, GLOBAL_pages_susp_fr, "Suspension frames: "); } #ifdef TABLING_INNER_CUTS -static inline long show_statistics_table_subgoal_solution_frames(IOSTREAM *out) { +static inline struct page_statistics show_statistics_table_subgoal_solution_frames(IOSTREAM *out) { SHOW_GLOBAL_PAGE_STATS(out, struct table_subgoal_solution_frame, GLOBAL_pages_tg_sol_fr, "Table subgoal solution frames:"); } -static inline long show_statistics_table_subgoal_answer_frames(IOSTREAM *out) { +static inline struct page_statistics show_statistics_table_subgoal_answer_frames(IOSTREAM *out) { SHOW_GLOBAL_PAGE_STATS(out, struct table_subgoal_answer_frame, GLOBAL_pages_tg_ans_fr, "Table subgoal answer frames: "); } #endif /* TABLING_INNER_CUTS */ diff --git a/OPTYap/tab.macros.h b/OPTYap/tab.macros.h index c63ef573a..7960daeb0 100644 --- a/OPTYap/tab.macros.h +++ b/OPTYap/tab.macros.h @@ -606,6 +606,7 @@ static inline void adjust_freeze_registers(void) { static inline void mark_as_completed(sg_fr_ptr sg_fr) { + CACHE_REGS LOCK_SG_FR(sg_fr); SgFr_state(sg_fr) = complete; UNLOCK_SG_FR(sg_fr); diff --git a/OPTYap/tab.tries.i b/OPTYap/tab.tries.i index 180819ce7..e36d08ae3 100644 --- a/OPTYap/tab.tries.i +++ b/OPTYap/tab.tries.i @@ -1388,6 +1388,7 @@ static inline ans_node_ptr answer_search_loop(sg_fr_ptr sg_fr, ans_node_ptr curr } static inline ans_node_ptr answer_search_min_max(sg_fr_ptr sg_fr, ans_node_ptr current_node, Term t, int mode) { + CACHE_REGS ans_node_ptr child_node; Term child_term; Float trie_value, term_value; @@ -1486,6 +1487,8 @@ static inline ans_node_ptr answer_search_min_max(sg_fr_ptr sg_fr, ans_node_ptr c SgFr_invalid_chain(SG_FR) = NODE static void invalidate_answer_trie(ans_node_ptr current_node, sg_fr_ptr sg_fr, int position) { + CACHE_REGS + if (IS_ANSWER_TRIE_HASH(current_node)) { ans_hash_ptr hash; ans_node_ptr *bucket, *last_bucket; diff --git a/packages/CLPBN/clpbn/bp.yap b/packages/CLPBN/clpbn/bp.yap index e5b00b4da..c7d9f7c94 100644 --- a/packages/CLPBN/clpbn/bp.yap +++ b/packages/CLPBN/clpbn/bp.yap @@ -8,6 +8,8 @@ :- module(clpbn_bp, [bp/3, check_if_bp_done/1, + set_solver_parameter/2, + use_log_space/0, init_bp_solver/4, run_bp_solver/3, finalize_bp_solver/1 @@ -34,116 +36,136 @@ :- attribute id/1. -:- dynamic num_bayes_nets/1. +:- dynamic network_counting/1. check_if_bp_done(_Var). -num_bayes_nets(0). +network_counting(0). + + +:- set_solver_parameter(run_mode, normal). +%:- set_solver_parameter(run_mode, convert). +%: -set_solver_parameter(run_mode, compress). + +:- set_solver_parameter(schedule, seq_fixed). +%:- set_solver_parameter(schedule, seq_random). +%:- set_solver_parameter(schedule, parallel). +%:- set_solver_parameter(schedule, max_residual). + +:- set_solver_parameter(accuracy, 0.0001). + +:- set_solver_parameter(max_iter, 1000). + +:- set_solver_parameter(always_loopy_solver, false). + +% :- use_log_space. bp([[]],_,_) :- !. bp([QueryVars], AllVars, Output) :- - init_bp_solver(_, AllVars, _, BayesNet), - run_bp_solver([QueryVars], LPs, BayesNet), - finalize_bp_solver(BayesNet), - clpbn_bind_vals([QueryVars], LPs, Output). + init_bp_solver(_, AllVars, _, BayesNet), + run_bp_solver([QueryVars], LPs, BayesNet), + finalize_bp_solver(BayesNet), + clpbn_bind_vals([QueryVars], LPs, Output). init_bp_solver(_, AllVars, _, (BayesNet, DistIds)) :- - %inc_num_bayes_nets, - %(showprofres(50) -> true ; true), - process_ids(AllVars, 0, DistIds0), - get_vars_info(AllVars, VarsInfo), - sort(DistIds0, DistIds), - %(num_bayes_nets(0) -> writeln(vars:VarsInfo) ; true), - %(num_bayes_nets(0) -> writeln(dists:DistsInfo) ; true), - create_network(VarsInfo, BayesNet). - %set_extra_vars_info(BayesNet, ExtraVarsInfo). + %inc_network_counting, + process_ids(AllVars, 0, DistIds0), + get_vars_info(AllVars, VarsInfo), + sort(DistIds0, DistIds), + %(network_counting(0) -> writeln(vars:VarsInfo) ; true), + %(network_counting(0) -> writeln(distsids:DistIds) ; true), + create_network(VarsInfo, BayesNet). + %get_extra_vars_info(AllVars, ExtraVarsInfo), + %(network_counting(0) -> writeln(extra:ExtraVarsInfo) ; true), + %set_extra_vars_info(BayesNet, ExtraVarsInfo). process_ids([], _, []). process_ids([V|Vs], VarId0, [DistId|DistIds]) :- - clpbn:get_atts(V, [dist(DistId, _)]), !, - put_atts(V, [id(VarId0)]), - VarId is VarId0 + 1, - process_ids(Vs, VarId, DistIds). + clpbn:get_atts(V, [dist(DistId, _)]), !, + put_atts(V, [id(VarId0)]), + VarId is VarId0 + 1, + process_ids(Vs, VarId, DistIds). process_ids([_|Vs], VarId, DistIds) :- - process_ids(Vs, VarId, DistIds). + process_ids(Vs, VarId, DistIds). get_vars_info([], []). get_vars_info([V|Vs], [var(VarId, DSize, Ev, ParentIds, DistId)|VarsInfo]) :- - clpbn:get_atts(V, [dist(DistId, Parents)]), !, - get_atts(V, [id(VarId)]), - get_dist_domain_size(DistId, DSize), - get_evidence(V, Ev), - vars2ids(Parents, ParentIds), - get_vars_info(Vs, VarsInfo). + clpbn:get_atts(V, [dist(DistId, Parents)]), !, + get_atts(V, [id(VarId)]), + get_dist_domain_size(DistId, DSize), + get_evidence(V, Ev), + vars2ids(Parents, ParentIds), + get_vars_info(Vs, VarsInfo). get_vars_info([_|Vs], VarsInfo) :- - get_vars_info(Vs, VarsInfo). + get_vars_info(Vs, VarsInfo). vars2ids([], []). vars2ids([V|QueryVars], [VarId|Ids]) :- - get_atts(V, [id(VarId)]), - vars2ids(QueryVars, Ids). + get_atts(V, [id(VarId)]), + vars2ids(QueryVars, Ids). get_evidence(V, Ev) :- - clpbn:get_atts(V, [evidence(Ev)]), !. + clpbn:get_atts(V, [evidence(Ev)]), !. get_evidence(_V, -1). % no evidence !!! get_extra_vars_info([], []). get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :- - get_atts(V, [id(VarId)]), !, - clpbn:get_atts(V, [key(Key),dist(DistId, _)]), - term_to_atom(Key, Label), - get_dist_domain(DistId, Domain0), - numbers2atoms(Domain0, Domain), - get_extra_vars_info(Vs, VarsInfo). + get_atts(V, [id(VarId)]), !, + clpbn:get_atts(V, [key(Key),dist(DistId, _)]), + term_to_atom(Key, Label), + get_dist_domain(DistId, Domain0), + numbers2atoms(Domain0, Domain), + get_extra_vars_info(Vs, VarsInfo). +get_extra_vars_info([_|Vs], VarsInfo) :- + get_extra_vars_info(Vs, VarsInfo). numbers2atoms([], []). numbers2atoms([Atom|L0], [Atom|L]) :- - atom(Atom), !, - numbers2atoms(L0, L). + atom(Atom), !, + numbers2atoms(L0, L). numbers2atoms([Number|L0], [Atom|L]) :- - number_atom(Number, Atom), - numbers2atoms(L0, L). + number_atom(Number, Atom), + numbers2atoms(L0, L). run_bp_solver(QVsL0, LPs, (BayesNet, DistIds)) :- - get_dists_parameters(DistIds, DistsParams), - set_parameters(BayesNet, DistsParams), - process_query_list(QVsL0, QVsL), - %writeln(' qvs':QVsL), - %(num_bayes_nets(1506) -> writeln(qvs:QVsL) ; true), - run_solver(BayesNet, QVsL, LPs). + get_dists_parameters(DistIds, DistsParams), + set_parameters(BayesNet, DistsParams), + process_query_list(QVsL0, QVsL), + %(network_counting(0) -> writeln(qvs:QVsL) ; true), + run_solver(BayesNet, QVsL, LPs). process_query_list([], []). process_query_list([[V]|QueryVars], [VarId|Ids]) :- !, - get_atts(V, [id(VarId)]), - process_query_list(QueryVars, Ids). + get_atts(V, [id(VarId)]), + process_query_list(QueryVars, Ids). process_query_list([Vs|QueryVars], [VarIds|Ids]) :- - vars2ids(Vs, VarIds), - process_query_list(QueryVars, Ids). + vars2ids(Vs, VarIds), + process_query_list(QueryVars, Ids). get_dists_parameters([],[]). get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :- - get_dist_params(Id, Params), - get_dists_parameters(Ids, DistsInfo). + get_dist_params(Id, Params), + get_dists_parameters(Ids, DistsInfo). finalize_bp_solver((BayesNet, _)) :- - delete_bayes_net(BayesNet). + free_bayesian_network(BayesNet). -inc_num_bayes_nets :- - retract(num_bayes_nets(Count0)), - Count is Count0 + 1, - assert(num_bayes_nets(Count)). +inc_network_counting :- + retract(network_counting(Count0)), + Count is Count0 + 1, + assert(network_counting(Count)). diff --git a/packages/CLPBN/clpbn/bp/BPNodeInfo.cpp b/packages/CLPBN/clpbn/bp/BPNodeInfo.cpp deleted file mode 100755 index f566bcbfb..000000000 --- a/packages/CLPBN/clpbn/bp/BPNodeInfo.cpp +++ /dev/null @@ -1,149 +0,0 @@ -#include -#include - -#include - -#include "BPNodeInfo.h" -#include "BPSolver.h" - -BPNodeInfo::BPNodeInfo (BayesNode* node) -{ - node_ = node; - ds_ = node->getDomainSize(); - piValsCalc_ = false; - ldValsCalc_ = false; - nPiMsgsRcv_ = 0; - nLdMsgsRcv_ = 0; - piVals_.resize (ds_, 1); - ldVals_.resize (ds_, 1); - const BnNodeSet& childs = node->getChilds(); - for (unsigned i = 0; i < childs.size(); i++) { - cmsgs_.insert (make_pair (childs[i], false)); - } - const BnNodeSet& parents = node->getParents(); - for (unsigned i = 0; i < parents.size(); i++) { - pmsgs_.insert (make_pair (parents[i], false)); - } -} - - - -ParamSet -BPNodeInfo::getBeliefs (void) const -{ - double sum = 0.0; - ParamSet beliefs (ds_); - for (unsigned xi = 0; xi < ds_; xi++) { - double prod = piVals_[xi] * ldVals_[xi]; - beliefs[xi] = prod; - sum += prod; - } - assert (sum); - //normalize the beliefs - for (unsigned xi = 0; xi < ds_; xi++) { - beliefs[xi] /= sum; - } - return beliefs; -} - - - -bool -BPNodeInfo::readyToSendPiMsgTo (const BayesNode* child) const -{ - for (unsigned i = 0; i < inChildLinks_.size(); i++) { - if (inChildLinks_[i]->getSource() != child - && !inChildLinks_[i]->messageWasSended()) { - return false; - } - } - return true; -} - - - -bool -BPNodeInfo::readyToSendLambdaMsgTo (const BayesNode* parent) const -{ - for (unsigned i = 0; i < inParentLinks_.size(); i++) { - if (inParentLinks_[i]->getSource() != parent - && !inParentLinks_[i]->messageWasSended()) { - return false; - } - } - return true; -} - - - -double -BPNodeInfo::getPiValue (unsigned idx) const -{ - assert (idx >=0 && idx < ds_); - return piVals_[idx]; -} - - - -void -BPNodeInfo::setPiValue (unsigned idx, Param value) -{ - assert (idx >=0 && idx < ds_); - piVals_[idx] = value; -} - - - -double -BPNodeInfo::getLambdaValue (unsigned idx) const -{ - assert (idx >=0 && idx < ds_); - return ldVals_[idx]; -} - - - -void -BPNodeInfo::setLambdaValue (unsigned idx, Param value) -{ - assert (idx >=0 && idx < ds_); - ldVals_[idx] = value; -} - - - -double -BPNodeInfo::getBeliefChange (void) -{ - double change = 0.0; - if (oldBeliefs_.size() == 0) { - oldBeliefs_ = getBeliefs(); - change = 9999999999.0; - } else { - ParamSet currentBeliefs = getBeliefs(); - for (unsigned xi = 0; xi < ds_; xi++) { - change += abs (currentBeliefs[xi] - oldBeliefs_[xi]); - } - oldBeliefs_ = currentBeliefs; - } - return change; -} - - - -bool -BPNodeInfo::receivedBottomInfluence (void) const -{ - // if all lambda values are equal, then neither - // this node neither its descendents have evidence, - // we can use this to don't send lambda messages his parents - bool childInfluenced = false; - for (unsigned xi = 1; xi < ds_; xi++) { - if (ldVals_[xi] != ldVals_[0]) { - childInfluenced = true; - break; - } - } - return childInfluenced; -} - diff --git a/packages/CLPBN/clpbn/bp/BPNodeInfo.h b/packages/CLPBN/clpbn/bp/BPNodeInfo.h deleted file mode 100755 index 966702652..000000000 --- a/packages/CLPBN/clpbn/bp/BPNodeInfo.h +++ /dev/null @@ -1,82 +0,0 @@ -#ifndef BP_BP_NODE_H -#define BP_BP_NODE_H - -#include -#include - -#include "BPSolver.h" -#include "BayesNode.h" -#include "Shared.h" - -//class Edge; - -using namespace std; - -class BPNodeInfo -{ - public: - BPNodeInfo (int); - BPNodeInfo (BayesNode*); - - ParamSet getBeliefs (void) const; - double getPiValue (unsigned) const; - void setPiValue (unsigned, Param); - double getLambdaValue (unsigned) const; - void setLambdaValue (unsigned, Param); - double getBeliefChange (void); - bool receivedBottomInfluence (void) const; - - ParamSet& getPiValues (void) { return piVals_; } - ParamSet& getLambdaValues (void) { return ldVals_; } - bool arePiValuesCalculated (void) { return piValsCalc_; } - bool areLambdaValuesCalculated (void) { return ldValsCalc_; } - void markPiValuesAsCalculated (void) { piValsCalc_ = true; } - void markLambdaValuesAsCalculated (void) { ldValsCalc_ = true; } - void incNumPiMsgsRcv (void) { nPiMsgsRcv_ ++; } - void incNumLambdaMsgsRcv (void) { nLdMsgsRcv_ ++; } - - bool receivedAllPiMessages (void) - { - return node_->getParents().size() == nPiMsgsRcv_; - } - - bool receivedAllLambdaMessages (void) - { - return node_->getChilds().size() == nLdMsgsRcv_; - } - - bool readyToSendPiMsgTo (const BayesNode*) const ; - bool readyToSendLambdaMsgTo (const BayesNode*) const; - - CEdgeSet getIncomingParentLinks (void) { return inParentLinks_; } - CEdgeSet getIncomingChildLinks (void) { return inChildLinks_; } - CEdgeSet getOutcomingParentLinks (void) { return outParentLinks_; } - CEdgeSet getOutcomingChildLinks (void) { return outChildLinks_; } - - void addIncomingParentLink (Edge* l) { inParentLinks_.push_back (l); } - void addIncomingChildLink (Edge* l) { inChildLinks_.push_back (l); } - void addOutcomingParentLink (Edge* l) { outParentLinks_.push_back (l); } - void addOutcomingChildLink (Edge* l) { outChildLinks_.push_back (l); } - - private: - DISALLOW_COPY_AND_ASSIGN (BPNodeInfo); - - ParamSet piVals_; // pi values - ParamSet ldVals_; // lambda values - ParamSet oldBeliefs_; - unsigned nPiMsgsRcv_; - unsigned nLdMsgsRcv_; - bool piValsCalc_; - bool ldValsCalc_; - EdgeSet inParentLinks_; - EdgeSet inChildLinks_; - EdgeSet outParentLinks_; - EdgeSet outChildLinks_; - unsigned ds_; - const BayesNode* node_; - map pmsgs_; - map cmsgs_; -}; - -#endif //BP_BP_NODE_H - diff --git a/packages/CLPBN/clpbn/bp/BPSolver.cpp b/packages/CLPBN/clpbn/bp/BPSolver.cpp deleted file mode 100644 index 67407d80c..000000000 --- a/packages/CLPBN/clpbn/bp/BPSolver.cpp +++ /dev/null @@ -1,905 +0,0 @@ -#include -#include -#include - -#include -#include -#include - -#include "BPSolver.h" - -BPSolver::BPSolver (const BayesNet& bn) : Solver (&bn) -{ - bn_ = &bn; - useAlwaysLoopySolver_ = false; - //jointCalcType_ = CHAIN_RULE; - jointCalcType_ = JUNCTION_NODE; -} - - - -BPSolver::~BPSolver (void) -{ - for (unsigned i = 0; i < nodesI_.size(); i++) { - delete nodesI_[i]; - } - for (unsigned i = 0; i < links_.size(); i++) { - delete links_[i]; - } -} - - - -void -BPSolver::runSolver (void) -{ - clock_t start_ = clock(); - unsigned size = bn_->getNumberOfNodes(); - unsigned nIters = 0; - initializeSolver(); - if (bn_->isSingleConnected() && !useAlwaysLoopySolver_) { - runPolyTreeSolver(); - Statistics::numSolvedPolyTrees ++; - } else { - runLoopySolver(); - Statistics::numSolvedLoopyNets ++; - if (nIter_ >= SolverOptions::maxIter) { - Statistics::numUnconvergedRuns ++; - } else { - nIters = nIter_; - } - if (DL >= 2) { - cout << endl; - if (nIter_ < SolverOptions::maxIter) { - cout << "Belief propagation converged in " ; - cout << nIter_ << " iterations" << endl; - } else { - cout << "The maximum number of iterations was hit, terminating..." ; - cout << endl; - } - } - } - double time = (double (clock() - start_)) / CLOCKS_PER_SEC; - Statistics::updateStats (size, nIters, time); - if (EXPORT_TO_DOT && size > EXPORT_MIN_SIZE) { - stringstream ss; - ss << size << "." ; - ss << Statistics::getCounting (size) << ".dot" ; - bn_->exportToDotFormat (ss.str().c_str()); - } -} - - - -ParamSet -BPSolver::getPosterioriOf (Vid vid) const -{ - BayesNode* node = bn_->getBayesNode (vid); - assert (node); - return nodesI_[node->getIndex()]->getBeliefs(); -} - - - -ParamSet -BPSolver::getJointDistributionOf (const VidSet& jointVids) -{ - if (DL >= 2) { - cout << "calculating joint distribution on: " ; - for (unsigned i = 0; i < jointVids.size(); i++) { - Variable* var = bn_->getBayesNode (jointVids[i]); - cout << var->getLabel() << " " ; - } - cout << endl; - } - - if (jointCalcType_ == JUNCTION_NODE) { - return getJointByJunctionNode (jointVids); - } else { - return getJointByChainRule (jointVids); - } -} - - - -void -BPSolver::initializeSolver (void) -{ - if (DL >= 2) { - if (!useAlwaysLoopySolver_) { - cout << "-> solver type = polytree solver" << endl; - cout << "-> schedule = n/a"; - } else { - cout << "-> solver = loopy solver" << endl; - cout << "-> schedule = "; - switch (SolverOptions::schedule) { - case SolverOptions::S_SEQ_FIXED: cout << "sequential fixed" ; break; - case SolverOptions::S_SEQ_RANDOM: cout << "sequential random" ; break; - case SolverOptions::S_PARALLEL: cout << "parallel" ; break; - case SolverOptions::S_MAX_RESIDUAL: cout << "max residual" ; break; - } - } - cout << endl; - cout << "-> joint method = " ; - if (jointCalcType_ == JUNCTION_NODE) { - cout << "junction node" << endl; - } else { - cout << "chain rule " << endl; - } - cout << "-> max iters = " << SolverOptions::maxIter << endl; - cout << "-> accuracy = " << SolverOptions::accuracy << endl; - cout << endl; - } - - CBnNodeSet nodes = bn_->getBayesNodes(); - for (unsigned i = 0; i < nodesI_.size(); i++) { - delete nodesI_[i]; - } - nodesI_.clear(); - nodesI_.reserve (nodes.size()); - links_.clear(); - sortedOrder_.clear(); - edgeMap_.clear(); - - for (unsigned i = 0; i < nodes.size(); i++) { - nodesI_.push_back (new BPNodeInfo (nodes[i])); - } - - BnNodeSet roots = bn_->getRootNodes(); - for (unsigned i = 0; i < roots.size(); i++) { - const ParamSet& params = roots[i]->getParameters(); - ParamSet& piVals = M(roots[i])->getPiValues(); - for (unsigned ri = 0; ri < roots[i]->getDomainSize(); ri++) { - piVals[ri] = params[ri]; - } - } - - for (unsigned i = 0; i < nodes.size(); i++) { - CBnNodeSet parents = nodes[i]->getParents(); - for (unsigned j = 0; j < parents.size(); j++) { - Edge* newLink = new Edge (parents[j], nodes[i], PI_MSG); - links_.push_back (newLink); - M(nodes[i])->addIncomingParentLink (newLink); - M(parents[j])->addOutcomingChildLink (newLink); - } - CBnNodeSet childs = nodes[i]->getChilds(); - for (unsigned j = 0; j < childs.size(); j++) { - Edge* newLink = new Edge (childs[j], nodes[i], LAMBDA_MSG); - links_.push_back (newLink); - M(nodes[i])->addIncomingChildLink (newLink); - M(childs[j])->addOutcomingParentLink (newLink); - } - } - - for (unsigned i = 0; i < nodes.size(); i++) { - if (nodes[i]->hasEvidence()) { - ParamSet& piVals = M(nodes[i])->getPiValues(); - ParamSet& ldVals = M(nodes[i])->getLambdaValues(); - for (unsigned xi = 0; xi < nodes[i]->getDomainSize(); xi++) { - piVals[xi] = 0.0; - ldVals[xi] = 0.0; - } - piVals[nodes[i]->getEvidence()] = 1.0; - ldVals[nodes[i]->getEvidence()] = 1.0; - } - } -} - - - -void -BPSolver::runPolyTreeSolver (void) -{ - CBnNodeSet nodes = bn_->getBayesNodes(); - for (unsigned i = 0; i < nodes.size(); i++) { - if (nodes[i]->isRoot()) { - M(nodes[i])->markPiValuesAsCalculated(); - } - if (nodes[i]->isLeaf()) { - M(nodes[i])->markLambdaValuesAsCalculated(); - } - } - - bool finish = false; - while (!finish) { - finish = true; - for (unsigned i = 0; i < nodes.size(); i++) { - if (M(nodes[i])->arePiValuesCalculated() == false - && M(nodes[i])->receivedAllPiMessages()) { - if (!nodes[i]->hasEvidence()) { - updatePiValues (nodes[i]); - } - M(nodes[i])->markPiValuesAsCalculated(); - finish = false; - } - - if (M(nodes[i])->areLambdaValuesCalculated() == false - && M(nodes[i])->receivedAllLambdaMessages()) { - if (!nodes[i]->hasEvidence()) { - updateLambdaValues (nodes[i]); - } - M(nodes[i])->markLambdaValuesAsCalculated(); - finish = false; - } - - if (M(nodes[i])->arePiValuesCalculated()) { - CEdgeSet outChildLinks = M(nodes[i])->getOutcomingChildLinks(); - for (unsigned j = 0; j < outChildLinks.size(); j++) { - BayesNode* child = outChildLinks[j]->getDestination(); - if (!outChildLinks[j]->messageWasSended()) { - if (M(nodes[i])->readyToSendPiMsgTo (child)) { - outChildLinks[j]->setNextMessage (getMessage (outChildLinks[j])); - outChildLinks[j]->updateMessage(); - M(child)->incNumPiMsgsRcv(); - } - finish = false; - } - } - } - - if (M(nodes[i])->areLambdaValuesCalculated()) { - CEdgeSet outParentLinks = M(nodes[i])->getOutcomingParentLinks(); - for (unsigned j = 0; j < outParentLinks.size(); j++) { - BayesNode* parent = outParentLinks[j]->getDestination(); - if (!outParentLinks[j]->messageWasSended()) { - if (M(nodes[i])->readyToSendLambdaMsgTo (parent)) { - outParentLinks[j]->setNextMessage (getMessage (outParentLinks[j])); - outParentLinks[j]->updateMessage(); - M(parent)->incNumLambdaMsgsRcv(); - } - finish = false; - } - } - } - } - } -} - - - -void -BPSolver::runLoopySolver() -{ - nIter_ = 0; - while (!converged() && nIter_ < SolverOptions::maxIter) { - - nIter_++; - if (DL >= 2) { - cout << endl; - cout << "****************************************" ; - cout << "****************************************" ; - cout << endl; - cout << " Iteration " << nIter_ << endl; - cout << "****************************************" ; - cout << "****************************************" ; - cout << endl; - } - - switch (SolverOptions::schedule) { - - case SolverOptions::S_SEQ_RANDOM: - random_shuffle (links_.begin(), links_.end()); - // no break - - case SolverOptions::S_SEQ_FIXED: - for (unsigned i = 0; i < links_.size(); i++) { - links_[i]->setNextMessage (getMessage (links_[i])); - links_[i]->updateMessage(); - updateValues (links_[i]); - } - break; - - case SolverOptions::S_PARALLEL: - for (unsigned i = 0; i < links_.size(); i++) { - //calculateNextMessage (links_[i]); - } - for (unsigned i = 0; i < links_.size(); i++) { - //updateMessage (links_[i]); - //updateValues (links_[i]); - } - break; - - case SolverOptions::S_MAX_RESIDUAL: - maxResidualSchedule(); - break; - - } - } -} - - - -bool -BPSolver::converged (void) const -{ - // this can happen if the graph is fully disconnected - if (links_.size() == 0) { - return true; - } - if (nIter_ == 0 || nIter_ == 1) { - return false; - } - bool converged = true; - if (SolverOptions::schedule == SolverOptions::S_MAX_RESIDUAL) { - Param maxResidual = (*(sortedOrder_.begin()))->getResidual(); - if (maxResidual < SolverOptions::accuracy) { - converged = true; - } else { - converged = false; - } - } else { - CBnNodeSet nodes = bn_->getBayesNodes(); - for (unsigned i = 0; i < nodes.size(); i++) { - if (!nodes[i]->hasEvidence()) { - double change = M(nodes[i])->getBeliefChange(); - if (DL >= 2) { - cout << nodes[i]->getLabel() + " belief change = " ; - cout << change << endl; - } - if (change > SolverOptions::accuracy) { - converged = false; - if (DL == 0) break; - } - } - } - } - return converged; -} - - - -void -BPSolver::maxResidualSchedule (void) -{ - if (nIter_ == 1) { - for (unsigned i = 0; i < links_.size(); i++) { - links_[i]->setNextMessage (getMessage (links_[i])); - links_[i]->updateResidual(); - SortedOrder::iterator it = sortedOrder_.insert (links_[i]); - edgeMap_.insert (make_pair (links_[i], it)); - if (DL >= 2) { - cout << "calculating " << links_[i]->toString() << endl; - } - } - return; - } - - for (unsigned c = 0; c < sortedOrder_.size(); c++) { - if (DL >= 2) { - cout << endl << "current residuals:" << endl; - for (SortedOrder::iterator it = sortedOrder_.begin(); - it != sortedOrder_.end(); it ++) { - cout << " " << setw (30) << left << (*it)->toString(); - cout << "residual = " << (*it)->getResidual() << endl; - } - } - - SortedOrder::iterator it = sortedOrder_.begin(); - Edge* edge = *it; - if (DL >= 2) { - cout << "updating " << edge->toString() << endl; - } - if (edge->getResidual() < SolverOptions::accuracy) { - return; - } - edge->updateMessage(); - updateValues (edge); - edge->clearResidual(); - sortedOrder_.erase (it); - edgeMap_.find (edge)->second = sortedOrder_.insert (edge); - - // update the messages that depend on message source --> destin - CEdgeSet outChildLinks = - M(edge->getDestination())->getOutcomingChildLinks(); - for (unsigned i = 0; i < outChildLinks.size(); i++) { - if (outChildLinks[i]->getDestination() != edge->getSource()) { - if (DL >= 2) { - cout << " calculating " << outChildLinks[i]->toString() << endl; - } - outChildLinks[i]->setNextMessage (getMessage (outChildLinks[i])); - outChildLinks[i]->updateResidual(); - EdgeMap::iterator iter = edgeMap_.find (outChildLinks[i]); - sortedOrder_.erase (iter->second); - iter->second = sortedOrder_.insert (outChildLinks[i]); - } - } - CEdgeSet outParentLinks = - M(edge->getDestination())->getOutcomingParentLinks(); - for (unsigned i = 0; i < outParentLinks.size(); i++) { - if (outParentLinks[i]->getDestination() != edge->getSource()) { - //&& !outParentLinks[i]->getDestination()->hasEvidence()) FIXME{ - if (DL >= 2) { - cout << " calculating " << outParentLinks[i]->toString() << endl; - } - outParentLinks[i]->setNextMessage (getMessage (outParentLinks[i])); - outParentLinks[i]->updateResidual(); - EdgeMap::iterator iter = edgeMap_.find (outParentLinks[i]); - sortedOrder_.erase (iter->second); - iter->second = sortedOrder_.insert (outParentLinks[i]); - } - } - - } -} - - - -void -BPSolver::updatePiValues (BayesNode* x) -{ - // π(Xi) - if (DL >= 3) { - cout << "updating " << PI << " values for " << x->getLabel() << endl; - } - CEdgeSet parentLinks = M(x)->getIncomingParentLinks(); - assert (x->getParents() == parentLinks.size()); - const vector& entries = x->getCptEntries(); - stringstream* calcs1 = 0; - stringstream* calcs2 = 0; - - ParamSet messageProducts (entries.size()); - for (unsigned k = 0; k < entries.size(); k++) { - if (DL >= 5) { - calcs1 = new stringstream; - calcs2 = new stringstream; - } - double messageProduct = 1.0; - const DConf& conf = entries[k].getDomainConfiguration(); - for (unsigned i = 0; i < parentLinks.size(); i++) { - assert (parentLinks[i]->getSource() == parents[i]); - assert (parentLinks[i]->getDestination() == x); - messageProduct *= parentLinks[i]->getMessage()[conf[i]]; - if (DL >= 5) { - if (i != 0) *calcs1 << "." ; - if (i != 0) *calcs2 << "*" ; - *calcs1 << PI << "(" << parentLinks[i]->getSource()->getLabel(); - *calcs1 << " --> " << x->getLabel() << ")" ; - *calcs1 << "[" ; - *calcs1 << parentLinks[i]->getSource()->getDomain()[conf[i]]; - *calcs1 << "]"; - *calcs2 << parentLinks[i]->getMessage()[conf[i]]; - } - } - messageProducts[k] = messageProduct; - if (DL >= 5) { - cout << " mp" << k; - cout << " = " << (*calcs1).str(); - if (parentLinks.size() == 1) { - cout << " = " << messageProduct << endl; - } else { - cout << " = " << (*calcs2).str(); - cout << " = " << messageProduct << endl; - } - delete calcs1; - delete calcs2; - } - } - - for (unsigned xi = 0; xi < x->getDomainSize(); xi++) { - double sum = 0.0; - if (DL >= 5) { - calcs1 = new stringstream; - calcs2 = new stringstream; - } - for (unsigned k = 0; k < entries.size(); k++) { - sum += x->getProbability (xi, entries[k]) * messageProducts[k]; - if (DL >= 5) { - if (k != 0) *calcs1 << " + " ; - if (k != 0) *calcs2 << " + " ; - *calcs1 << x->cptEntryToString (xi, entries[k]); - *calcs1 << ".mp" << k; - *calcs2 << x->getProbability (xi, entries[k]); - *calcs2 << "*" << messageProducts[k]; - } - } - M(x)->setPiValue (xi, sum); - if (DL >= 5) { - cout << " " << PI << "(" << x->getLabel() << ")" ; - cout << "[" << x->getDomain()[xi] << "]" ; - cout << " = " << (*calcs1).str(); - cout << " = " << (*calcs2).str(); - cout << " = " << sum << endl; - delete calcs1; - delete calcs2; - } - } -} - - - -void -BPSolver::updateLambdaValues (BayesNode* x) -{ - // λ(Xi) - if (DL >= 3) { - cout << "updating " << LD << " values for " << x->getLabel() << endl; - } - CEdgeSet childLinks = M(x)->getIncomingChildLinks(); - stringstream* calcs1 = 0; - stringstream* calcs2 = 0; - - for (unsigned xi = 0; xi < x->getDomainSize(); xi++) { - double product = 1.0; - if (DL >= 5) { - calcs1 = new stringstream; - calcs2 = new stringstream; - } - for (unsigned i = 0; i < childLinks.size(); i++) { - assert (childLinks[i]->getDestination() == x); - product *= childLinks[i]->getMessage()[xi]; - if (DL >= 5) { - if (i != 0) *calcs1 << "." ; - if (i != 0) *calcs2 << "*" ; - *calcs1 << LD << "(" << childLinks[i]->getSource()->getLabel(); - *calcs1 << "-->" << x->getLabel() << ")" ; - *calcs1 << "[" << x->getDomain()[xi] << "]" ; - *calcs2 << childLinks[i]->getMessage()[xi]; - } - } - M(x)->setLambdaValue (xi, product); - if (DL >= 5) { - cout << " " << LD << "(" << x->getLabel() << ")" ; - cout << "[" << x->getDomain()[xi] << "]" ; - cout << " = " << (*calcs1).str(); - if (childLinks.size() == 1) { - cout << " = " << product << endl; - } else { - cout << " = " << (*calcs2).str(); - cout << " = " << product << endl; - } - delete calcs1; - delete calcs2; - } - } -} - - - -ParamSet -BPSolver::calculateNextPiMessage (Edge* edge) -{ - // πX(Zi) - BayesNode* z = edge->getSource(); - BayesNode* x = edge->getDestination(); - ParamSet zxPiNextMessage (z->getDomainSize()); - CEdgeSet zChildLinks = M(z)->getIncomingChildLinks(); - stringstream* calcs1 = 0; - stringstream* calcs2 = 0; - - - for (unsigned zi = 0; zi < z->getDomainSize(); zi++) { - double product = M(z)->getPiValue (zi); - if (DL >= 5) { - calcs1 = new stringstream; - calcs2 = new stringstream; - *calcs1 << PI << "(" << z->getLabel() << ")"; - *calcs1 << "[" << z->getDomain()[zi] << "]" ; - *calcs2 << product; - } - for (unsigned i = 0; i < zChildLinks.size(); i++) { - assert (zChildLinks[i]->getDestination() == z); - if (zChildLinks[i]->getSource() != x) { - product *= zChildLinks[i]->getMessage()[zi]; - if (DL >= 5) { - *calcs1 << "." << LD << "(" << zChildLinks[i]->getSource()->getLabel(); - *calcs1 << "-->" << z->getLabel() << ")"; - *calcs1 << "[" << z->getDomain()[zi] + "]" ; - *calcs2 << " * " << zChildLinks[i]->getMessage()[zi]; - } - } - } - zxPiNextMessage[zi] = product; - if (DL >= 5) { - cout << " " << PI << "(" << z->getLabel(); - cout << "-->" << x->getLabel() << ")" ; - cout << "[" << z->getDomain()[zi] << "]" ; - cout << " = " << (*calcs1).str(); - if (zChildLinks.size() == 1) { - cout << " = " << product << endl; - } else { - cout << " = " << (*calcs2).str(); - cout << " = " << product << endl; - } - delete calcs1; - delete calcs2; - } - } - return zxPiNextMessage; -} - - - -ParamSet -BPSolver::calculateNextLambdaMessage (Edge* edge) -{ - // λY(Xi) - BayesNode* y = edge->getSource(); - BayesNode* x = edge->getDestination(); - if (!M(y)->receivedBottomInfluence()) { - //cout << "returning 1" << endl; - //return edge->getMessage(); - } - if (x->hasEvidence()) { - //cout << "returning 2" << endl; - //return edge->getMessage(); - } - ParamSet yxLambdaNextMessage (x->getDomainSize()); - CEdgeSet yParentLinks = M(y)->getIncomingParentLinks(); - const vector& allEntries = y->getCptEntries(); - int parentIndex = y->getIndexOfParent (x); - stringstream* calcs1 = 0; - stringstream* calcs2 = 0; - - vector entries; - DConstraint constr = make_pair (parentIndex, 0); - for (unsigned i = 0; i < allEntries.size(); i++) { - if (allEntries[i].matchConstraints(constr)) { - entries.push_back (allEntries[i]); - } - } - - ParamSet messageProducts (entries.size()); - for (unsigned k = 0; k < entries.size(); k++) { - if (DL >= 5) { - calcs1 = new stringstream; - calcs2 = new stringstream; - } - double messageProduct = 1.0; - const DConf& conf = entries[k].getDomainConfiguration(); - for (unsigned i = 0; i < yParentLinks.size(); i++) { - assert (yParentLinks[i]->getDestination() == y); - if (yParentLinks[i]->getSource() != x) { - if (DL >= 5) { - if (messageProduct != 1.0) *calcs1 << "*" ; - if (messageProduct != 1.0) *calcs2 << "*" ; - *calcs1 << PI << "(" << yParentLinks[i]->getSource()->getLabel(); - *calcs1 << "-->" << y->getLabel() << ")" ; - *calcs1 << "[" ; - *calcs1 << yParentLinks[i]->getSource()->getDomain()[conf[i]]; - *calcs1 << "]" ; - *calcs2 << yParentLinks[i]->getMessage()[conf[i]]; - } - messageProduct *= yParentLinks[i]->getMessage()[conf[i]]; - } - } - messageProducts[k] = messageProduct; - if (DL >= 5) { - cout << " mp" << k; - cout << " = " << (*calcs1).str(); - if (yParentLinks.size() == 1) { - cout << 1 << endl; - } else if (yParentLinks.size() == 2) { - cout << " = " << messageProduct << endl; - } else { - cout << " = " << (*calcs2).str(); - cout << " = " << messageProduct << endl; - } - delete calcs1; - delete calcs2; - } - } - - for (unsigned xi = 0; xi < x->getDomainSize(); xi++) { - if (DL >= 5) { - calcs1 = new stringstream; - calcs2 = new stringstream; - } - vector entries; - DConstraint constr = make_pair (parentIndex, xi); - for (unsigned i = 0; i < allEntries.size(); i++) { - if (allEntries[i].matchConstraints(constr)) { - entries.push_back (allEntries[i]); - } - } - double outerSum = 0.0; - for (unsigned yi = 0; yi < y->getDomainSize(); yi++) { - if (DL >= 5) { - (yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ; - (yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ; - } - double innerSum = 0.0; - for (unsigned k = 0; k < entries.size(); k++) { - if (DL >= 5) { - if (k != 0) *calcs1 << " + " ; - if (k != 0) *calcs2 << " + " ; - *calcs1 << y->cptEntryToString (yi, entries[k]); - *calcs1 << ".mp" << k; - *calcs2 << y->getProbability (yi, entries[k]); - *calcs2 << "*" << messageProducts[k]; - } - innerSum += y->getProbability (yi, entries[k]) * messageProducts[k]; - } - outerSum += innerSum * M(y)->getLambdaValue (yi); - if (DL >= 5) { - *calcs1 << "}." << LD << "(" << y->getLabel() << ")" ; - *calcs1 << "[" << y->getDomain()[yi] << "]"; - *calcs2 << "}*" << M(y)->getLambdaValue (yi); - } - } - yxLambdaNextMessage[xi] = outerSum; - if (DL >= 5) { - cout << " " << LD << "(" << y->getLabel(); - cout << "-->" << x->getLabel() << ")" ; - cout << "[" << x->getDomain()[xi] << "]" ; - cout << " = " << (*calcs1).str(); - cout << " = " << (*calcs2).str(); - cout << " = " << outerSum << endl; - delete calcs1; - delete calcs2; - } - } - return yxLambdaNextMessage; -} - - - -ParamSet -BPSolver::getJointByJunctionNode (const VidSet& jointVids) const -{ - BnNodeSet jointVars; - for (unsigned i = 0; i < jointVids.size(); i++) { - jointVars.push_back (bn_->getBayesNode (jointVids[i])); - } - - BayesNet* mrn = bn_->getMinimalRequesiteNetwork (jointVids); - - BnNodeSet parents; - unsigned dsize = 1; - for (unsigned i = 0; i < jointVars.size(); i++) { - parents.push_back (mrn->getBayesNode (jointVids[i])); - dsize *= jointVars[i]->getDomainSize(); - } - - unsigned nParams = dsize * dsize; - ParamSet params (nParams); - - for (unsigned i = 0; i < nParams; i++) { - unsigned row = i / dsize; - unsigned col = i % dsize; - if (row == col) { - params[i] = 1; - } else { - params[i] = 0; - } - } - - unsigned maxVid = std::numeric_limits::max(); - Distribution* dist = new Distribution (params); - - mrn->addNode (maxVid, dsize, NO_EVIDENCE, parents, dist); - mrn->setIndexes(); - - BPSolver solver (*mrn); - solver.runSolver(); - - const ParamSet& results = solver.getPosterioriOf (maxVid); - - delete mrn; - delete dist; - - return results; -} - - - -ParamSet -BPSolver::getJointByChainRule (const VidSet& jointVids) const -{ - BnNodeSet jointVars; - for (unsigned i = 0; i < jointVids.size(); i++) { - jointVars.push_back (bn_->getBayesNode (jointVids[i])); - } - - BayesNet* mrn = bn_->getMinimalRequesiteNetwork (jointVids[0]); - BPSolver solver (*mrn); - solver.runSolver(); - ParamSet prevBeliefs = solver.getPosterioriOf (jointVids[0]); - delete mrn; - - VarSet observedVars = {jointVars[0]}; - - for (unsigned i = 1; i < jointVids.size(); i++) { - mrn = bn_->getMinimalRequesiteNetwork (jointVids[i]); - ParamSet newBeliefs; - vector confs = - Util::getDomainConfigurations (observedVars); - for (unsigned j = 0; j < confs.size(); j++) { - for (unsigned k = 0; k < observedVars.size(); k++) { - if (!observedVars[k]->hasEvidence()) { - BayesNode* node = mrn->getBayesNode (observedVars[k]->getVarId()); - if (node) { - node->setEvidence (confs[j][k]); - } - } - } - BPSolver solver (*mrn); - solver.runSolver(); - ParamSet beliefs = solver.getPosterioriOf (jointVids[i]); - for (unsigned k = 0; k < beliefs.size(); k++) { - newBeliefs.push_back (beliefs[k]); - } - } - - int count = -1; - for (unsigned j = 0; j < newBeliefs.size(); j++) { - if (j % jointVars[i]->getDomainSize() == 0) { - count ++; - } - newBeliefs[j] *= prevBeliefs[count]; - } - prevBeliefs = newBeliefs; - observedVars.push_back (jointVars[i]); - delete mrn; - } - return prevBeliefs; -} - - - -void -BPSolver::printMessageStatusOf (const BayesNode* var) const -{ - cout << left; - cout << setw (10) << "domain" ; - cout << setw (20) << PI << "(" + var->getLabel() + ")" ; - cout << setw (20) << LD << "(" + var->getLabel() + ")" ; - cout << setw (16) << "belief" ; - cout << endl; - cout << "--------------------------------" ; - cout << "--------------------------------" ; - cout << endl; - - BPNodeInfo* x = M(var); - ParamSet& piVals = x->getPiValues(); - ParamSet& ldVals = x->getLambdaValues(); - ParamSet beliefs = x->getBeliefs(); - const Domain& domain = var->getDomain(); - CBnNodeSet& childs = var->getChilds(); - - for (unsigned xi = 0; xi < var->getDomainSize(); xi++) { - cout << setw (10) << domain[xi]; - cout << setw (19) << piVals[xi]; - cout << setw (19) << ldVals[xi]; - cout.precision (PRECISION); - cout << setw (16) << beliefs[xi]; - cout << endl; - } - cout << endl; - if (childs.size() > 0) { - string s = "(" + var->getLabel() + ")" ; - for (unsigned j = 0; j < childs.size(); j++) { - cout << setw (10) << "domain" ; - cout << setw (28) << PI + childs[j]->getLabel() + s; - cout << setw (28) << LD + childs[j]->getLabel() + s; - cout << endl; - cout << "--------------------------------" ; - cout << "--------------------------------" ; - cout << endl; - /* FIXME - const ParamSet& piMessage = x->getPiMessage (childs[j]); - const ParamSet& lambdaMessage = x->getLambdaMessage (childs[j]); - for (unsigned xi = 0; xi < var->getDomainSize(); xi++) { - cout << setw (10) << domain[xi]; - cout.precision (PRECISION); - cout << setw (27) << piMessage[xi]; - cout.precision (PRECISION); - cout << setw (27) << lambdaMessage[xi]; - cout << endl; - } - cout << endl; - */ - } - } -} - - - -void -BPSolver::printAllMessageStatus (void) const -{ - CBnNodeSet nodes = bn_->getBayesNodes(); - for (unsigned i = 0; i < nodes.size(); i++) { - printMessageStatusOf (nodes[i]); - } -} - diff --git a/packages/CLPBN/clpbn/bp/BPSolver.h b/packages/CLPBN/clpbn/bp/BPSolver.h deleted file mode 100644 index c3b8ee9f1..000000000 --- a/packages/CLPBN/clpbn/bp/BPSolver.h +++ /dev/null @@ -1,192 +0,0 @@ -#ifndef BP_BP_SOLVER_H -#define BP_BP_SOLVER_H - -#include -#include - -#include "Solver.h" -#include "BayesNet.h" -#include "BPNodeInfo.h" -#include "Shared.h" - -using namespace std; - -class BPNodeInfo; - -static const string PI = "pi" ; -static const string LD = "ld" ; - - -enum MessageType {PI_MSG, LAMBDA_MSG}; -enum JointCalcType {CHAIN_RULE, JUNCTION_NODE}; - -class Edge -{ - public: - Edge (BayesNode* s, BayesNode* d, MessageType t) - { - source_ = s; - destin_ = d; - type_ = t; - if (type_ == PI_MSG) { - currMsg_.resize (s->getDomainSize(), 1); - nextMsg_.resize (s->getDomainSize(), 1); - } else { - currMsg_.resize (d->getDomainSize(), 1); - nextMsg_.resize (d->getDomainSize(), 1); - } - msgSended_ = false; - residual_ = 0.0; - } - - //void setMessage (ParamSet msg) - //{ - // Util::normalize (msg); - // residual_ = Util::getMaxNorm (currMsg_, msg); - // currMsg_ = msg; - //} - - void setNextMessage (CParamSet msg) - { - nextMsg_ = msg; - Util::normalize (nextMsg_); - residual_ = Util::getMaxNorm (currMsg_, nextMsg_); - } - - void updateMessage (void) - { - currMsg_ = nextMsg_; - if (DL >= 3) { - cout << "updating " << toString() << endl; - } - msgSended_ = true; - } - - void updateResidual (void) - { - residual_ = Util::getMaxNorm (currMsg_, nextMsg_); - } - - string toString (void) const - { - stringstream ss; - if (type_ == PI_MSG) { - ss << PI; - } else if (type_ == LAMBDA_MSG) { - ss << LD; - } else { - abort(); - } - ss << "(" << source_->getLabel(); - ss << " --> " << destin_->getLabel() << ")" ; - return ss.str(); - } - - BayesNode* getSource (void) const { return source_; } - BayesNode* getDestination (void) const { return destin_; } - MessageType getMessageType (void) const { return type_; } - CParamSet getMessage (void) const { return currMsg_; } - bool messageWasSended (void) const { return msgSended_; } - double getResidual (void) const { return residual_; } - void clearResidual (void) { residual_ = 0.0; } - - private: - BayesNode* source_; - BayesNode* destin_; - MessageType type_; - ParamSet currMsg_; - ParamSet nextMsg_; - bool msgSended_; - double residual_; -}; - - -class BPSolver : public Solver -{ - public: - BPSolver (const BayesNet&); - ~BPSolver (void); - - void runSolver (void); - ParamSet getPosterioriOf (Vid) const; - ParamSet getJointDistributionOf (const VidSet&); - - - private: - DISALLOW_COPY_AND_ASSIGN (BPSolver); - - void initializeSolver (void); - void runPolyTreeSolver (void); - void runLoopySolver (void); - void maxResidualSchedule (void); - bool converged (void) const; - void updatePiValues (BayesNode*); - void updateLambdaValues (BayesNode*); - ParamSet calculateNextLambdaMessage (Edge* edge); - ParamSet calculateNextPiMessage (Edge* edge); - ParamSet getJointByJunctionNode (const VidSet&) const; - ParamSet getJointByChainRule (const VidSet&) const; - void printMessageStatusOf (const BayesNode*) const; - void printAllMessageStatus (void) const; - - ParamSet getMessage (Edge* edge) - { - if (DL >= 3) { - cout << " calculating " << edge->toString() << endl; - } - if (edge->getMessageType() == PI_MSG) { - return calculateNextPiMessage (edge); - } else if (edge->getMessageType() == LAMBDA_MSG) { - return calculateNextLambdaMessage (edge); - } else { - abort(); - } - return ParamSet(); - } - - void updateValues (Edge* edge) - { - if (!edge->getDestination()->hasEvidence()) { - if (edge->getMessageType() == PI_MSG) { - updatePiValues (edge->getDestination()); - } else if (edge->getMessageType() == LAMBDA_MSG) { - updateLambdaValues (edge->getDestination()); - } else { - abort(); - } - } - } - - BPNodeInfo* M (const BayesNode* node) const - { - assert (node); - assert (node == bn_->getBayesNode (node->getVarId())); - assert (node->getIndex() < nodesI_.size()); - return nodesI_[node->getIndex()]; - } - - const BayesNet* bn_; - vector nodesI_; - unsigned nIter_; - vector links_; - bool useAlwaysLoopySolver_; - JointCalcType jointCalcType_; - - struct compare - { - inline bool operator() (const Edge* e1, const Edge* e2) - { - return e1->getResidual() > e2->getResidual(); - } - }; - - typedef multiset SortedOrder; - SortedOrder sortedOrder_; - - typedef map EdgeMap; - EdgeMap edgeMap_; - -}; - -#endif //BP_BP_SOLVER_H - diff --git a/packages/CLPBN/clpbn/bp/BayesNet.cpp b/packages/CLPBN/clpbn/bp/BayesNet.cpp index fe9a52bd4..632b383a6 100644 --- a/packages/CLPBN/clpbn/bp/BayesNet.cpp +++ b/packages/CLPBN/clpbn/bp/BayesNet.cpp @@ -4,111 +4,12 @@ #include #include #include -#include #include "xmlParser/xmlParser.h" #include "BayesNet.h" -BayesNet::BayesNet (const char* fileName) -{ - map domains; - XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF"); - // only the first network is parsed, others are ignored - XMLNode xNode = xMainNode.getChildNode ("NETWORK"); - unsigned nVars = xNode.nChildNode ("VARIABLE"); - for (unsigned i = 0; i < nVars; i++) { - XMLNode var = xNode.getChildNode ("VARIABLE", i); - string type = var.getAttribute ("TYPE"); - if (type != "nature") { - cerr << "error: only \"nature\" variables are supported" << endl; - abort(); - } - Domain domain; - string varLabel = var.getChildNode("NAME").getText(); - unsigned dsize = var.nChildNode ("OUTCOME"); - for (unsigned j = 0; j < dsize; j++) { - if (var.getChildNode("OUTCOME", j).getText() == 0) { - stringstream ss; - ss << j + 1; - domain.push_back (ss.str()); - } else { - domain.push_back (var.getChildNode("OUTCOME", j).getText()); - } - } - domains.insert (make_pair (varLabel, domain)); - } - - unsigned nDefs = xNode.nChildNode ("DEFINITION"); - if (nVars != nDefs) { - cerr << "error: different number of variables and definitions" << endl; - abort(); - } - - queue indexes; - for (unsigned i = 0; i < nDefs; i++) { - indexes.push (i); - } - - while (!indexes.empty()) { - unsigned index = indexes.front(); - indexes.pop(); - XMLNode def = xNode.getChildNode ("DEFINITION", index); - string varLabel = def.getChildNode("FOR").getText(); - map::const_iterator iter; - iter = domains.find (varLabel); - if (iter == domains.end()) { - cerr << "error: unknow variable `" << varLabel << "'" << endl; - abort(); - } - bool processItLatter = false; - BnNodeSet parents; - unsigned nParams = iter->second.size(); - for (int j = 0; j < def.nChildNode ("GIVEN"); j++) { - string parentLabel = def.getChildNode("GIVEN", j).getText(); - BayesNode* parentNode = getBayesNode (parentLabel); - if (parentNode) { - nParams *= parentNode->getDomainSize(); - parents.push_back (parentNode); - } - else { - iter = domains.find (parentLabel); - if (iter == domains.end()) { - cerr << "error: unknow parent `" << parentLabel << "'" << endl; - abort(); - } else { - // this definition contains a parent that doesn't - // have a corresponding bayesian node instance yet, - // so process this definition latter - indexes.push (index); - processItLatter = true; - break; - } - } - } - - if (!processItLatter) { - unsigned count = 0; - ParamSet params (nParams); - stringstream s (def.getChildNode("TABLE").getText()); - while (!s.eof() && count < nParams) { - s >> params[count]; - count ++; - } - if (count != nParams) { - cerr << "error: invalid number of parameters " ; - cerr << "for variable `" << varLabel << "'" << endl; - abort(); - } - params = reorderParameters (params, iter->second.size()); - addNode (varLabel, iter->second, parents, params); - } - } - setIndexes(); -} - - BayesNet::~BayesNet (void) { @@ -119,26 +20,130 @@ BayesNet::~BayesNet (void) -BayesNode* -BayesNet::addNode (Vid vid) +void +BayesNet::readFromBifFormat (const char* fileName) { - indexMap_.insert (make_pair (vid, nodes_.size())); - nodes_.push_back (new BayesNode (vid)); - return nodes_.back(); + XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF"); + // only the first network is parsed, others are ignored + XMLNode xNode = xMainNode.getChildNode ("NETWORK"); + unsigned nVars = xNode.nChildNode ("VARIABLE"); + for (unsigned i = 0; i < nVars; i++) { + XMLNode var = xNode.getChildNode ("VARIABLE", i); + if (string (var.getAttribute ("TYPE")) != "nature") { + cerr << "error: only \"nature\" variables are supported" << endl; + abort(); + } + States states; + string label = var.getChildNode("NAME").getText(); + unsigned nrStates = var.nChildNode ("OUTCOME"); + for (unsigned j = 0; j < nrStates; j++) { + if (var.getChildNode("OUTCOME", j).getText() == 0) { + stringstream ss; + ss << j + 1; + states.push_back (ss.str()); + } else { + states.push_back (var.getChildNode("OUTCOME", j).getText()); + } + } + addNode (label, states); + } + + unsigned nDefs = xNode.nChildNode ("DEFINITION"); + if (nVars != nDefs) { + cerr << "error: different number of variables and definitions" << endl; + abort(); + } + for (unsigned i = 0; i < nDefs; i++) { + XMLNode def = xNode.getChildNode ("DEFINITION", i); + string label = def.getChildNode("FOR").getText(); + BayesNode* node = getBayesNode (label); + if (!node) { + cerr << "error: unknow variable `" << label << "'" << endl; + abort(); + } + BnNodeSet parents; + unsigned nParams = node->nrStates(); + for (int j = 0; j < def.nChildNode ("GIVEN"); j++) { + string parentLabel = def.getChildNode("GIVEN", j).getText(); + BayesNode* parentNode = getBayesNode (parentLabel); + if (!parentNode) { + cerr << "error: unknow variable `" << parentLabel << "'" << endl; + abort(); + } + nParams *= parentNode->nrStates(); + parents.push_back (parentNode); + } + node->setParents (parents); + unsigned count = 0; + ParamSet params (nParams); + stringstream s (def.getChildNode("TABLE").getText()); + while (!s.eof() && count < nParams) { + s >> params[count]; + count ++; + } + if (count != nParams) { + cerr << "error: invalid number of parameters " ; + cerr << "for variable `" << label << "'" << endl; + abort(); + } + params = reorderParameters (params, node->nrStates()); + Distribution* dist = new Distribution (params); + node->setDistribution (dist); + addDistribution (dist); + } + + setIndexes(); + if (NSPACE == NumberSpace::LOGARITHM) { + distributionsToLogs(); + } +} + + + +void +BayesNet::addNode (BayesNode* n) +{ + indexMap_.insert (make_pair (n->varId(), nodes_.size())); + nodes_.push_back (n); } BayesNode* -BayesNet::addNode (Vid vid, +BayesNet::addNode (string label, const States& states) +{ + VarId vid = nodes_.size(); + indexMap_.insert (make_pair (vid, nodes_.size())); + GraphicalModel::addVariableInformation (vid, label, states); + BayesNode* node = new BayesNode (VarNode (vid, states.size())); + nodes_.push_back (node); + return node; +} + + + +BayesNode* +BayesNet::addNode (VarId vid, unsigned dsize, int evidence, BnNodeSet& parents, Distribution* dist) { indexMap_.insert (make_pair (vid, nodes_.size())); - nodes_.push_back (new BayesNode ( - vid, dsize, evidence, parents, dist)); + nodes_.push_back (new BayesNode (vid, dsize, evidence, parents, dist)); + return nodes_.back(); +} + + + +BayesNode* +BayesNet::addNode (VarId vid, + unsigned dsize, + int evidence, + Distribution* dist) +{ + indexMap_.insert (make_pair (vid, nodes_.size())); + nodes_.push_back (new BayesNode (vid, dsize, evidence, dist)); return nodes_.back(); } @@ -146,14 +151,16 @@ BayesNet::addNode (Vid vid, BayesNode* BayesNet::addNode (string label, - Domain domain, + States states, BnNodeSet& parents, ParamSet& params) { - indexMap_.insert (make_pair (nodes_.size(), nodes_.size())); + VarId vid = nodes_.size(); + indexMap_.insert (make_pair (vid, nodes_.size())); + GraphicalModel::addVariableInformation (vid, label, states); Distribution* dist = new Distribution (params); BayesNode* node = new BayesNode ( - nodes_.size(), label, domain, parents, dist); + vid, states.size(), NO_EVIDENCE, parents, dist); dists_.push_back (dist); nodes_.push_back (node); return node; @@ -162,7 +169,7 @@ BayesNet::addNode (string label, BayesNode* -BayesNet::getBayesNode (Vid vid) const +BayesNet::getBayesNode (VarId vid) const { IndexMap::const_iterator it = indexMap_.find (vid); if (it == indexMap_.end()) { @@ -179,7 +186,7 @@ BayesNet::getBayesNode (string label) const { BayesNode* node = 0; for (unsigned i = 0; i < nodes_.size(); i++) { - if (nodes_[i]->getLabel() == label) { + if (nodes_[i]->label() == label) { node = nodes_[i]; break; } @@ -190,10 +197,25 @@ BayesNet::getBayesNode (string label) const -Variable* -BayesNet::getVariable (Vid vid) const +VarNode* +BayesNet::getVariableNode (VarId vid) const { - return getBayesNode (vid); + BayesNode* node = getBayesNode (vid); + assert (node); + return node; +} + + + + +VarNodes +BayesNet::getVariableNodes (void) const +{ + VarNodes vars; + for (unsigned i = 0; i < nodes_.size(); i++) { + vars.push_back (nodes_[i]); + } + return vars; } @@ -230,7 +252,7 @@ BayesNet::getBayesNodes (void) const unsigned -BayesNet::getNumberOfNodes (void) const +BayesNet::nrNodes (void) const { return nodes_.size(); } @@ -265,37 +287,25 @@ BayesNet::getLeafNodes (void) const -VarSet -BayesNet::getVariables (void) const +BayesNet* +BayesNet::getMinimalRequesiteNetwork (VarId vid) const { - VarSet vars; - for (unsigned i = 0; i < nodes_.size(); i++) { - vars.push_back (nodes_[i]); - } - return vars; + return getMinimalRequesiteNetwork (VarIdSet() = {vid}); } BayesNet* -BayesNet::getMinimalRequesiteNetwork (Vid vid) const -{ - return getMinimalRequesiteNetwork (VidSet() = {vid}); -} - - - -BayesNet* -BayesNet::getMinimalRequesiteNetwork (const VidSet& queryVids) const +BayesNet::getMinimalRequesiteNetwork (const VarIdSet& queryVarIds) const { BnNodeSet queryVars; - for (unsigned i = 0; i < queryVids.size(); i++) { - assert (getBayesNode (queryVids[i])); - queryVars.push_back (getBayesNode (queryVids[i])); + for (unsigned i = 0; i < queryVarIds.size(); i++) { + assert (getBayesNode (queryVarIds[i])); + queryVars.push_back (getBayesNode (queryVarIds[i])); } // cout << "query vars: " ; // for (unsigned i = 0; i < queryVars.size(); i++) { - // cout << queryVars[i]->getLabel() << " " ; + // cout << queryVars[i]->label() << " " ; // } // cout << endl; @@ -344,7 +354,7 @@ BayesNet::getMinimalRequesiteNetwork (const VidSet& queryVids) const cout << "----------------------------------------------------------" ; cout << endl; for (unsigned i = 0; i < states.size(); i++) { - cout << nodes_[i]->getLabel() << ":\t\t" ; + cout << nodes_[i]->label() << ":\t\t" ; if (states[i]) { states[i]->markedOnTop ? cout << "yes\t" : cout << "no\t" ; states[i]->markedOnBottom ? cout << "yes\t" : cout << "no\t" ; @@ -374,51 +384,46 @@ void BayesNet::constructGraph (BayesNet* bn, const vector& states) const { + BnNodeSet mrnNodes; + vector parents; for (unsigned i = 0; i < nodes_.size(); i++) { bool isRequired = false; if (states[i]) { isRequired = (nodes_[i]->hasEvidence() && states[i]->visited) - || + || states[i]->markedOnTop; } if (isRequired) { - BnNodeSet parents; + parents.push_back (VarIdSet()); if (states[i]->markedOnTop) { const BnNodeSet& ps = nodes_[i]->getParents(); for (unsigned j = 0; j < ps.size(); j++) { - BayesNode* parent = bn->getBayesNode (ps[j]->getVarId()); - if (!parent) { - parent = bn->addNode (ps[j]->getVarId()); - } - parents.push_back (parent); + parents.back().push_back (ps[j]->varId()); } } - BayesNode* node = bn->getBayesNode (nodes_[i]->getVarId()); - if (node) { - node->setData (nodes_[i]->getDomainSize(), - nodes_[i]->getEvidence(), parents, - nodes_[i]->getDistribution()); - } else { - node = bn->addNode (nodes_[i]->getVarId(), - nodes_[i]->getDomainSize(), - nodes_[i]->getEvidence(), parents, - nodes_[i]->getDistribution()); - } - if (nodes_[i]->hasDomain()) { - node->setDomain (nodes_[i]->getDomain()); - } - if (nodes_[i]->hasLabel()) { - node->setLabel (nodes_[i]->getLabel()); - } + assert (bn->getBayesNode (nodes_[i]->varId()) == 0); + BayesNode* mrnNode = bn->addNode (nodes_[i]->varId(), + nodes_[i]->nrStates(), + nodes_[i]->getEvidence(), + nodes_[i]->getDistribution()); + mrnNodes.push_back (mrnNode); } } + for (unsigned i = 0; i < mrnNodes.size(); i++) { + BnNodeSet ps; + for (unsigned j = 0; j < parents[i].size(); j++) { + assert (bn->getBayesNode (parents[i][j]) != 0); + ps.push_back (bn->getBayesNode (parents[i][j])); + } + mrnNodes[i]->setParents (ps); + } bn->setIndexes(); } bool -BayesNet::isSingleConnected (void) const +BayesNet::isPolyTree (void) const { return !containsUndirectedCycle(); } @@ -435,6 +440,16 @@ BayesNet::setIndexes (void) +void +BayesNet::distributionsToLogs (void) +{ + for (unsigned i = 0; i < dists_.size(); i++) { + Util::toLog (dists_[i]->params); + } +} + + + void BayesNet::freeDistributions (void) { @@ -456,9 +471,9 @@ BayesNet::printGraphicalModel (void) const void -BayesNet::exportToDotFormat (const char* fileName, - bool showNeighborless, - CVidSet& highlightVids) const +BayesNet::exportToGraphViz (const char* fileName, + bool showNeighborless, + const VarIdSet& highlightVarIds) const { ofstream out (fileName); if (!out.is_open()) { @@ -467,27 +482,32 @@ BayesNet::exportToDotFormat (const char* fileName, abort(); } - out << "digraph \"" << fileName << "\" {" << endl; - + out << "digraph {" << endl; + out << "ranksep=1" << endl; for (unsigned i = 0; i < nodes_.size(); i++) { if (showNeighborless || nodes_[i]->hasNeighbors()) { - out << '"' << nodes_[i]->getLabel() << '"' ; + out << nodes_[i]->varId() ; if (nodes_[i]->hasEvidence()) { - out << " [style=filled, fillcolor=yellow]" << endl; + out << " [" ; + out << "label=\"" << nodes_[i]->label() << "\"," ; + out << "style=filled, fillcolor=yellow" ; + out << "]" ; } else { - out << endl; + out << " [" ; + out << "label=\"" << nodes_[i]->label() << "\"" ; + out << "]" ; } + out << endl; } } - for (unsigned i = 0; i < highlightVids.size(); i++) { - BayesNode* node = getBayesNode (highlightVids[i]); + for (unsigned i = 0; i < highlightVarIds.size(); i++) { + BayesNode* node = getBayesNode (highlightVarIds[i]); if (node) { - out << '"' << node->getLabel() << '"' ; - // out << " [shape=polygon, sides=6]" << endl; + out << node->varId() ; out << " [shape=box3d]" << endl; } else { - cout << "error: invalid variable id: " << highlightVids[i] << endl; + cout << "error: invalid variable id: " << highlightVarIds[i] << endl; abort(); } } @@ -495,8 +515,7 @@ BayesNet::exportToDotFormat (const char* fileName, for (unsigned i = 0; i < nodes_.size(); i++) { const BnNodeSet& childs = nodes_[i]->getChilds(); for (unsigned j = 0; j < childs.size(); j++) { - out << '"' << nodes_[i]->getLabel() << '"' << " -> " ; - out << '"' << childs[j]->getLabel() << '"' << endl; + out << nodes_[i]->varId() << " -> " << childs[j]->varId() << " [style=bold]" << endl ; } } @@ -521,24 +540,24 @@ BayesNet::exportToBifFormat (const char* fileName) const out << "" << fileName << "" << endl << endl; for (unsigned i = 0; i < nodes_.size(); i++) { out << "" << endl; - out << "\t" << nodes_[i]->getLabel() << "" << endl; - const Domain& domain = nodes_[i]->getDomain(); - for (unsigned j = 0; j < domain.size(); j++) { - out << "\t" << domain[j] << "" << endl; + out << "\t" << nodes_[i]->label() << "" << endl; + const States& states = nodes_[i]->states(); + for (unsigned j = 0; j < states.size(); j++) { + out << "\t" << states[j] << "" << endl; } out << "" << endl << endl; } for (unsigned i = 0; i < nodes_.size(); i++) { out << "" << endl; - out << "\t" << nodes_[i]->getLabel() << "" << endl; + out << "\t" << nodes_[i]->label() << "" << endl; const BnNodeSet& parents = nodes_[i]->getParents(); for (unsigned j = 0; j < parents.size(); j++) { - out << "\t" << parents[j]->getLabel(); + out << "\t" << parents[j]->label(); out << "" << endl; } ParamSet params = revertParameterReorder (nodes_[i]->getParameters(), - nodes_[i]->getDomainSize()); + nodes_[i]->nrStates()); out << "\t" ; for (unsigned j = 0; j < params.size(); j++) { out << " " << params[j]; @@ -571,9 +590,7 @@ BayesNet::containsUndirectedCycle (void) const bool -BayesNet::containsUndirectedCycle (int v, - int p, - vector& visited) const +BayesNet::containsUndirectedCycle (int v, int p, vector& visited) const { visited[v] = true; vector adjacencies = getAdjacentNodes (v); @@ -611,8 +628,7 @@ BayesNet::getAdjacentNodes (int v) const ParamSet -BayesNet::reorderParameters (CParamSet params, - unsigned domainSize) const +BayesNet::reorderParameters (const ParamSet& params, unsigned dsize) const { // the interchange format for bayesian networks keeps the probabilities // in the following order: @@ -623,13 +639,13 @@ BayesNet::reorderParameters (CParamSet params, // p(a1|b1,c1) p(a1|b1,c2) p(a1|b2,c1) p(a1|b2,c2) p(a2|b1,c1) p(a2|b1,c2) // p(a2|b2,c1) p(a2|b2,c2). unsigned count = 0; - unsigned rowSize = params.size() / domainSize; + unsigned rowSize = params.size() / dsize; ParamSet reordered; while (reordered.size() < params.size()) { unsigned idx = count; for (unsigned i = 0; i < rowSize; i++) { reordered.push_back (params[idx]); - idx += domainSize; + idx += dsize ; } count++; } @@ -639,15 +655,14 @@ BayesNet::reorderParameters (CParamSet params, ParamSet -BayesNet::revertParameterReorder (CParamSet params, - unsigned domainSize) const +BayesNet::revertParameterReorder (const ParamSet& params, unsigned dsize) const { unsigned count = 0; - unsigned rowSize = params.size() / domainSize; + unsigned rowSize = params.size() / dsize; ParamSet reordered; while (reordered.size() < params.size()) { unsigned idx = count; - for (unsigned i = 0; i < domainSize; i++) { + for (unsigned i = 0; i < dsize; i++) { reordered.push_back (params[idx]); idx += rowSize; } diff --git a/packages/CLPBN/clpbn/bp/BayesNet.h b/packages/CLPBN/clpbn/bp/BayesNet.h index a20f2f46c..ae4713b3b 100644 --- a/packages/CLPBN/clpbn/bp/BayesNet.h +++ b/packages/CLPBN/clpbn/bp/BayesNet.h @@ -1,5 +1,5 @@ -#ifndef BP_BAYES_NET_H -#define BP_BAYES_NET_H +#ifndef HORUS_BAYESNET_H +#define HORUS_BAYESNET_H #include #include @@ -44,61 +44,59 @@ struct StateInfo typedef vector DistSet; typedef queue > Scheduling; -typedef map Histogram; -typedef map Times; class BayesNet : public GraphicalModel { public: BayesNet (void) {}; - BayesNet (const char*); ~BayesNet (void); - BayesNode* addNode (unsigned); - BayesNode* addNode (unsigned, unsigned, int, BnNodeSet&, - Distribution*); - BayesNode* addNode (string, Domain, BnNodeSet&, ParamSet&); - BayesNode* getBayesNode (Vid) const; - BayesNode* getBayesNode (string) const; - Variable* getVariable (Vid) const; - void addDistribution (Distribution*); - Distribution* getDistribution (unsigned) const; - const BnNodeSet& getBayesNodes (void) const; - unsigned getNumberOfNodes (void) const; - BnNodeSet getRootNodes (void) const; - BnNodeSet getLeafNodes (void) const; - VarSet getVariables (void) const; - BayesNet* getMinimalRequesiteNetwork (Vid) const; - BayesNet* getMinimalRequesiteNetwork (const VidSet&) const; - void constructGraph (BayesNet*, - const vector&) const; - bool isSingleConnected (void) const; - void setIndexes (void); - void freeDistributions (void); - void printGraphicalModel (void) const; - void exportToDotFormat (const char*, bool = true, - CVidSet = VidSet()) const; - void exportToBifFormat (const char*) const; - - static Histogram histogram_; - static Times times_; + void readFromBifFormat (const char*); + void addNode (BayesNode*); + BayesNode* addNode (string, const States&); + BayesNode* addNode (VarId, unsigned, int, BnNodeSet&, Distribution*); + BayesNode* addNode (VarId, unsigned, int, Distribution*); + BayesNode* addNode (string, States, BnNodeSet&, ParamSet&); + BayesNode* getBayesNode (VarId) const; + BayesNode* getBayesNode (string) const; + VarNode* getVariableNode (VarId) const; + VarNodes getVariableNodes (void) const; + void addDistribution (Distribution*); + Distribution* getDistribution (unsigned) const; + const BnNodeSet& getBayesNodes (void) const; + unsigned nrNodes (void) const; + BnNodeSet getRootNodes (void) const; + BnNodeSet getLeafNodes (void) const; + BayesNet* getMinimalRequesiteNetwork (VarId) const; + BayesNet* getMinimalRequesiteNetwork (const VarIdSet&) const; + void constructGraph ( + BayesNet*, const vector&) const; + bool isPolyTree (void) const; + void setIndexes (void); + void distributionsToLogs (void); + void freeDistributions (void); + void printGraphicalModel (void) const; + void exportToGraphViz (const char*, bool = true, + const VarIdSet& = VarIdSet()) const; + void exportToBifFormat (const char*) const; private: DISALLOW_COPY_AND_ASSIGN (BayesNet); - bool containsUndirectedCycle (void) const; - bool containsUndirectedCycle (int, int, - vector&)const; - vector getAdjacentNodes (int) const ; - ParamSet reorderParameters (CParamSet, unsigned) const; - ParamSet revertParameterReorder (CParamSet, unsigned) const; - void scheduleParents (const BayesNode*, Scheduling&) const; - void scheduleChilds (const BayesNode*, Scheduling&) const; + bool containsUndirectedCycle (void) const; + bool containsUndirectedCycle (int, int, vector&)const; + vector getAdjacentNodes (int) const; + ParamSet reorderParameters (const ParamSet&, unsigned) const; + ParamSet revertParameterReorder (const ParamSet&, unsigned) const; + void scheduleParents (const BayesNode*, Scheduling&) const; + void scheduleChilds (const BayesNode*, Scheduling&) const; - BnNodeSet nodes_; - DistSet dists_; - IndexMap indexMap_; + BnNodeSet nodes_; + DistSet dists_; + + typedef unordered_map IndexMap; + IndexMap indexMap_; }; @@ -123,5 +121,5 @@ BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const } } -#endif //BP_BAYES_NET_H +#endif // HORUS_BAYESNET_H diff --git a/packages/CLPBN/clpbn/bp/BayesNode.cpp b/packages/CLPBN/clpbn/bp/BayesNode.cpp index d2ac88cb0..d828d5eb3 100644 --- a/packages/CLPBN/clpbn/bp/BayesNode.cpp +++ b/packages/CLPBN/clpbn/bp/BayesNode.cpp @@ -1,34 +1,30 @@ #include #include +#include #include #include -#include #include "BayesNode.h" -BayesNode::BayesNode (Vid vid, +BayesNode::BayesNode (VarId vid, unsigned dsize, int evidence, - const BnNodeSet& parents, - Distribution* dist) : Variable (vid, dsize, evidence) + Distribution* dist) + : VarNode (vid, dsize, evidence) { - parents_ = parents; - dist_ = dist; - for (unsigned int i = 0; i < parents.size(); i++) { - parents[i]->addChild (this); - } + dist_ = dist; } -BayesNode::BayesNode (Vid vid, - string label, - const Domain& domain, +BayesNode::BayesNode (VarId vid, + unsigned dsize, + int evidence, const BnNodeSet& parents, - Distribution* dist) : Variable (vid, domain, - NO_EVIDENCE, label) + Distribution* dist) + : VarNode (vid, dsize, evidence) { parents_ = parents; dist_ = dist; @@ -40,18 +36,12 @@ BayesNode::BayesNode (Vid vid, void -BayesNode::setData (unsigned dsize, - int evidence, - const BnNodeSet& parents, - Distribution* dist) +BayesNode::setParents (const BnNodeSet& parents) { - setDomainSize (dsize); - setEvidence (evidence); - parents_ = parents; - dist_ = dist; + parents_ = parents; for (unsigned int i = 0; i < parents.size(); i++) { parents[i]->addChild (this); - } + } } @@ -64,6 +54,15 @@ BayesNode::addChild (BayesNode* node) +void +BayesNode::setDistribution (Distribution* dist) +{ + assert (dist); + dist_ = dist; +} + + + Distribution* BayesNode::getDistribution (void) { @@ -140,14 +139,14 @@ BayesNode::getCptEntries (void) for (int i = parents_.size() - 1; i >= 0; i--) { unsigned index = 0; while (index < rowSize) { - for (unsigned j = 0; j < parents_[i]->getDomainSize(); j++) { + for (unsigned j = 0; j < parents_[i]->nrStates(); j++) { for (unsigned r = 0; r < nReps; r++) { confs[index][i] = j; index++; } } } - nReps *= parents_[i]->getDomainSize(); + nReps *= parents_[i]->nrStates(); } dist_->entries.reserve (rowSize); @@ -180,14 +179,14 @@ BayesNode::cptEntryToString (const CptEntry& entry) const ss << "p(" ; const DConf& conf = entry.getDomainConfiguration(); int row = entry.getParameterIndex() / getRowSize(); - ss << getDomain()[row]; + ss << states()[row]; if (parents_.size() > 0) { ss << "|" ; for (unsigned int i = 0; i < conf.size(); i++) { if (i != 0) { ss << ","; } - ss << parents_[i]->getDomain()[conf[i]]; + ss << parents_[i]->states()[conf[i]]; } } ss << ")" ; @@ -202,14 +201,14 @@ BayesNode::cptEntryToString (int row, const CptEntry& entry) const stringstream ss; ss << "p(" ; const DConf& conf = entry.getDomainConfiguration(); - ss << getDomain()[row]; + ss << states()[row]; if (parents_.size() > 0) { ss << "|" ; for (unsigned int i = 0; i < conf.size(); i++) { if (i != 0) { ss << ","; } - ss << parents_[i]->getDomain()[conf[i]]; + ss << parents_[i]->states()[conf[i]]; } } ss << ")" ; @@ -226,21 +225,21 @@ BayesNode::getDomainHeaders (void) const unsigned nReps = 1; vector headers (rowSize); for (int i = nParents - 1; i >= 0; i--) { - Domain domain = parents_[i]->getDomain(); + States states = parents_[i]->states(); unsigned index = 0; while (index < rowSize) { - for (unsigned j = 0; j < parents_[i]->getDomainSize(); j++) { + for (unsigned j = 0; j < parents_[i]->nrStates(); j++) { for (unsigned r = 0; r < nReps; r++) { if (headers[index] != "") { - headers[index] = domain[j] + "," + headers[index]; + headers[index] = states[j] + "," + headers[index]; } else { - headers[index] = domain[j]; + headers[index] = states[j]; } index++; } } } - nReps *= parents_[i]->getDomainSize(); + nReps *= parents_[i]->nrStates(); } return headers; } @@ -251,8 +250,8 @@ ostream& operator << (ostream& o, const BayesNode& node) { o << "variable " << node.getIndex() << endl; - o << "Var Id: " << node.getVarId() << endl; - o << "Label: " << node.getLabel() << endl; + o << "Var Id: " << node.varId() << endl; + o << "Label: " << node.label() << endl; o << "Evidence: " ; if (node.hasEvidence()) { @@ -267,9 +266,9 @@ operator << (ostream& o, const BayesNode& node) const BnNodeSet& parents = node.getParents(); if (parents.size() != 0) { for (unsigned int i = 0; i < parents.size() - 1; i++) { - o << parents[i]->getLabel() << ", " ; + o << parents[i]->label() << ", " ; } - o << parents[parents.size() - 1]->getLabel(); + o << parents[parents.size() - 1]->label(); } o << endl; @@ -277,19 +276,19 @@ operator << (ostream& o, const BayesNode& node) const BnNodeSet& childs = node.getChilds(); if (childs.size() != 0) { for (unsigned int i = 0; i < childs.size() - 1; i++) { - o << childs[i]->getLabel() << ", " ; + o << childs[i]->label() << ", " ; } - o << childs[childs.size() - 1]->getLabel(); + o << childs[childs.size() - 1]->label(); } o << endl; o << "Domain: " ; - Domain domain = node.getDomain(); - for (unsigned int i = 0; i < domain.size() - 1; i++) { - o << domain[i] << ", " ; + States states = node.states(); + for (unsigned int i = 0; i < states.size() - 1; i++) { + o << states[i] << ", " ; } - if (domain.size() != 0) { - o << domain[domain.size() - 1]; + if (states.size() != 0) { + o << states[states.size() - 1]; } o << endl; @@ -298,10 +297,10 @@ operator << (ostream& o, const BayesNode& node) // min width of following columns const unsigned int MIN_COMBO_WIDTH = 12; - unsigned int domainWidth = domain[0].length(); - for (unsigned int i = 1; i < domain.size(); i++) { - if (domain[i].length() > domainWidth) { - domainWidth = domain[i].length(); + unsigned int domainWidth = states[0].length(); + for (unsigned int i = 1; i < states.size(); i++) { + if (states[i].length() > domainWidth) { + domainWidth = states[i].length(); } } domainWidth = (domainWidth < MIN_DOMAIN_WIDTH) @@ -334,9 +333,9 @@ operator << (ostream& o, const BayesNode& node) } o << endl; - for (unsigned int i = 0; i < domain.size(); i++) { + for (unsigned int i = 0; i < states.size(); i++) { ParamSet row = node.getRow (i); - o << left << setw (domainWidth) << domain[i] << right; + o << left << setw (domainWidth) << states[i] << right; for (unsigned j = 0; j < node.getRowSize(); j++) { o << setw (widths[j]) << row[j]; } diff --git a/packages/CLPBN/clpbn/bp/BayesNode.h b/packages/CLPBN/clpbn/bp/BayesNode.h index 7e8e7780e..2e0086aa6 100644 --- a/packages/CLPBN/clpbn/bp/BayesNode.h +++ b/packages/CLPBN/clpbn/bp/BayesNode.h @@ -1,9 +1,9 @@ -#ifndef BP_BAYES_NODE_H -#define BP_BAYES_NODE_H +#ifndef HORUS_BAYESNODE_H +#define HORUS_BAYESNODE_H #include -#include "Variable.h" +#include "VarNode.h" #include "CptEntry.h" #include "Distribution.h" #include "Shared.h" @@ -11,16 +11,16 @@ using namespace std; -class BayesNode : public Variable +class BayesNode : public VarNode { public: - BayesNode (Vid vid) : Variable (vid) {} - BayesNode (Vid, unsigned, int, const BnNodeSet&, Distribution*); - BayesNode (Vid, string, const Domain&, const BnNodeSet&, Distribution*); + BayesNode (const VarNode& v) : VarNode (v) {} + BayesNode (VarId, unsigned, int, Distribution*); + BayesNode (VarId, unsigned, int, const BnNodeSet&, Distribution*); - void setData (unsigned, int, const BnNodeSet&, - Distribution*); + void setParents (const BnNodeSet&); void addChild (BayesNode*); + void setDistribution (Distribution*); Distribution* getDistribution (void); const ParamSet& getParameters (void); ParamSet getRow (int) const; @@ -34,12 +34,12 @@ class BayesNode : public Variable string cptEntryToString (const CptEntry&) const; string cptEntryToString (int, const CptEntry&) const; - const BnNodeSet& getParents (void) const { return parents_; } - const BnNodeSet& getChilds (void) const { return childs_; } + const BnNodeSet& getParents (void) const { return parents_; } + const BnNodeSet& getChilds (void) const { return childs_; } unsigned getRowSize (void) const { - return dist_->params.size() / getDomainSize(); + return dist_->params.size() / nrStates(); } double getProbability (int row, const CptEntry& entry) @@ -52,7 +52,7 @@ class BayesNode : public Variable private: DISALLOW_COPY_AND_ASSIGN (BayesNode); - Domain getDomainHeaders (void) const; + States getDomainHeaders (void) const; friend ostream& operator << (ostream&, const BayesNode&); BnNodeSet parents_; @@ -62,5 +62,5 @@ class BayesNode : public Variable ostream& operator << (ostream&, const BayesNode&); -#endif //BP_BAYES_NODE_H +#endif // HORUS_BAYESNODE_H diff --git a/packages/CLPBN/clpbn/bp/BnBpSolver.cpp b/packages/CLPBN/clpbn/bp/BnBpSolver.cpp new file mode 100644 index 000000000..f5206c65c --- /dev/null +++ b/packages/CLPBN/clpbn/bp/BnBpSolver.cpp @@ -0,0 +1,962 @@ +#include +#include +#include + +#include + +#include +#include +#include + +#include "BnBpSolver.h" + +BnBpSolver::BnBpSolver (const BayesNet& bn) : Solver (&bn) +{ + bayesNet_ = &bn; + jointCalcType_ = CHAIN_RULE; + //jointCalcType_ = JUNCTION_NODE; +} + + + +BnBpSolver::~BnBpSolver (void) +{ + for (unsigned i = 0; i < nodesI_.size(); i++) { + delete nodesI_[i]; + } + for (unsigned i = 0; i < links_.size(); i++) { + delete links_[i]; + } +} + + + +void +BnBpSolver::runSolver (void) +{ + clock_t start; + if (COLLECT_STATISTICS) { + start = clock(); + } + initializeSolver(); + if (!BpOptions::useAlwaysLoopySolver && bayesNet_->isPolyTree()) { + runPolyTreeSolver(); + } else { + runLoopySolver(); + if (DL >= 2) { + cout << endl; + if (nIters_ < BpOptions::maxIter) { + cout << "Belief propagation converged in " ; + cout << nIters_ << " iterations" << endl; + } else { + cout << "The maximum number of iterations was hit, terminating..." ; + cout << endl; + } + } + } + unsigned size = bayesNet_->nrNodes(); + if (COLLECT_STATISTICS) { + unsigned nIters = 0; + bool loopy = bayesNet_->isPolyTree() == false; + if (loopy) nIters = nIters_; + double time = (double (clock() - start)) / CLOCKS_PER_SEC; + Statistics::updateStatistics (size, loopy, nIters, time); + } + if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) { + stringstream ss; + ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ; + bayesNet_->exportToGraphViz (ss.str().c_str()); + } +} + + + +ParamSet +BnBpSolver::getPosterioriOf (VarId vid) +{ + BayesNode* node = bayesNet_->getBayesNode (vid); + assert (node); + return nodesI_[node->getIndex()]->getBeliefs(); +} + + + +ParamSet +BnBpSolver::getJointDistributionOf (const VarIdSet& jointVarIds) +{ + if (DL >= 2) { + cout << "calculating joint distribution on: " ; + for (unsigned i = 0; i < jointVarIds.size(); i++) { + VarNode* var = bayesNet_->getBayesNode (jointVarIds[i]); + cout << var->label() << " " ; + } + cout << endl; + } + + if (jointCalcType_ == JUNCTION_NODE) { + return getJointByJunctionNode (jointVarIds); + } else { + return getJointByChainRule (jointVarIds); + } +} + + + +void +BnBpSolver::initializeSolver (void) +{ + const BnNodeSet& nodes = bayesNet_->getBayesNodes(); + for (unsigned i = 0; i < nodesI_.size(); i++) { + delete nodesI_[i]; + } + nodesI_.clear(); + nodesI_.reserve (nodes.size()); + links_.clear(); + sortedOrder_.clear(); + linkMap_.clear(); + + for (unsigned i = 0; i < nodes.size(); i++) { + nodesI_.push_back (new BpNodeInfo (nodes[i])); + } + + BnNodeSet roots = bayesNet_->getRootNodes(); + for (unsigned i = 0; i < roots.size(); i++) { + const ParamSet& params = roots[i]->getParameters(); + ParamSet& piVals = ninf(roots[i])->getPiValues(); + for (unsigned ri = 0; ri < roots[i]->nrStates(); ri++) { + piVals[ri] = params[ri]; + } + } + + for (unsigned i = 0; i < nodes.size(); i++) { + const BnNodeSet& parents = nodes[i]->getParents(); + for (unsigned j = 0; j < parents.size(); j++) { + BpLink* newLink = new BpLink ( + parents[j], nodes[i], LinkOrientation::DOWN); + links_.push_back (newLink); + ninf(nodes[i])->addIncomingParentLink (newLink); + ninf(parents[j])->addOutcomingChildLink (newLink); + } + const BnNodeSet& childs = nodes[i]->getChilds(); + for (unsigned j = 0; j < childs.size(); j++) { + BpLink* newLink = new BpLink ( + childs[j], nodes[i], LinkOrientation::UP); + links_.push_back (newLink); + ninf(nodes[i])->addIncomingChildLink (newLink); + ninf(childs[j])->addOutcomingParentLink (newLink); + } + } + + for (unsigned i = 0; i < nodes.size(); i++) { + if (nodes[i]->hasEvidence()) { + ParamSet& piVals = ninf(nodes[i])->getPiValues(); + ParamSet& ldVals = ninf(nodes[i])->getLambdaValues(); + for (unsigned xi = 0; xi < nodes[i]->nrStates(); xi++) { + piVals[xi] = Util::noEvidence(); + ldVals[xi] = Util::noEvidence(); + } + piVals[nodes[i]->getEvidence()] = Util::withEvidence(); + ldVals[nodes[i]->getEvidence()] = Util::withEvidence(); + } + } +} + + + +void +BnBpSolver::runPolyTreeSolver (void) +{ + const BnNodeSet& nodes = bayesNet_->getBayesNodes(); + for (unsigned i = 0; i < nodes.size(); i++) { + if (nodes[i]->isRoot()) { + ninf(nodes[i])->markPiValuesAsCalculated(); + } + if (nodes[i]->isLeaf()) { + ninf(nodes[i])->markLambdaValuesAsCalculated(); + } + } + + bool finish = false; + while (!finish) { + finish = true; + for (unsigned i = 0; i < nodes.size(); i++) { + if (ninf(nodes[i])->piValuesCalculated() == false + && ninf(nodes[i])->receivedAllPiMessages()) { + if (!nodes[i]->hasEvidence()) { + updatePiValues (nodes[i]); + } + ninf(nodes[i])->markPiValuesAsCalculated(); + finish = false; + } + + if (ninf(nodes[i])->lambdaValuesCalculated() == false + && ninf(nodes[i])->receivedAllLambdaMessages()) { + if (!nodes[i]->hasEvidence()) { + updateLambdaValues (nodes[i]); + } + ninf(nodes[i])->markLambdaValuesAsCalculated(); + finish = false; + } + + if (ninf(nodes[i])->piValuesCalculated()) { + const BpLinkSet& outChildLinks + = ninf(nodes[i])->getOutcomingChildLinks(); + for (unsigned j = 0; j < outChildLinks.size(); j++) { + BayesNode* child = outChildLinks[j]->getDestination(); + if (!outChildLinks[j]->messageWasSended()) { + if (ninf(nodes[i])->readyToSendPiMsgTo (child)) { + calculateAndUpdateMessage (outChildLinks[j], false); + ninf(child)->incNumPiMsgsReceived(); + } + finish = false; + } + } + } + + if (ninf(nodes[i])->lambdaValuesCalculated()) { + const BpLinkSet& outParentLinks = + ninf(nodes[i])->getOutcomingParentLinks(); + for (unsigned j = 0; j < outParentLinks.size(); j++) { + BayesNode* parent = outParentLinks[j]->getDestination(); + if (!outParentLinks[j]->messageWasSended()) { + if (ninf(nodes[i])->readyToSendLambdaMsgTo (parent)) { + calculateAndUpdateMessage (outParentLinks[j], false); + ninf(parent)->incNumLambdaMsgsReceived(); + } + finish = false; + } + } + } + } + } +} + + + +void +BnBpSolver::runLoopySolver() +{ + nIters_ = 0; + while (!converged() && nIters_ < BpOptions::maxIter) { + + nIters_++; + if (DL >= 2) { + cout << "****************************************" ; + cout << "****************************************" ; + cout << endl; + cout << " Iteration " << nIters_ << endl; + cout << "****************************************" ; + cout << "****************************************" ; + cout << endl; + } + + switch (BpOptions::schedule) { + + case BpOptions::Schedule::SEQ_RANDOM: + random_shuffle (links_.begin(), links_.end()); + // no break + + case BpOptions::Schedule::SEQ_FIXED: + for (unsigned i = 0; i < links_.size(); i++) { + calculateAndUpdateMessage (links_[i]); + updateValues (links_[i]); + } + break; + + case BpOptions::Schedule::PARALLEL: + for (unsigned i = 0; i < links_.size(); i++) { + calculateMessage (links_[i]); + } + for (unsigned i = 0; i < links_.size(); i++) { + updateMessage (links_[i]); + updateValues (links_[i]); + } + break; + + case BpOptions::Schedule::MAX_RESIDUAL: + maxResidualSchedule(); + break; + + } + if (DL >= 2) { + cout << endl; + } + } +} + + + +bool +BnBpSolver::converged (void) const +{ + // this can happen if the graph is fully disconnected + if (links_.size() == 0) { + return true; + } + if (nIters_ == 0 || nIters_ == 1) { + return false; + } + bool converged = true; + if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) { + Param maxResidual = (*(sortedOrder_.begin()))->getResidual(); + if (maxResidual < BpOptions::accuracy) { + converged = true; + } else { + converged = false; + } + } else { + for (unsigned i = 0; i < links_.size(); i++) { + Param residual = links_[i]->getResidual(); + if (DL >= 2) { + cout << links_[i]->toString() + " residual change = " ; + cout << residual << endl; + } + if (residual > BpOptions::accuracy) { + converged = false; + break; + } + } + } + return converged; +} + + + +void +BnBpSolver::maxResidualSchedule (void) +{ + if (nIters_ == 1) { + for (unsigned i = 0; i < links_.size(); i++) { + calculateMessage (links_[i]); + SortedOrder::iterator it = sortedOrder_.insert (links_[i]); + linkMap_.insert (make_pair (links_[i], it)); + } + return; + } + + for (unsigned c = 0; c < sortedOrder_.size(); c++) { + if (DL >= 2) { + cout << "current residuals:" << endl; + for (SortedOrder::iterator it = sortedOrder_.begin(); + it != sortedOrder_.end(); it ++) { + cout << " " << setw (30) << left << (*it)->toString(); + cout << "residual = " << (*it)->getResidual() << endl; + } + } + + SortedOrder::iterator it = sortedOrder_.begin(); + BpLink* link = *it; + if (link->getResidual() < BpOptions::accuracy) { + sortedOrder_.erase (it); + it = sortedOrder_.begin(); + return; + } + updateMessage (link); + updateValues (link); + link->clearResidual(); + sortedOrder_.erase (it); + linkMap_.find (link)->second = sortedOrder_.insert (link); + + const BpLinkSet& outParentLinks = + ninf(link->getDestination())->getOutcomingParentLinks(); + for (unsigned i = 0; i < outParentLinks.size(); i++) { + if (outParentLinks[i]->getDestination() != link->getSource() + && outParentLinks[i]->getDestination()->hasEvidence() == false) { + calculateMessage (outParentLinks[i]); + BpLinkMap::iterator iter = linkMap_.find (outParentLinks[i]); + sortedOrder_.erase (iter->second); + iter->second = sortedOrder_.insert (outParentLinks[i]); + } + } + const BpLinkSet& outChildLinks = + ninf(link->getDestination())->getOutcomingChildLinks(); + for (unsigned i = 0; i < outChildLinks.size(); i++) { + if (outChildLinks[i]->getDestination() != link->getSource()) { + calculateMessage (outChildLinks[i]); + BpLinkMap::iterator iter = linkMap_.find (outChildLinks[i]); + sortedOrder_.erase (iter->second); + iter->second = sortedOrder_.insert (outChildLinks[i]); + } + } + + if (DL >= 2) { + cout << "----------------------------------------" ; + cout << "----------------------------------------" << endl; + } + } +} + + + +void +BnBpSolver::updatePiValues (BayesNode* x) +{ + // π(Xi) + if (DL >= 3) { + cout << "updating " << PI_SYMBOL << " values for " << x->label() << endl; + } + ParamSet& piValues = ninf(x)->getPiValues(); + const BpLinkSet& parentLinks = ninf(x)->getIncomingParentLinks(); + const vector& entries = x->getCptEntries(); + stringstream* calcs1 = 0; + stringstream* calcs2 = 0; + + ParamSet messageProducts (entries.size()); + for (unsigned k = 0; k < entries.size(); k++) { + if (DL >= 5) { + calcs1 = new stringstream; + calcs2 = new stringstream; + } + double messageProduct = Util::multIdenty(); + const DConf& conf = entries[k].getDomainConfiguration(); + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned i = 0; i < parentLinks.size(); i++) { + messageProduct *= parentLinks[i]->getMessage()[conf[i]]; + if (DL >= 5) { + if (i != 0) *calcs1 << " + " ; + if (i != 0) *calcs2 << " + " ; + *calcs1 << parentLinks[i]->toString (conf[i]); + *calcs2 << parentLinks[i]->getMessage()[conf[i]]; + } + } + break; + case NumberSpace::LOGARITHM: + for (unsigned i = 0; i < parentLinks.size(); i++) { + messageProduct += parentLinks[i]->getMessage()[conf[i]]; + } + } + messageProducts[k] = messageProduct; + if (DL >= 5) { + cout << " mp" << k; + cout << " = " << (*calcs1).str(); + if (parentLinks.size() == 1) { + cout << " = " << messageProduct << endl; + } else { + cout << " = " << (*calcs2).str(); + cout << " = " << messageProduct << endl; + } + delete calcs1; + delete calcs2; + } + } + + for (unsigned xi = 0; xi < x->nrStates(); xi++) { + double sum = Util::addIdenty(); + if (DL >= 5) { + calcs1 = new stringstream; + calcs2 = new stringstream; + } + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned k = 0; k < entries.size(); k++) { + sum += x->getProbability (xi, entries[k]) * messageProducts[k]; + if (DL >= 5) { + if (k != 0) *calcs1 << " + " ; + if (k != 0) *calcs2 << " + " ; + *calcs1 << x->cptEntryToString (xi, entries[k]); + *calcs1 << ".mp" << k; + *calcs2 << Util::fl (x->getProbability (xi, entries[k])); + *calcs2 << "*" << messageProducts[k]; + } + } + break; + case NumberSpace::LOGARITHM: + for (unsigned k = 0; k < entries.size(); k++) { + Util::logSum (sum, + x->getProbability(xi,entries[k]) + messageProducts[k]); + } + } + piValues[xi] = sum; + if (DL >= 5) { + cout << " " << PI_SYMBOL << "(" << x->label() << ")" ; + cout << "[" << x->states()[xi] << "]" ; + cout << " = " << (*calcs1).str(); + cout << " = " << (*calcs2).str(); + cout << " = " << piValues[xi] << endl; + delete calcs1; + delete calcs2; + } + } +} + + + +void +BnBpSolver::updateLambdaValues (BayesNode* x) +{ + // λ(Xi) + if (DL >= 3) { + cout << "updating " << LD_SYMBOL << " values for " << x->label() << endl; + } + ParamSet& lambdaValues = ninf(x)->getLambdaValues(); + const BpLinkSet& childLinks = ninf(x)->getIncomingChildLinks(); + stringstream* calcs1 = 0; + stringstream* calcs2 = 0; + + for (unsigned xi = 0; xi < x->nrStates(); xi++) { + if (DL >= 5) { + calcs1 = new stringstream; + calcs2 = new stringstream; + } + double product = Util::multIdenty(); + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned i = 0; i < childLinks.size(); i++) { + product *= childLinks[i]->getMessage()[xi]; + if (DL >= 5) { + if (i != 0) *calcs1 << "." ; + if (i != 0) *calcs2 << "*" ; + *calcs1 << childLinks[i]->toString (xi); + *calcs2 << childLinks[i]->getMessage()[xi]; + } + } + break; + case NumberSpace::LOGARITHM: + for (unsigned i = 0; i < childLinks.size(); i++) { + product += childLinks[i]->getMessage()[xi]; + } + } + lambdaValues[xi] = product; + if (DL >= 5) { + cout << " " << LD_SYMBOL << "(" << x->label() << ")" ; + cout << "[" << x->states()[xi] << "]" ; + cout << " = " << (*calcs1).str(); + if (childLinks.size() == 1) { + cout << " = " << product << endl; + } else { + cout << " = " << (*calcs2).str(); + cout << " = " << lambdaValues[xi] << endl; + } + delete calcs1; + delete calcs2; + } + } +} + + + +void +BnBpSolver::calculatePiMessage (BpLink* link) +{ + // πX(Zi) + BayesNode* z = link->getSource(); + BayesNode* x = link->getDestination(); + ParamSet& zxPiNextMessage = link->getNextMessage(); + const BpLinkSet& zChildLinks = ninf(z)->getIncomingChildLinks(); + stringstream* calcs1 = 0; + stringstream* calcs2 = 0; + + const ParamSet& zPiValues = ninf(z)->getPiValues(); + for (unsigned zi = 0; zi < z->nrStates(); zi++) { + double product = zPiValues[zi]; + if (DL >= 5) { + calcs1 = new stringstream; + calcs2 = new stringstream; + *calcs1 << PI_SYMBOL << "(" << z->label() << ")"; + *calcs1 << "[" << z->states()[zi] << "]" ; + *calcs2 << product; + } + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned i = 0; i < zChildLinks.size(); i++) { + if (zChildLinks[i]->getSource() != x) { + product *= zChildLinks[i]->getMessage()[zi]; + if (DL >= 5) { + *calcs1 << "." << zChildLinks[i]->toString (zi); + *calcs2 << " * " << zChildLinks[i]->getMessage()[zi]; + } + } + } + break; + case NumberSpace::LOGARITHM: + for (unsigned i = 0; i < zChildLinks.size(); i++) { + if (zChildLinks[i]->getSource() != x) { + product += zChildLinks[i]->getMessage()[zi]; + } + } + } + zxPiNextMessage[zi] = product; + if (DL >= 5) { + cout << " " << link->toString(); + cout << "[" << z->states()[zi] << "]" ; + cout << " = " << (*calcs1).str(); + if (zChildLinks.size() == 1) { + cout << " = " << product << endl; + } else { + cout << " = " << (*calcs2).str(); + cout << " = " << product << endl; + } + delete calcs1; + delete calcs2; + } + } + Util::normalize (zxPiNextMessage); +} + + + +void +BnBpSolver::calculateLambdaMessage (BpLink* link) +{ + // λY(Xi) + BayesNode* y = link->getSource(); + BayesNode* x = link->getDestination(); + if (x->hasEvidence()) { + return; + } + ParamSet& yxLambdaNextMessage = link->getNextMessage(); + const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks(); + const ParamSet& yLambdaValues = ninf(y)->getLambdaValues(); + const vector& allEntries = y->getCptEntries(); + int parentIndex = y->getIndexOfParent (x); + stringstream* calcs1 = 0; + stringstream* calcs2 = 0; + + vector entries; + DConstraint constr = make_pair (parentIndex, 0); + for (unsigned i = 0; i < allEntries.size(); i++) { + if (allEntries[i].matchConstraints(constr)) { + entries.push_back (allEntries[i]); + } + } + + ParamSet messageProducts (entries.size()); + for (unsigned k = 0; k < entries.size(); k++) { + if (DL >= 5) { + calcs1 = new stringstream; + calcs2 = new stringstream; + } + double messageProduct = Util::multIdenty(); + const DConf& conf = entries[k].getDomainConfiguration(); + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned i = 0; i < yParentLinks.size(); i++) { + if (yParentLinks[i]->getSource() != x) { + if (DL >= 5) { + if (messageProduct != Util::multIdenty()) *calcs1 << "*" ; + if (messageProduct != Util::multIdenty()) *calcs2 << "*" ; + *calcs1 << yParentLinks[i]->toString (conf[i]); + *calcs2 << yParentLinks[i]->getMessage()[conf[i]]; + } + messageProduct *= yParentLinks[i]->getMessage()[conf[i]]; + } + } + break; + case NumberSpace::LOGARITHM: + for (unsigned i = 0; i < yParentLinks.size(); i++) { + if (yParentLinks[i]->getSource() != x) { + messageProduct += yParentLinks[i]->getMessage()[conf[i]]; + } + } + } + messageProducts[k] = messageProduct; + if (DL >= 5) { + cout << " mp" << k; + cout << " = " << (*calcs1).str(); + if (yParentLinks.size() == 1) { + cout << 1 << endl; + } else if (yParentLinks.size() == 2) { + cout << " = " << messageProduct << endl; + } else { + cout << " = " << (*calcs2).str(); + cout << " = " << messageProduct << endl; + } + delete calcs1; + delete calcs2; + } + } + + for (unsigned xi = 0; xi < x->nrStates(); xi++) { + if (DL >= 5) { + calcs1 = new stringstream; + calcs2 = new stringstream; + } + vector entries; + DConstraint constr = make_pair (parentIndex, xi); + for (unsigned i = 0; i < allEntries.size(); i++) { + if (allEntries[i].matchConstraints(constr)) { + entries.push_back (allEntries[i]); + } + } + double outerSum = Util::addIdenty(); + for (unsigned yi = 0; yi < y->nrStates(); yi++) { + if (DL >= 5) { + (yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ; + (yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ; + } + double innerSum = Util::addIdenty(); + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned k = 0; k < entries.size(); k++) { + if (DL >= 5) { + if (k != 0) *calcs1 << " + " ; + if (k != 0) *calcs2 << " + " ; + *calcs1 << y->cptEntryToString (yi, entries[k]); + *calcs1 << ".mp" << k; + *calcs2 << y->getProbability (yi, entries[k]); + *calcs2 << "*" << messageProducts[k]; + } + innerSum += y->getProbability (yi, entries[k]) * messageProducts[k]; + } + outerSum += innerSum * yLambdaValues[yi]; + break; + case NumberSpace::LOGARITHM: + for (unsigned k = 0; k < entries.size(); k++) { + Util::logSum (innerSum, + y->getProbability(yi, entries[k]) + messageProducts[k]); + } + Util::logSum (outerSum, innerSum + yLambdaValues[yi]); + } + if (DL >= 5) { + *calcs1 << "}." << LD_SYMBOL << "(" << y->label() << ")" ; + *calcs1 << "[" << y->states()[yi] << "]"; + *calcs2 << "}*" << yLambdaValues[yi]; + } + } + yxLambdaNextMessage[xi] = outerSum; + if (DL >= 5) { + cout << " " << link->toString(); + cout << "[" << x->states()[xi] << "]" ; + cout << " = " << (*calcs1).str(); + cout << " = " << (*calcs2).str(); + cout << " = " << yxLambdaNextMessage[xi] << endl; + delete calcs1; + delete calcs2; + } + } + Util::normalize (yxLambdaNextMessage); +} + + + +ParamSet +BnBpSolver::getJointByJunctionNode (const VarIdSet& jointVarIds) +{ + unsigned msgSize = 1; + vector dsizes (jointVarIds.size()); + for (unsigned i = 0; i < jointVarIds.size(); i++) { + dsizes[i] = bayesNet_->getBayesNode (jointVarIds[i])->nrStates(); + msgSize *= dsizes[i]; + } + unsigned reps = 1; + ParamSet jointDist (msgSize, Util::multIdenty()); + for (int i = jointVarIds.size() - 1 ; i >= 0; i--) { + Util::multiply (jointDist, getPosterioriOf (jointVarIds[i]), reps); + reps *= dsizes[i] ; + } + return jointDist; +} + + + +ParamSet +BnBpSolver::getJointByChainRule (const VarIdSet& jointVarIds) const +{ + BnNodeSet jointVars; + for (unsigned i = 0; i < jointVarIds.size(); i++) { + jointVars.push_back (bayesNet_->getBayesNode (jointVarIds[i])); + } + + BayesNet* mrn = bayesNet_->getMinimalRequesiteNetwork (jointVarIds[0]); + BnBpSolver solver (*mrn); + solver.runSolver(); + ParamSet prevBeliefs = solver.getPosterioriOf (jointVarIds[0]); + delete mrn; + + VarNodes observedVars = {jointVars[0]}; + + for (unsigned i = 1; i < jointVarIds.size(); i++) { + mrn = bayesNet_->getMinimalRequesiteNetwork (jointVarIds[i]); + ParamSet newBeliefs; + vector confs = + Util::getDomainConfigurations (observedVars); + for (unsigned j = 0; j < confs.size(); j++) { + for (unsigned k = 0; k < observedVars.size(); k++) { + if (!observedVars[k]->hasEvidence()) { + BayesNode* node = mrn->getBayesNode (observedVars[k]->varId()); + if (node) { + node->setEvidence (confs[j][k]); + } + } + } + BnBpSolver solver (*mrn); + solver.runSolver(); + ParamSet beliefs = solver.getPosterioriOf (jointVarIds[i]); + for (unsigned k = 0; k < beliefs.size(); k++) { + newBeliefs.push_back (beliefs[k]); + } + } + + int count = -1; + for (unsigned j = 0; j < newBeliefs.size(); j++) { + if (j % jointVars[i]->nrStates() == 0) { + count ++; + } + newBeliefs[j] *= prevBeliefs[count]; + } + prevBeliefs = newBeliefs; + observedVars.push_back (jointVars[i]); + delete mrn; + } + return prevBeliefs; +} + + + +void +BnBpSolver::printPiLambdaValues (const BayesNode* var) const +{ + cout << left; + cout << setw (10) << "states" ; + cout << setw (20) << PI_SYMBOL << "(" + var->label() + ")" ; + cout << setw (20) << LD_SYMBOL << "(" + var->label() + ")" ; + cout << setw (16) << "belief" ; + cout << endl; + cout << "--------------------------------" ; + cout << "--------------------------------" ; + cout << endl; + const States& states = var->states(); + const ParamSet& piVals = ninf(var)->getPiValues(); + const ParamSet& ldVals = ninf(var)->getLambdaValues(); + const ParamSet& beliefs = ninf(var)->getBeliefs(); + for (unsigned xi = 0; xi < var->nrStates(); xi++) { + cout << setw (10) << states[xi]; + cout << setw (19) << piVals[xi]; + cout << setw (19) << ldVals[xi]; + cout.precision (PRECISION); + cout << setw (16) << beliefs[xi]; + cout << endl; + } + cout << endl; +} + + + +void +BnBpSolver::printAllMessageStatus (void) const +{ + const BnNodeSet& nodes = bayesNet_->getBayesNodes(); + for (unsigned i = 0; i < nodes.size(); i++) { + printPiLambdaValues (nodes[i]); + } +} + + + +BpNodeInfo::BpNodeInfo (BayesNode* node) +{ + node_ = node; + piValsCalc_ = false; + ldValsCalc_ = false; + nPiMsgsRcv_ = 0; + nLdMsgsRcv_ = 0; + piVals_.resize (node->nrStates(), Util::one()); + ldVals_.resize (node->nrStates(), Util::one()); +} + + + +ParamSet +BpNodeInfo::getBeliefs (void) const +{ + double sum = 0.0; + ParamSet beliefs (node_->nrStates()); + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned xi = 0; xi < node_->nrStates(); xi++) { + beliefs[xi] = piVals_[xi] * ldVals_[xi]; + sum += beliefs[xi]; + } + break; + case NumberSpace::LOGARITHM: + for (unsigned xi = 0; xi < node_->nrStates(); xi++) { + beliefs[xi] = exp (piVals_[xi] + ldVals_[xi]); + sum += beliefs[xi]; + } + } + assert (sum); + for (unsigned xi = 0; xi < node_->nrStates(); xi++) { + beliefs[xi] /= sum; + } + return beliefs; +} + + + +void +BpNodeInfo::markPiValuesAsCalculated (void) +{ + piValsCalc_ = true; +} + + + +void +BpNodeInfo::markLambdaValuesAsCalculated (void) +{ + ldValsCalc_ = true; +} + + + +bool +BpNodeInfo::receivedAllPiMessages (void) +{ + return node_->getParents().size() == nPiMsgsRcv_; +} + + + +bool +BpNodeInfo::receivedAllLambdaMessages (void) +{ + return node_->getChilds().size() == nLdMsgsRcv_; +} + + + +bool +BpNodeInfo::readyToSendPiMsgTo (const BayesNode* child) const +{ + for (unsigned i = 0; i < inChildLinks_.size(); i++) { + if (inChildLinks_[i]->getSource() != child + && inChildLinks_[i]->messageWasSended() == false) { + return false; + } + } + return true; +} + + + +bool +BpNodeInfo::readyToSendLambdaMsgTo (const BayesNode* parent) const +{ + for (unsigned i = 0; i < inParentLinks_.size(); i++) { + if (inParentLinks_[i]->getSource() != parent + && inParentLinks_[i]->messageWasSended() == false) { + return false; + } + } + return true; +} + + + +bool +BpNodeInfo::receivedBottomInfluence (void) const +{ + // if all lambda values are equal, then neither + // this node neither its descendents have evidence, + // we can use this to don't send lambda messages his parents + bool childInfluenced = false; + for (unsigned xi = 1; xi < node_->nrStates(); xi++) { + if (ldVals_[xi] != ldVals_[0]) { + childInfluenced = true; + break; + } + } + return childInfluenced; +} + diff --git a/packages/CLPBN/clpbn/bp/BnBpSolver.h b/packages/CLPBN/clpbn/bp/BnBpSolver.h new file mode 100644 index 000000000..c4b5d78d3 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/BnBpSolver.h @@ -0,0 +1,262 @@ +#ifndef HORUS_BNBPSOLVER_H +#define HORUS_BNBPSOLVER_H + +#include +#include + +#include "Solver.h" +#include "BayesNet.h" +#include "Shared.h" + +using namespace std; + +class BpNodeInfo; + +static const string PI_SYMBOL = "pi" ; +static const string LD_SYMBOL = "ld" ; + +enum LinkOrientation {UP, DOWN}; +enum JointCalcType {CHAIN_RULE, JUNCTION_NODE}; + +class BpLink +{ + public: + BpLink (BayesNode* s, BayesNode* d, LinkOrientation o) + { + source_ = s; + destin_ = d; + orientation_ = o; + if (orientation_ == LinkOrientation::DOWN) { + v1_.resize (s->nrStates(), Util::tl (1.0/s->nrStates())); + v2_.resize (s->nrStates(), Util::tl (1.0/s->nrStates())); + } else { + v1_.resize (d->nrStates(), Util::tl (1.0/d->nrStates())); + v2_.resize (d->nrStates(), Util::tl (1.0/d->nrStates())); + } + currMsg_ = &v1_; + nextMsg_ = &v2_; + residual_ = 0; + msgSended_ = false; + } + + void updateMessage (void) + { + swap (currMsg_, nextMsg_); + msgSended_ = true; + } + + void updateResidual (void) + { + residual_ = Util::getMaxNorm (v1_, v2_); + } + + string toString (void) const + { + stringstream ss; + if (orientation_ == LinkOrientation::DOWN) { + ss << PI_SYMBOL; + } else { + ss << LD_SYMBOL; + } + ss << "(" << source_->label(); + ss << " --> " << destin_->label() << ")" ; + return ss.str(); + } + + string toString (unsigned stateIndex) const + { + stringstream ss; + ss << toString() << "[" ; + if (orientation_ == LinkOrientation::DOWN) { + ss << source_->states()[stateIndex] << "]" ; + } else { + ss << destin_->states()[stateIndex] << "]" ; + } + return ss.str(); + } + + BayesNode* getSource (void) const { return source_; } + BayesNode* getDestination (void) const { return destin_; } + LinkOrientation getOrientation (void) const { return orientation_; } + const ParamSet& getMessage (void) const { return *currMsg_; } + ParamSet& getNextMessage (void) { return *nextMsg_; } + bool messageWasSended (void) const { return msgSended_; } + double getResidual (void) const { return residual_; } + void clearResidual (void) { residual_ = 0;} + + private: + BayesNode* source_; + BayesNode* destin_; + LinkOrientation orientation_; + ParamSet v1_; + ParamSet v2_; + ParamSet* currMsg_; + ParamSet* nextMsg_; + bool msgSended_; + double residual_; +}; + + +typedef vector BpLinkSet; + + +class BpNodeInfo +{ + public: + BpNodeInfo (BayesNode*); + + ParamSet getBeliefs (void) const; + bool receivedBottomInfluence (void) const; + + ParamSet& getPiValues (void) { return piVals_; } + ParamSet& getLambdaValues (void) { return ldVals_; } + void incNumPiMsgsReceived (void) { nPiMsgsRcv_ ++; } + void incNumLambdaMsgsReceived (void) { nLdMsgsRcv_ ++; } + bool piValuesCalculated (void) { return piValsCalc_; } + bool lambdaValuesCalculated (void) { return ldValsCalc_; } + + void markPiValuesAsCalculated (void); + void markLambdaValuesAsCalculated (void); + bool receivedAllPiMessages (void); + bool receivedAllLambdaMessages (void); + bool readyToSendPiMsgTo (const BayesNode*) const ; + bool readyToSendLambdaMsgTo (const BayesNode*) const; + + const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; } + const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; } + const BpLinkSet& getOutcomingParentLinks (void) { return outParentLinks_; } + const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; } + + void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); } + void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); } + void addOutcomingParentLink (BpLink* l) { outParentLinks_.push_back (l); } + void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); } + + private: + DISALLOW_COPY_AND_ASSIGN (BpNodeInfo); + + ParamSet piVals_; // pi values + ParamSet ldVals_; // lambda values + unsigned nPiMsgsRcv_; + unsigned nLdMsgsRcv_; + bool piValsCalc_; + bool ldValsCalc_; + BpLinkSet inParentLinks_; + BpLinkSet inChildLinks_; + BpLinkSet outParentLinks_; + BpLinkSet outChildLinks_; + const BayesNode* node_; +}; + + + +class BnBpSolver : public Solver +{ + public: + BnBpSolver (const BayesNet&); + ~BnBpSolver (void); + + void runSolver (void); + ParamSet getPosterioriOf (VarId); + ParamSet getJointDistributionOf (const VarIdSet&); + + + private: + DISALLOW_COPY_AND_ASSIGN (BnBpSolver); + + void initializeSolver (void); + void runPolyTreeSolver (void); + void runLoopySolver (void); + void maxResidualSchedule (void); + bool converged (void) const; + void updatePiValues (BayesNode*); + void updateLambdaValues (BayesNode*); + void calculateLambdaMessage (BpLink*); + void calculatePiMessage (BpLink*); + ParamSet getJointByJunctionNode (const VarIdSet&); + ParamSet getJointByChainRule (const VarIdSet&) const; + void printPiLambdaValues (const BayesNode*) const; + void printAllMessageStatus (void) const; + + void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true) + { + if (DL >= 3) { + cout << "calculating & updating " << link->toString() << endl; + } + if (link->getOrientation() == LinkOrientation::DOWN) { + calculatePiMessage (link); + } else if (link->getOrientation() == LinkOrientation::UP) { + calculateLambdaMessage (link); + } + if (calcResidual) { + link->updateResidual(); + } + link->updateMessage(); + } + + void calculateMessage (BpLink* link, bool calcResidual = true) + { + if (DL >= 3) { + cout << "calculating " << link->toString() << endl; + } + if (link->getOrientation() == LinkOrientation::DOWN) { + calculatePiMessage (link); + } else if (link->getOrientation() == LinkOrientation::UP) { + calculateLambdaMessage (link); + } + if (calcResidual) { + link->updateResidual(); + } + } + + void updateMessage (BpLink* link) + { + if (DL >= 3) { + cout << "updating " << link->toString() << endl; + } + link->updateMessage(); + } + + void updateValues (BpLink* link) + { + if (!link->getDestination()->hasEvidence()) { + if (link->getOrientation() == LinkOrientation::DOWN) { + updatePiValues (link->getDestination()); + } else if (link->getOrientation() == LinkOrientation::UP) { + updateLambdaValues (link->getDestination()); + } + } + } + + BpNodeInfo* ninf (const BayesNode* node) const + { + assert (node); + assert (node == bayesNet_->getBayesNode (node->varId())); + assert (node->getIndex() < nodesI_.size()); + return nodesI_[node->getIndex()]; + } + + const BayesNet* bayesNet_; + vector links_; + vector nodesI_; + unsigned nIters_; + JointCalcType jointCalcType_; + + struct compare + { + inline bool operator() (const BpLink* e1, const BpLink* e2) + { + return e1->getResidual() > e2->getResidual(); + } + }; + + typedef multiset SortedOrder; + SortedOrder sortedOrder_; + + typedef unordered_map BpLinkMap; + BpLinkMap linkMap_; + +}; + +#endif // HORUS_BNBPSOLVER_H + diff --git a/packages/CLPBN/clpbn/bp/BpNetwork.cpp b/packages/CLPBN/clpbn/bp/BpNetwork.cpp deleted file mode 100644 index 905f7cbc3..000000000 --- a/packages/CLPBN/clpbn/bp/BpNetwork.cpp +++ /dev/null @@ -1,811 +0,0 @@ -#include -#include -#include -#include - -#include "BpNetwork.h" -#include "BpNode.h" -#include "CptEntry.h" - -BpNetwork::BpNetwork (void) -{ - schedule_ = SEQUENTIAL_SCHEDULE; - maxIter_ = 150; - stableThreashold_ = 0.00000000000000000001; -} - - - -BpNetwork::~BpNetwork (void) -{ - for (unsigned int i = 0; i < nodes_.size(); i++) { - delete nodes_[i]; - } - nodes_.clear(); -} - - - -void -BpNetwork::setSolverParameters (Schedule schedule, - int maxIter, - double stableThreashold) -{ - if (maxIter <= 0) { - cerr << "error: maxIter must be greater or equal to 1" << endl; - abort(); - } - if (stableThreashold <= 0.0 || stableThreashold >= 1.0) { - cerr << "error: stableThreashold must be greater than 0.0 " ; - cerr << "and lesser than 1.0" << endl; - abort(); - } - schedule_ = schedule; - maxIter_ = maxIter; - stableThreashold_ = stableThreashold; -} - - - -void -BpNetwork::runSolver (BayesianNode* queryVar) -{ - vector queryVars; - queryVars.push_back (queryVar); - runSolver (queryVars); -} - - - -void -BpNetwork::runSolver (vector queryVars) -{ - if (queryVars.size() > 1) { - addJunctionNode (queryVars); - } - else { - string varName = queryVars[0]->getVariableName(); - queryNode_ = static_cast (getNode (varName)); - } - - if (!isPolyTree()) { - if (DL_ >= 1) { - cout << "The graph is not single connected. " ; - cout << "Iterative belief propagation will be used." ; - cout << endl << endl; - } - schedule_ = PARALLEL_SCHEDULE; - } - - if (schedule_ == SEQUENTIAL_SCHEDULE) { - initializeSolver (queryVars); - runNeapolitanSolver(); - for (unsigned int i = 0; i < nodes_.size(); i++) { - if (nodes_[i]->hasEvidence()) { - BpNode* v = static_cast (nodes_[i]); - addEvidence (v); - vector parents = cast (v->getParents()); - for (unsigned int i = 0; i < parents.size(); i++) { - if (!parents[i]->hasEvidence()) { - sendLambdaMessage (v, parents[i]); - } - } - vector childs = cast (v->getChilds()); - for (unsigned int i = 0; i < childs.size(); i++) { - sendPiMessage (v, childs[i]); - } - } - } - } else if (schedule_ == PARALLEL_SCHEDULE) { - BpNode::enableParallelSchedule(); - initializeSolver (queryVars); - for (unsigned int i = 0; i < nodes_.size(); i++) { - if (nodes_[i]->hasEvidence()) { - addEvidence (static_cast (nodes_[i])); - } - } - runIterativeBpSolver(); - } -} - - - -void -BpNetwork::printCurrentStatus (void) -{ - for (unsigned int i = 0; i < nodes_.size(); i++) { - printCurrentStatusOf (static_cast (nodes_[i])); - } -} - - - -void -BpNetwork::printCurrentStatusOf (BpNode* x) -{ - vector childs = cast (x->getChilds()); - vector domain = x->getDomain(); - - cout << left; - cout << setw (10) << "domain" ; - cout << setw (20) << "π(" + x->getVariableName() + ")" ; - cout << setw (20) << "λ(" + x->getVariableName() + ")" ; - cout << setw (16) << "belief" ; - cout << endl; - - cout << "--------------------------------" ; - cout << "--------------------------------" ; - cout << endl; - - double* piValues = x->getPiValues(); - double* lambdaValues = x->getLambdaValues(); - double* beliefs = x->getBeliefs(); - for (int xi = 0; xi < x->getDomainSize(); xi++) { - cout << setw (10) << domain[xi]; - cout << setw (19) << piValues[xi]; - cout << setw (19) << lambdaValues[xi]; - cout.precision (PRECISION_); - cout << setw (16) << beliefs[xi]; - cout << endl; - } - cout << endl; - if (childs.size() > 0) { - string s = "(" + x->getVariableName() + ")" ; - for (unsigned int j = 0; j < childs.size(); j++) { - cout << setw (10) << "domain" ; - cout << setw (28) << "π" + childs[j]->getVariableName() + s; - cout << setw (28) << "λ" + childs[j]->getVariableName() + s; - cout << endl; - cout << "--------------------------------" ; - cout << "--------------------------------" ; - cout << endl; - for (int xi = 0; xi < x->getDomainSize(); xi++) { - cout << setw (10) << domain[xi]; - cout.precision (PRECISION_); - cout << setw (27) << x->getPiMessage(childs[j], xi); - cout.precision (PRECISION_); - cout << setw (27) << x->getLambdaMessage(childs[j], xi); - cout << endl; - } - cout << endl; - } - } -} - - - -void -BpNetwork::printBeliefs (void) -{ - for (unsigned int i = 0; i < nodes_.size(); i++) { - BpNode* x = static_cast (nodes_[i]); - vector domain = x->getDomain(); - cout << setw (20) << left << x->getVariableName() ; - cout << setw (26) << "belief" ; - cout << endl; - cout << "--------------------------------------" ; - cout << endl; - double* beliefs = x->getBeliefs(); - for (int xi = 0; xi < x->getDomainSize(); xi++) { - cout << setw (20) << domain[xi]; - cout.precision (PRECISION_); - cout << setw (26) << beliefs[xi]; - cout << endl; - } - cout << endl; - } -} - - - -vector -BpNetwork::getBeliefs (void) -{ - return getBeliefs (queryNode_); -} - - - -vector -BpNetwork::getBeliefs (BpNode* x) -{ - double* beliefs = x->getBeliefs(); - vector beliefsVec; - for (int xi = 0; xi < x->getDomainSize(); xi++) { - beliefsVec.push_back (beliefs[xi]); - } - return beliefsVec; -} - - - -void -BpNetwork::initializeSolver (vector queryVars) -{ - if (DL_ >= 1) { - cout << "Initializing solver" << endl; - if (schedule_ == SEQUENTIAL_SCHEDULE) { - cout << "-> schedule = sequential" << endl; - } else { - cout << "-> schedule = parallel" << endl; - } - cout << "-> max iters = " << maxIter_ << endl; - cout << "-> stable threashold = " << stableThreashold_ << endl; - cout << "-> query vars = " ; - for (unsigned int i = 0; i < queryVars.size(); i++) { - cout << queryVars[i]->getVariableName() << " " ; - } - cout << endl; - } - - nIter_ = 0; - - for (unsigned int i = 0; i < nodes_.size(); i++) { - BpNode* node = static_cast (nodes_[i]); - node->allocateMemory(); - } - - for (unsigned int i = 0; i < nodes_.size(); i++) { - BpNode* x = static_cast (nodes_[i]); - - double* piValues = x->getPiValues(); - double* lambdaValues = x->getLambdaValues(); - for (int xi = 0; xi < x->getDomainSize(); xi++) { - piValues[xi] = 1.0; - lambdaValues[xi] = 1.0; - } - - vector xChilds = cast (x->getChilds()); - for (unsigned int j = 0; j < xChilds.size(); j++) { - double* piMessages = x->getPiMessages (xChilds[j]); - for (int xi = 0; xi < x->getDomainSize(); xi++) { - piMessages[xi] = 1.0; - //x->setPiMessage (xChilds[j], xi, 1.0); - } - } - - vector xParents = cast (x->getParents()); - for (unsigned int j = 0; j < xParents.size(); j++) { - double* lambdaMessages = xParents[j]->getLambdaMessages (x); - for (int xi = 0; xi < xParents[j]->getDomainSize(); xi++) { - lambdaMessages[xi] = 1.0; - //xParents[j]->setLambdaMessage (x, xi, 1.0); - } - } - } - - for (unsigned int i = 0; i < nodes_.size(); i++) { - BpNode* x = static_cast (nodes_[i]); - x->normalizeMessages(); - } - printCurrentStatus(); - - - vector roots = cast (getRootNodes()); - for (unsigned int i = 0; i < roots.size(); i++) { - double* params = roots[i]->getParameters(); - double* piValues = roots[i]->getPiValues(); - for (int ri = 0; ri < roots[i]->getDomainSize(); ri++) { - piValues[ri] = params[ri]; - } - } -} - - - -void -BpNetwork::addJunctionNode (vector queryVars) -{ - const string VAR_NAME = "_Jn"; - int nStates = 1; - vector parents; - vector domain; - for (unsigned int i = 0; i < queryVars.size(); i++) { - parents.push_back (queryVars[i]); - nStates *= queryVars[i]->getDomainSize(); - } - - for (int i = 0; i < nStates; i++) { - stringstream ss; - ss << "_jn" << i; - domain.push_back (ss.str()); // FIXME make domain optional - } - - int nParams = nStates * nStates; - double* params = new double [nParams]; - for (int i = 0; i < nParams; i++) { - int row = i / nStates; - int col = i % nStates; - if (row == col) { - params[i] = 1; - } else { - params[i] = 0; - } - } - addNode (VAR_NAME, parents, params, nParams, domain); - queryNode_ = static_cast (getNode (VAR_NAME)); - printNetwork(); -} - - - -void -BpNetwork::addEvidence (BpNode* v) -{ - if (DL_ >= 1) { - cout << "Adding evidence: node " ; - cout << "`" << v->getVariableName() << "' was instantiated as " ; - cout << "`" << v->getDomain()[v->getEvidence()] << "'" ; - cout << endl; - } - double* piValues = v->getPiValues(); - double* lambdaValues = v->getLambdaValues(); - for (int vi = 0; vi < v->getDomainSize(); vi++) { - if (vi == v->getEvidence()) { - piValues[vi] = 1.0; - lambdaValues[vi] = 1.0; - } else { - piValues[vi] = 0.0; - lambdaValues[vi] = 0.0; - } - } -} - - - -void -BpNetwork::runNeapolitanSolver (void) -{ - vector roots = cast (getRootNodes()); - for (unsigned int i = 0; i < roots.size(); i++) { - vector childs = cast (roots[i]->getChilds()); - for (unsigned int j = 0; j < childs.size(); j++) { - sendPiMessage (roots[i], static_cast (childs[j])); - } - } -} - - - -void -BpNetwork::sendPiMessage (BpNode* z, BpNode* x) -{ - nIter_ ++; - if (!(maxIter_ == -1 || nIter_ < maxIter_)) { - cout << "the maximum number of iterations was achieved, terminating..." ; - cout << endl; - return; - } - - if (DL_ >= 1) { - cout << "π message " << z->getVariableName(); - cout << " --> " << x->getVariableName() << endl; - } - - updatePiMessages(z, x); - - if (!x->hasEvidence()) { - updatePiValues (x); - vector xChilds = cast (x->getChilds()); - for (unsigned int i = 0; i < xChilds.size(); i++) { - sendPiMessage (x, xChilds[i]); - } - } - - bool isAllOnes = true; - double* lambdaValues = x->getLambdaValues(); - for (int xi = 0; xi < x->getDomainSize(); xi++) { - if (lambdaValues[xi] != 1.0) { - isAllOnes = false; - break; - } - } - - if (!isAllOnes) { - vector xParents = cast (x->getParents()); - for (unsigned int i = 0; i < xParents.size(); i++) { - if (xParents[i] != z && !xParents[i]->hasEvidence()) { - sendLambdaMessage (x, xParents[i]); - } - } - } -} - - - -void -BpNetwork::sendLambdaMessage (BpNode* y, BpNode* x) -{ - nIter_ ++; - if (!(maxIter_ == -1 || nIter_ < maxIter_)) { - cout << "the maximum number of iterations was achieved, terminating..." ; - cout << endl; - return; - } - - if (DL_ >= 1) { - cout << "λ message " << y->getVariableName(); - cout << " --> " << x->getVariableName() << endl; - } - - updateLambdaMessages (x, y); - updateLambdaValues (x); - - vector xParents = cast (x->getParents()); - for (unsigned int i = 0; i < xParents.size(); i++) { - if (!xParents[i]->hasEvidence()) { - sendLambdaMessage (x, xParents[i]); - } - } - - vector xChilds = cast (x->getChilds()); - for (unsigned int i = 0; i < xChilds.size(); i++) { - if (xChilds[i] != y) { - sendPiMessage (x, xChilds[i]); - } - } -} - - - -void -BpNetwork::updatePiValues (BpNode* x) -{ - // π(Xi) - vector parents = cast (x->getParents()); - for (int xi = 0; xi < x->getDomainSize(); xi++) { - stringstream calcs1; - stringstream calcs2; - if (DL_ >= 2) { - calcs1 << "π("<< x->getDomain()[xi] << ")" << endl << "= " ; - } - double sum = 0.0; - vector > constraints; - vector entries = x->getCptEntriesOfRow (xi); - for (unsigned int k = 0; k < entries.size(); k++) { - double prod = x->getProbability (entries[k]); - if (DL_ >= 2) { - if (k != 0) calcs1 << endl << "+ " ; - calcs1 << x->entryToString (entries[k]); - if (DL_ >= 3) { - (k == 0) ? calcs2 << "(" << prod : calcs2 << endl << "+ (" << prod; - } - } - vector insts = entries[k].getDomainInstantiations(); - for (unsigned int i = 0; i < parents.size(); i++) { - double value = parents[i]->getPiMessage (x, insts[i + 1]); - prod *= value; - if (DL_ >= 2) { - calcs1 << ".π" << x->getVariableName(); - calcs1 << "(" << parents[i]->getDomain()[insts[i + 1]] << ")"; - if (DL_ >= 3) calcs2 << "x" << value; - } - } - sum += prod; - if (DL_ >= 3) calcs2 << ")"; - } - x->setPiValue (xi, sum); - if (DL_ >= 2) { - cout << calcs1.str(); - if (DL_ >= 3) cout << endl << "= " << calcs2.str(); - cout << " = " << sum << endl; - } - } -} - - - -void -BpNetwork::updatePiMessages (BpNode* z, BpNode* x) -{ - // πX(Zi) - vector zChilds = cast (z->getChilds()); - for (int zi = 0; zi < z->getDomainSize(); zi++) { - stringstream calcs1; - stringstream calcs2; - if (DL_ >= 2) { - calcs1 << "π" << x->getVariableName(); - calcs1 << "(" << z->getDomain()[zi] << ") = " ; - } - double prod = z->getPiValue (zi); - if (DL_ >= 2) { - calcs1 << "π(" << z->getDomain()[zi] << ")" ; - if (DL_ >= 3) calcs2 << prod; - } - for (unsigned int i = 0; i < zChilds.size(); i++) { - if (zChilds[i] != x) { - double value = z->getLambdaMessage (zChilds[i], zi); - prod *= value; - if (DL_ >= 2) { - calcs1 << ".λ" << zChilds[i]->getVariableName(); - calcs1 << "(" << z->getDomain()[zi] + ")" ; - if (DL_ >= 3) calcs2 << " x " << value; - } - } - } - z->setPiMessage (x, zi, prod); - if (DL_ >= 2) { - cout << calcs1.str(); - if (DL_ >= 3) cout << " = " << calcs2.str(); - cout << " = " << prod << endl; - } - } -} - - - -void -BpNetwork::updateLambdaValues (BpNode* x) -{ - // λ(Xi) - vector childs = cast (x->getChilds()); - for (int xi = 0; xi < x->getDomainSize(); xi++) { - stringstream calcs1; - stringstream calcs2; - if (DL_ >= 2) { - calcs1 << "λ" << "(" << x->getDomain()[xi] << ") = " ; - } - double prod = 1.0; - for (unsigned int i = 0; i < childs.size(); i++) { - double val = x->getLambdaMessage (childs[i], xi); - prod *= val; - if (DL_ >= 2) { - if (i != 0) calcs1 << "." ; - calcs1 << "λ" << childs[i]->getVariableName(); - calcs1 << "(" << x->getDomain()[xi] + ")" ; - if (DL_ >= 3) (i == 0) ? calcs2 << val : calcs2 << " x " << val; - } - } - x->setLambdaValue (xi, prod); - if (DL_ >= 2) { - cout << calcs1.str(); - if (childs.size() == 0) { - cout << 1 << endl; - } else { - if (DL_ >= 3) cout << " = " << calcs2.str(); - cout << " = " << prod << endl; - } - } - } -} - - - -void -BpNetwork::updateLambdaMessages (BpNode* x, BpNode* y) -{ - // λY(Xi) - int parentIndex = y->getIndexOfParent (x) + 1; - vector yParents = cast (y->getParents()); - for (int xi = 0; xi < x->getDomainSize(); xi++) { - stringstream calcs1; - stringstream calcs2; - if (DL_ >= 2) { - calcs1 << "λ" << y->getVariableName() ; - calcs1 << "(" << x->getDomain()[xi] << ")" << endl << "= " ; - } - double outer_sum = 0.0; - for (int yi = 0; yi < y->getDomainSize(); yi++) { - if (DL_ >= 2) { - (yi == 0) ? calcs1 << "[" : calcs1 << endl << "+ [" ; - if (DL_ >= 3) { - (yi == 0) ? calcs2 << "[" : calcs2 << endl << "+ [" ; - } - } - double inner_sum = 0.0; - vector > constraints; - constraints.push_back (make_pair (0, yi)); - constraints.push_back (make_pair (parentIndex, xi)); - vector entries = y->getCptEntries (constraints); - for (unsigned int k = 0; k < entries.size(); k++) { - double prod = y->getProbability (entries[k]); - if (DL_ >= 2) { - if (k != 0) calcs1 << " + " ; - calcs1 << y->entryToString (entries[k]); - if (DL_ >= 3) { - (k == 0) ? calcs2 << "(" << prod : calcs2 << " + (" << prod; - } - } - vector insts = entries[k].getDomainInstantiations(); - for (unsigned int i = 0; i < yParents.size(); i++) { - if (yParents[i] != x) { - double val = yParents[i]->getPiMessage (y, insts[i + 1]); - prod *= val; - if (DL_ >= 2) { - calcs1 << ".π" << y->getVariableName(); - calcs1 << "(" << yParents[i]->getDomain()[insts[i + 1]] << ")" ; - if (DL_ >= 3) calcs2 << "x" << val; - } - - } - } - inner_sum += prod; - if (DL_ >= 3) { - calcs2 << ")" ; - } - } - outer_sum += inner_sum * y->getLambdaValue (yi); - if (DL_ >= 2) { - calcs1 << "].λ(" << y->getDomain()[yi] << ")"; - if (DL_ >= 3) calcs2 << "]x" << y->getLambdaValue (yi); - } - } - x->setLambdaMessage (y, xi, outer_sum); - if (DL_ >= 2) { - cout << calcs1.str(); - if (DL_ >= 3) cout << endl << "= " << calcs2.str(); - cout << " = " << outer_sum << endl; - } - } -} - - - -void -BpNetwork::runIterativeBpSolver() -{ - int nIter = 0; - maxIter_ = 100; - bool converged = false; - while (nIter < maxIter_ && !converged) { - if (DL_ >= 1) { - cout << endl << endl; - cout << "****************************************" ; - cout << "****************************************" ; - cout << endl; - cout << " Iteration " << nIter + 1 << endl; - cout << "****************************************" ; - cout << "****************************************" ; - } - - for (unsigned int i = 0; i < nodes_.size(); i++) { - BpNode* x = static_cast(nodes_[i]); - vector xParents = cast (x->getParents()); - for (unsigned int j = 0; j < xParents.size(); j++) { - //if (!xParents[j]->hasEvidence()) { - if (DL_ >= 1) { - cout << endl << "λ message " << x->getVariableName(); - cout << " --> " << xParents[j]->getVariableName() << endl; - } - updateLambdaMessages (xParents[j], x); - //} - } - } - - for (unsigned int i = 0; i < nodes_.size(); i++) { - BpNode* x = static_cast(nodes_[i]); - vector xChilds = cast (x->getChilds()); - for (unsigned int j = 0; j < xChilds.size(); j++) { - if (DL_ >= 1) { - cout << endl << "π message " << x->getVariableName(); - cout << " --> " << xChilds[j]->getVariableName() << endl; - } - updatePiMessages (x, xChilds[j]); - } - } - - /* - for (unsigned int i = 0; i < nodes_.size(); i++) { - BpNode* x = static_cast(nodes_[i]); - vector xChilds = cast (x->getChilds()); - for (unsigned int j = 0; j < xChilds.size(); j++) { - if (DL_ >= 1) { - cout << "π message " << x->getVariableName(); - cout << " --> " << xChilds[j]->getVariableName() << endl; - } - updatePiMessages (x, xChilds[j]); - } - vector xParents = cast (x->getParents()); - for (unsigned int j = 0; j < xParents.size(); j++) { - //if (!xParents[j]->hasEvidence()) { - if (DL_ >= 1) { - cout << "λ message " << x->getVariableName(); - cout << " --> " << xParents[j]->getVariableName() << endl; - } - updateLambdaMessages (xParents[j], x); - //} - } - } - */ - - for (unsigned int i = 0; i < nodes_.size(); i++) { - BpNode* x = static_cast (nodes_[i]); - //cout << endl << "SWAPING MESSAGES FOR " << x->getVariableName() << ":" ; - //cout << endl << endl; - //printCurrentStatusOf (x); - x->swapMessages(); - x->normalizeMessages(); - //cout << endl << "messages swaped " << endl; - //printCurrentStatusOf (x); - } - - converged = true; - for (unsigned int i = 0; i < nodes_.size(); i++) { - BpNode* x = static_cast(nodes_[i]); - if (DL_ >= 1) { - cout << endl << "var " << x->getVariableName() << ":" << endl; - } - //if (!x->hasEvidence()) { - updatePiValues (x); - updateLambdaValues (x); - double change = x->getBeliefChange(); - if (DL_ >= 1) { - cout << "belief change = " << change << endl; - } - if (change > stableThreashold_) { - converged = false; - } - //} - } - - if (converged) { - // converged = false; - } - if (DL_ >= 2) { - cout << endl; - printCurrentStatus(); - } - nIter++; - } - - if (DL_ >= 1) { - cout << endl; - if (converged) { - cout << "Iterative belief propagation converged in " ; - cout << nIter << " iterations" << endl; - } else { - cout << "Iterative belief propagation converged didn't converge" ; - cout << endl; - } - if (DL_ == 1) { - cout << endl; - printBeliefs(); - } - cout << endl; - } -} - - - -void -BpNetwork::addNode (string varName, - vector parents, - int evidence, - int distId) -{ - for (unsigned int i = 0; i < dists_.size(); i++) { - if (dists_[i]->id == distId) { - BpNode* node = new BpNode (varName, parents, dists_[i], evidence); - nodes_.push_back (node); - break; - } - } -} - - - -void -BpNetwork::addNode (string varName, - vector parents, - double* params, - int nParams, - vector domain) -{ - Distribution* dist = new Distribution (params, nParams, domain); - BpNode* node = new BpNode (varName, parents, dist); - dists_.push_back (dist); - nodes_.push_back (node); -} - - - -vector -BpNetwork::cast (vector nodes) -{ - vector castedNodes (nodes.size()); - for (unsigned int i = 0; i < nodes.size(); i++) { - castedNodes[i] = static_cast (nodes[i]); - } - return castedNodes; -} - diff --git a/packages/CLPBN/clpbn/bp/BpNetwork.h b/packages/CLPBN/clpbn/bp/BpNetwork.h deleted file mode 100644 index a37b079af..000000000 --- a/packages/CLPBN/clpbn/bp/BpNetwork.h +++ /dev/null @@ -1,66 +0,0 @@ -#ifndef BP_BP_NETWORK_H -#define BP_BP_NETWORK_H - -#include -#include - -#include "BayesianNetwork.h" - -using namespace std; - -class BpNode; - -enum Schedule -{ - SEQUENTIAL_SCHEDULE, - PARALLEL_SCHEDULE -}; - -class BpNetwork : public BayesianNetwork -{ - public: - // constructs - BpNetwork (void); - // destruct - ~BpNetwork (void); - // methods - void setSolverParameters (Schedule, int, double); - void runSolver (BayesianNode* queryVar); - void runSolver (vector); - void printCurrentStatus (void); - void printCurrentStatusOf (BpNode*); - void printBeliefs (void); - vector getBeliefs (void); - vector getBeliefs (BpNode*); - - private: - BpNetwork (const BpNetwork&); // disallow copy - void operator= (const BpNetwork&); // disallow assign - // methods - void initializeSolver (vector); - void addJunctionNode (vector); - void addEvidence (BpNode*); - void runNeapolitanSolver (void); - void sendLambdaMessage (BpNode*, BpNode*); - void sendPiMessage (BpNode*, BpNode*); - void updatePiValues (BpNode*); - void updatePiMessages (BpNode*, BpNode*); - void updateLambdaValues (BpNode*); - void updateLambdaMessages (BpNode*, BpNode*); - void runIterativeBpSolver (void); - void addNode (string, vector, int, int); - void addNode (string, vector, - double*, int, vector); - vector cast (vector); - // members - Schedule schedule_; - int nIter_; - int maxIter_; - double stableThreashold_; - BpNode* queryNode_; - static const int DL_ = 3; - static const int PRECISION_ = 10; -}; - -#endif // BP_BP_NETWORK_H - diff --git a/packages/CLPBN/clpbn/bp/BpNode.cpp b/packages/CLPBN/clpbn/bp/BpNode.cpp deleted file mode 100644 index 4fd52f95c..000000000 --- a/packages/CLPBN/clpbn/bp/BpNode.cpp +++ /dev/null @@ -1,250 +0,0 @@ -#include -#include -#include - -#include "BpNode.h" - -bool BpNode::calculateMessageResidual_ = true; - - -BpNode::BpNode (BayesNode* node) -{ - ds_ = node->getDomainSize(); - const NodeSet& childs = node->getChilds(); - piVals_.resize (ds_, 1); - ldVals_.resize (ds_, 1); - if (calculateMessageResidual_) { - piResiduals_.resize (childs.size(), 0.0); - ldResiduals_.resize (childs.size(), 0.0); - } - childs_ = &childs; - for (unsigned i = 0; i < childs.size(); i++) { - //indexMap_.insert (make_pair (childs[i]->getVarId(), i)); - currPiMsgs_.push_back (ParamSet (ds_, 1)); - currLdMsgs_.push_back (ParamSet (ds_, 1)); - nextPiMsgs_.push_back (ParamSet (ds_, 1)); - nextLdMsgs_.push_back (ParamSet (ds_, 1)); - } -} - - - -ParamSet -BpNode::getBeliefs (void) const -{ - double sum = 0.0; - ParamSet beliefs (ds_); - for (int xi = 0; xi < ds_; xi++) { - double prod = piVals_[xi] * ldVals_[xi]; - beliefs[xi] = prod; - sum += prod; - } - assert (sum); - //normalize the beliefs - for (int xi = 0; xi < ds_; xi++) { - beliefs[xi] /= sum; - } - return beliefs; -} - - - -double -BpNode::getPiValue (int idx) const -{ - assert (idx >=0 && idx < ds_); - return piVals_[idx]; -} - - - -void -BpNode::setPiValue (int idx, double value) -{ - assert (idx >=0 && idx < ds_); - piVals_[idx] = value; -} - - - -double -BpNode::getLambdaValue (int idx) const -{ - assert (idx >=0 && idx < ds_); - return ldVals_[idx]; -} - - - -void -BpNode::setLambdaValue (int idx, double value) -{ - assert (idx >=0 && idx < ds_); - ldVals_[idx] = value; -} - - - -ParamSet& -BpNode::getPiValues (void) -{ - return piVals_; -} - - - -ParamSet& -BpNode::getLambdaValues (void) -{ - return ldVals_; -} - - - -double -BpNode::getPiMessageValue (const BayesNode* destination, int idx) const -{ - assert (idx >=0 && idx < ds_); - return currPiMsgs_[getIndex(destination)][idx]; -} - - - -double -BpNode::getLambdaMessageValue (const BayesNode* source, int idx) const -{ - assert (idx >=0 && idx < ds_); - return currLdMsgs_[getIndex(source)][idx]; -} - - - -const ParamSet& -BpNode::getPiMessage (const BayesNode* destination) const -{ - return currPiMsgs_[getIndex(destination)]; -} - - - -const ParamSet& -BpNode::getLambdaMessage (const BayesNode* source) const -{ - return currLdMsgs_[getIndex(source)]; -} - - - -ParamSet& -BpNode::piNextMessageReference (const BayesNode* destination) -{ - return nextPiMsgs_[getIndex(destination)]; -} - - - -ParamSet& -BpNode::lambdaNextMessageReference (const BayesNode* source) -{ - return nextLdMsgs_[getIndex(source)]; -} - - - -void -BpNode::updatePiMessage (const BayesNode* destination) -{ - int idx = getIndex (destination); - currPiMsgs_[idx] = nextPiMsgs_[idx]; - Util::normalize (currPiMsgs_[idx]); -} - - - -void -BpNode::updateLambdaMessage (const BayesNode* source) -{ - int idx = getIndex (source); - currLdMsgs_[idx] = nextLdMsgs_[idx]; - Util::normalize (currLdMsgs_[idx]); -} - - - -double -BpNode::getBeliefChange (void) -{ - double change = 0.0; - if (oldBeliefs_.size() == 0) { - oldBeliefs_ = getBeliefs(); - change = 9999999999.0; - } else { - ParamSet currentBeliefs = getBeliefs(); - for (int xi = 0; xi < ds_; xi++) { - change += abs (currentBeliefs[xi] - oldBeliefs_[xi]); - } - oldBeliefs_ = currentBeliefs; - } - return change; -} - - - -void -BpNode::updatePiResidual (const BayesNode* destination) -{ - int idx = getIndex (destination); - Util::normalize (nextPiMsgs_[idx]); - //piResiduals_[idx] = Util::getL1dist ( - // currPiMsgs_[idx], nextPiMsgs_[idx]); - piResiduals_[idx] = Util::getMaxNorm ( - currPiMsgs_[idx], nextPiMsgs_[idx]); -} - - - -void -BpNode::updateLambdaResidual (const BayesNode* source) -{ - int idx = getIndex (source); - Util::normalize (nextLdMsgs_[idx]); - //ldResiduals_[idx] = Util::getL1dist ( - // currLdMsgs_[idx], nextLdMsgs_[idx]); - ldResiduals_[idx] = Util::getMaxNorm ( - currLdMsgs_[idx], nextLdMsgs_[idx]); -} - - - -void -BpNode::clearPiResidual (const BayesNode* destination) -{ - piResiduals_[getIndex(destination)] = 0; -} - - - -void -BpNode::clearLambdaResidual (const BayesNode* source) -{ - ldResiduals_[getIndex(source)] = 0; -} - - - -bool -BpNode::hasReceivedChildInfluence (void) const -{ - // if all lambda values are equal, then neither - // this node neither its descendents have evidence, - // we can use this to don't send lambda messages his parents - bool childInfluenced = false; - for (int xi = 1; xi < ds_; xi++) { - if (ldVals_[xi] != ldVals_[0]) { - childInfluenced = true; - break; - } - } - return childInfluenced; -} - diff --git a/packages/CLPBN/clpbn/bp/BpNode.h b/packages/CLPBN/clpbn/bp/BpNode.h deleted file mode 100644 index 2b84a298d..000000000 --- a/packages/CLPBN/clpbn/bp/BpNode.h +++ /dev/null @@ -1,99 +0,0 @@ -#ifndef BP_BPNODE_H -#define BP_BPNODE_H - -#include -#include -#include -#include - -#include "BayesNode.h" -#include "Shared.h" - -using namespace std; - -class BpNode -{ - public: - BpNode (int); - BpNode (BayesNode*); - - ParamSet getBeliefs (void) const; - double getPiValue (int) const; - void setPiValue (int, double); - double getLambdaValue (int) const; - void setLambdaValue (int, double); - ParamSet& getPiValues (void); - ParamSet& getLambdaValues (void); - double getPiMessageValue (const BayesNode*, int) const; - double getLambdaMessageValue (const BayesNode*, int) const; - const ParamSet& getPiMessage (const BayesNode*) const; - const ParamSet& getLambdaMessage (const BayesNode*) const; - ParamSet& piNextMessageReference (const BayesNode*); - ParamSet& lambdaNextMessageReference (const BayesNode*); - void updatePiMessage (const BayesNode*); - void updateLambdaMessage (const BayesNode*); - double getBeliefChange (void); - void updatePiResidual (const BayesNode*); - void updateLambdaResidual (const BayesNode*); - void clearPiResidual (const BayesNode*); - void clearLambdaResidual (const BayesNode*); - bool hasReceivedChildInfluence (void) const; - // inlines - double getPiResidual (const BayesNode*); - double getLambdaResidual (const BayesNode*); - int getIndex (const BayesNode*) const; - - private: - DISALLOW_COPY_AND_ASSIGN (BpNode); - - IndexMap indexMap_; - ParamSet piVals_; // pi values - ParamSet ldVals_; // lambda values - vector currPiMsgs_; // current pi messages - vector currLdMsgs_; // current lambda messages - vector nextPiMsgs_; - vector nextLdMsgs_; - ParamSet oldBeliefs_; - ParamSet piResiduals_; - ParamSet ldResiduals_; - int ds_; - const NodeSet* childs_; - static bool calculateMessageResidual_; -// static const double MAX_CHANGE_ = 10000000.0; -}; - - - -inline double -BpNode::getPiResidual (const BayesNode* destination) -{ - return piResiduals_[getIndex(destination)]; -} - - -inline double -BpNode::getLambdaResidual (const BayesNode* source) -{ - return ldResiduals_[getIndex(source)]; -} - - - -inline int -BpNode::getIndex (const BayesNode* node) const -{ - assert (node); - //assert (indexMap_.find(node->getVarId()) != indexMap_.end()); - //return indexMap_.find(node->getVarId())->second; - for (unsigned i = 0; childs_->size(); i++) { - if ((*childs_)[i]->getVarId() == node->getVarId()) { - return i; - } - } - assert (false); - return -1; -} - - -#endif - diff --git a/packages/CLPBN/clpbn/bp/CFactorGraph.cpp b/packages/CLPBN/clpbn/bp/CFactorGraph.cpp new file mode 100644 index 000000000..26cfae009 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/CFactorGraph.cpp @@ -0,0 +1,344 @@ + +#include "CFactorGraph.h" +#include "Factor.h" +#include "Distribution.h" + + +bool CFactorGraph::checkForIdenticalFactors_ = true; + +CFactorGraph::CFactorGraph (const FactorGraph& fg) +{ + groundFg_ = &fg; + freeColor_ = 0; + + const FgVarSet& varNodes = fg.getVarNodes(); + varSignatures_.reserve (varNodes.size()); + for (unsigned i = 0; i < varNodes.size(); i++) { + unsigned c = (varNodes[i]->neighbors().size() * 2) + 1; + varSignatures_.push_back (Signature (c)); + } + + const FgFacSet& facNodes = fg.getFactorNodes(); + factorSignatures_.reserve (facNodes.size()); + for (unsigned i = 0; i < facNodes.size(); i++) { + unsigned c = facNodes[i]->neighbors().size() + 1; + factorSignatures_.push_back (Signature (c)); + } + + varColors_.resize (varNodes.size()); + factorColors_.resize (facNodes.size()); + setInitialColors(); + createGroups(); +} + + + +CFactorGraph::~CFactorGraph (void) +{ + for (unsigned i = 0; i < varClusters_.size(); i++) { + delete varClusters_[i]; + } + for (unsigned i = 0; i < factorClusters_.size(); i++) { + delete factorClusters_[i]; + } +} + + + +void +CFactorGraph::setInitialColors (void) +{ + // create the initial variable colors + VarColorMap colorMap; + const FgVarSet& varNodes = groundFg_->getVarNodes(); + for (unsigned i = 0; i < varNodes.size(); i++) { + unsigned dsize = varNodes[i]->nrStates(); + VarColorMap::iterator it = colorMap.find (dsize); + if (it == colorMap.end()) { + it = colorMap.insert (make_pair ( + dsize, vector (dsize+1,-1))).first; + } + unsigned idx; + if (varNodes[i]->hasEvidence()) { + idx = varNodes[i]->getEvidence(); + } else { + idx = dsize; + } + vector& stateColors = it->second; + if (stateColors[idx] == -1) { + stateColors[idx] = getFreeColor(); + } + setColor (varNodes[i], stateColors[idx]); + } + + const FgFacSet& facNodes = groundFg_->getFactorNodes(); + if (checkForIdenticalFactors_) { + for (unsigned i = 0; i < facNodes.size() - 1; i++) { + // facNodes[i]->factor()->orderFactorVariables(); + // FIXME + } + for (unsigned i = 0, s = facNodes.size(); i < s; i++) { + Distribution* dist1 = facNodes[i]->getDistribution(); + for (unsigned j = 0; j < i; j++) { + Distribution* dist2 = facNodes[j]->getDistribution(); + if (dist1 != dist2 && dist1->params == dist2->params) { + facNodes[i]->factor()->setDistribution (dist2); + // delete dist2; + break; + } + /* + if (ok) { + const FgVarSet& fiVars = factors[i]->getFgVarNodes(); + const FgVarSet& fjVars = factors[j]->getFgVarNodes(); + if (fiVars.size() != fjVars.size()) continue; + for (unsigned k = 0; k < fiVars.size(); k++) { + if (fiVars[k]->nrStates() != fjVars[k]->nrStates()) { + ok = false; + break; + } + } + } + */ + } + } + } + + // create the initial factor colors + DistColorMap distColors; + for (unsigned i = 0; i < facNodes.size(); i++) { + const Distribution* dist = facNodes[i]->getDistribution(); + DistColorMap::iterator it = distColors.find (dist); + if (it == distColors.end()) { + it = distColors.insert (make_pair (dist, getFreeColor())).first; + } + setColor (facNodes[i], it->second); + } +} + + + +void +CFactorGraph::createGroups (void) +{ + VarSignMap varGroups; + FacSignMap factorGroups; + unsigned nIters = 0; + bool groupsHaveChanged = true; + const FgVarSet& varNodes = groundFg_->getVarNodes(); + const FgFacSet& facNodes = groundFg_->getFactorNodes(); + + while (groupsHaveChanged || nIters == 1) { + nIters ++; + + unsigned prevFactorGroupsSize = factorGroups.size(); + factorGroups.clear(); + // set a new color to the factors with the same signature + for (unsigned i = 0; i < facNodes.size(); i++) { + const Signature& signature = getSignature (facNodes[i]); + FacSignMap::iterator it = factorGroups.find (signature); + if (it == factorGroups.end()) { + it = factorGroups.insert (make_pair (signature, FgFacSet())).first; + } + it->second.push_back (facNodes[i]); + } + for (FacSignMap::iterator it = factorGroups.begin(); + it != factorGroups.end(); it++) { + Color newColor = getFreeColor(); + FgFacSet& groupMembers = it->second; + for (unsigned i = 0; i < groupMembers.size(); i++) { + setColor (groupMembers[i], newColor); + } + } + + // set a new color to the variables with the same signature + unsigned prevVarGroupsSize = varGroups.size(); + varGroups.clear(); + for (unsigned i = 0; i < varNodes.size(); i++) { + const Signature& signature = getSignature (varNodes[i]); + VarSignMap::iterator it = varGroups.find (signature); + if (it == varGroups.end()) { + it = varGroups.insert (make_pair (signature, FgVarSet())).first; + } + it->second.push_back (varNodes[i]); + } + for (VarSignMap::iterator it = varGroups.begin(); + it != varGroups.end(); it++) { + Color newColor = getFreeColor(); + FgVarSet& groupMembers = it->second; + for (unsigned i = 0; i < groupMembers.size(); i++) { + setColor (groupMembers[i], newColor); + } + } + + groupsHaveChanged = prevVarGroupsSize != varGroups.size() + || prevFactorGroupsSize != factorGroups.size(); + } + //printGroups (varGroups, factorGroups); + createClusters (varGroups, factorGroups); +} + + + +void +CFactorGraph::createClusters (const VarSignMap& varGroups, + const FacSignMap& factorGroups) +{ + varClusters_.reserve (varGroups.size()); + for (VarSignMap::const_iterator it = varGroups.begin(); + it != varGroups.end(); it++) { + const FgVarSet& groupVars = it->second; + VarCluster* vc = new VarCluster (groupVars); + for (unsigned i = 0; i < groupVars.size(); i++) { + vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc)); + } + varClusters_.push_back (vc); + } + + factorClusters_.reserve (factorGroups.size()); + for (FacSignMap::const_iterator it = factorGroups.begin(); + it != factorGroups.end(); it++) { + FgFacNode* groupFactor = it->second[0]; + const FgVarSet& neighs = groupFactor->neighbors(); + VarClusterSet varClusters; + varClusters.reserve (neighs.size()); + for (unsigned i = 0; i < neighs.size(); i++) { + VarId vid = neighs[i]->varId(); + varClusters.push_back (vid2VarCluster_.find (vid)->second); + } + factorClusters_.push_back (new FacCluster (it->second, varClusters)); + } +} + + + +const Signature& +CFactorGraph::getSignature (const FgVarNode* varNode) +{ + Signature& sign = varSignatures_[varNode->getIndex()]; + vector::iterator it = sign.colors.begin(); + const FgFacSet& neighs = varNode->neighbors(); + for (unsigned i = 0; i < neighs.size(); i++) { + *it = getColor (neighs[i]); + it ++; + *it = neighs[i]->factor()->getPositionOf (varNode->varId()); + it ++; + } + *it = getColor (varNode); + return sign; +} + + + +const Signature& +CFactorGraph::getSignature (const FgFacNode* facNode) +{ + Signature& sign = factorSignatures_[facNode->getIndex()]; + vector::iterator it = sign.colors.begin(); + const FgVarSet& neighs = facNode->neighbors(); + for (unsigned i = 0; i < neighs.size(); i++) { + *it = getColor (neighs[i]); + it ++; + } + *it = getColor (facNode); + return sign; +} + + + +FactorGraph* +CFactorGraph::getCompressedFactorGraph (void) +{ + FactorGraph* fg = new FactorGraph(); + for (unsigned i = 0; i < varClusters_.size(); i++) { + FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0]; + FgVarNode* newVar = new FgVarNode (var); + varClusters_[i]->setRepresentativeVariable (newVar); + fg->addVariable (newVar); + } + + for (unsigned i = 0; i < factorClusters_.size(); i++) { + const VarClusterSet& myVarClusters = factorClusters_[i]->getVarClusters(); + VarNodes myGroundVars; + myGroundVars.reserve (myVarClusters.size()); + for (unsigned j = 0; j < myVarClusters.size(); j++) { + FgVarNode* v = myVarClusters[j]->getRepresentativeVariable(); + myGroundVars.push_back (v); + } + Factor* newFactor = new Factor (myGroundVars, + factorClusters_[i]->getGroundFactors()[0]->getDistribution()); + FgFacNode* fn = new FgFacNode (newFactor); + factorClusters_[i]->setRepresentativeFactor (fn); + fg->addFactor (fn); + for (unsigned j = 0; j < myGroundVars.size(); j++) { + fg->addEdge (fn, static_cast (myGroundVars[j])); + } + } + fg->setIndexes(); + return fg; +} + + + +unsigned +CFactorGraph::getGroundEdgeCount (const FacCluster* fc, + const VarCluster* vc) const +{ + const FgFacSet& clusterGroundFactors = fc->getGroundFactors(); + FgVarNode* varNode = vc->getGroundFgVarNodes()[0]; + unsigned count = 0; + for (unsigned i = 0; i < clusterGroundFactors.size(); i++) { + if (clusterGroundFactors[i]->factor()->getPositionOf (varNode->varId()) != -1) { + count ++; + } + } + // CFgVarSet vars = vc->getGroundFgVarNodes(); + // for (unsigned i = 1; i < vars.size(); i++) { + // FgVarNode* var = vc->getGroundFgVarNodes()[i]; + // unsigned count2 = 0; + // for (unsigned i = 0; i < clusterGroundFactors.size(); i++) { + // if (clusterGroundFactors[i]->getPositionOf (var) != -1) { + // count2 ++; + // } + // } + // if (count != count2) { cout << "oops!" << endl; abort(); } + // } + return count; +} + + + +void +CFactorGraph::printGroups (const VarSignMap& varGroups, + const FacSignMap& factorGroups) const +{ + unsigned count = 1; + cout << "variable groups:" << endl; + for (VarSignMap::const_iterator it = varGroups.begin(); + it != varGroups.end(); it++) { + const FgVarSet& groupMembers = it->second; + if (groupMembers.size() > 0) { + cout << count << ": " ; + for (unsigned i = 0; i < groupMembers.size(); i++) { + cout << groupMembers[i]->label() << " " ; + } + count ++; + cout << endl; + } + } + + count = 1; + cout << endl << "factor groups:" << endl; + for (FacSignMap::const_iterator it = factorGroups.begin(); + it != factorGroups.end(); it++) { + const FgFacSet& groupMembers = it->second; + if (groupMembers.size() > 0) { + cout << ++count << ": " ; + for (unsigned i = 0; i < groupMembers.size(); i++) { + cout << groupMembers[i]->getLabel() << " " ; + } + count ++; + cout << endl; + } + } +} + diff --git a/packages/CLPBN/clpbn/bp/CFactorGraph.h b/packages/CLPBN/clpbn/bp/CFactorGraph.h new file mode 100644 index 000000000..bc49de682 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/CFactorGraph.h @@ -0,0 +1,237 @@ +#ifndef HORUS_CFACTORGRAPH_H +#define HORUS_CFACTORGRAPH_H + +#include + +#include "FactorGraph.h" +#include "Factor.h" +#include "Shared.h" + +class VarCluster; +class FacCluster; +class Distribution; +class Signature; + +class SignatureHash; + + +typedef long Color; +typedef unordered_map > VarColorMap; +typedef unordered_map DistColorMap; +typedef unordered_map VarId2VarCluster; +typedef vector VarClusterSet; +typedef vector FacClusterSet; +typedef unordered_map VarSignMap; +typedef unordered_map FacSignMap; + + + +struct Signature { + Signature (unsigned size) + { + colors.resize (size); + } + bool operator< (const Signature& sig) const + { + if (colors.size() < sig.colors.size()) { + return true; + } else if (colors.size() > sig.colors.size()) { + return false; + } else { + for (unsigned i = 0; i < colors.size(); i++) { + if (colors[i] < sig.colors[i]) { + return true; + } else if (colors[i] > sig.colors[i]) { + return false; + } + } + } + return false; + } + bool operator== (const Signature& sig) const + { + if (colors.size() != sig.colors.size()) { + return false; + } + for (unsigned i = 0; i < colors.size(); i++) { + if (colors[i] != sig.colors[i]) { + return false; + } + } + return true; + } + vector colors; +}; + + + +struct SignatureHash { + size_t operator() (const Signature &sig) const + { + size_t val = hash()(sig.colors.size()); + for (unsigned i = 0; i < sig.colors.size(); i++) { + val ^= hash()(sig.colors[i]); + } + return val; + } +}; + + + +class VarCluster +{ + public: + VarCluster (const FgVarSet& vs) + { + for (unsigned i = 0; i < vs.size(); i++) { + groundVars_.push_back (vs[i]); + } + } + + void addFacCluster (FacCluster* fc) + { + factorClusters_.push_back (fc); + } + + const FacClusterSet& getFacClusters (void) const + { + return factorClusters_; + } + + FgVarNode* getRepresentativeVariable (void) const { return representVar_; } + void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; } + const FgVarSet& getGroundFgVarNodes (void) const { return groundVars_; } + + private: + FgVarSet groundVars_; + FacClusterSet factorClusters_; + FgVarNode* representVar_; +}; + + +class FacCluster +{ + public: + FacCluster (const FgFacSet& groundFactors, const VarClusterSet& vcs) + { + groundFactors_ = groundFactors; + varClusters_ = vcs; + for (unsigned i = 0; i < varClusters_.size(); i++) { + varClusters_[i]->addFacCluster (this); + } + } + + const VarClusterSet& getVarClusters (void) const + { + return varClusters_; + } + + bool containsGround (const FgFacNode* fn) + { + for (unsigned i = 0; i < groundFactors_.size(); i++) { + if (groundFactors_[i] == fn) { + return true; + } + } + return false; + } + + FgFacNode* getRepresentativeFactor (void) const + { + return representFactor_; + } + void setRepresentativeFactor (FgFacNode* fn) + { + representFactor_ = fn; + } + const FgFacSet& getGroundFactors (void) const + { + return groundFactors_; + } + + + private: + FgFacSet groundFactors_; + VarClusterSet varClusters_; + FgFacNode* representFactor_; +}; + + +class CFactorGraph +{ + public: + CFactorGraph (const FactorGraph&); + ~CFactorGraph (void); + + FactorGraph* getCompressedFactorGraph (void); + unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const; + + FgVarNode* getEquivalentVariable (VarId vid) + { + VarCluster* vc = vid2VarCluster_.find (vid)->second; + return vc->getRepresentativeVariable(); + } + + const VarClusterSet& getVariableClusters (void) { return varClusters_; } + const FacClusterSet& getFacClusters (void) { return factorClusters_; } + + static void enableCheckForIdenticalFactors (void) + { + checkForIdenticalFactors_ = true; + } + + static void disableCheckForIdenticalFactors (void) + { + checkForIdenticalFactors_ = false; + } + + private: + void setInitialColors (void); + void createGroups (void); + void createClusters (const VarSignMap&, const FacSignMap&); + const Signature& getSignature (const FgVarNode*); + const Signature& getSignature (const FgFacNode*); + void printGroups (const VarSignMap&, const FacSignMap&) const; + + Color getFreeColor (void) { + ++ freeColor_; + return freeColor_ - 1; + } + + Color getColor (const FgVarNode* vn) const + { + return varColors_[vn->getIndex()]; + } + Color getColor (const FgFacNode* fn) const { + return factorColors_[fn->getIndex()]; + } + + void setColor (const FgVarNode* vn, Color c) + { + varColors_[vn->getIndex()] = c; + } + + void setColor (const FgFacNode* fn, Color c) + { + factorColors_[fn->getIndex()] = c; + } + + VarCluster* getVariableCluster (VarId vid) const + { + return vid2VarCluster_.find (vid)->second; + } + + Color freeColor_; + vector varColors_; + vector factorColors_; + vector varSignatures_; + vector factorSignatures_; + VarClusterSet varClusters_; + FacClusterSet factorClusters_; + VarId2VarCluster vid2VarCluster_; + const FactorGraph* groundFg_; + bool static checkForIdenticalFactors_; +}; + +#endif // HORUS_CFACTORGRAPH_H + diff --git a/packages/CLPBN/clpbn/bp/CbpSolver.cpp b/packages/CLPBN/clpbn/bp/CbpSolver.cpp new file mode 100644 index 000000000..b45769b57 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/CbpSolver.cpp @@ -0,0 +1,263 @@ +#include "CbpSolver.h" + + +CbpSolver::~CbpSolver (void) +{ + delete lfg_; + delete factorGraph_; + for (unsigned i = 0; i < links_.size(); i++) { + delete links_[i]; + } + links_.clear(); +} + + + +ParamSet +CbpSolver::getPosterioriOf (VarId vid) +{ + FgVarNode* var = lfg_->getEquivalentVariable (vid); + ParamSet probs; + if (var->hasEvidence()) { + probs.resize (var->nrStates(), Util::noEvidence()); + probs[var->getEvidence()] = Util::withEvidence(); + } else { + probs.resize (var->nrStates(), Util::multIdenty()); + const SpLinkSet& links = ninf(var)->getLinks(); + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned i = 0; i < links.size(); i++) { + CbpSolverLink* l = static_cast (links[i]); + Util::multiply (probs, l->getPoweredMessage()); + } + Util::normalize (probs); + break; + case NumberSpace::LOGARITHM: + for (unsigned i = 0; i < links.size(); i++) { + CbpSolverLink* l = static_cast (links[i]); + Util::add (probs, l->getPoweredMessage()); + } + Util::normalize (probs); + Util::fromLog (probs); + } + } + return probs; +} + + + +ParamSet +CbpSolver::getJointDistributionOf (const VarIdSet& jointVarIds) +{ + unsigned msgSize = 1; + vector dsizes (jointVarIds.size()); + for (unsigned i = 0; i < jointVarIds.size(); i++) { + dsizes[i] = lfg_->getEquivalentVariable (jointVarIds[i])->nrStates(); + msgSize *= dsizes[i]; + } + unsigned reps = 1; + ParamSet jointDist (msgSize, Util::multIdenty()); + for (int i = jointVarIds.size() - 1 ; i >= 0; i--) { + Util::multiply (jointDist, getPosterioriOf (jointVarIds[i]), reps); + reps *= dsizes[i]; + } + return jointDist; +} + + + +void +CbpSolver::initializeSolver (void) +{ + unsigned nGroundVars, nGroundFacs, nWithoutNeighs; + if (COLLECT_STATISTICS) { + nGroundVars = factorGraph_->getVarNodes().size(); + nGroundFacs = factorGraph_->getFactorNodes().size(); + const FgVarSet& vars = factorGraph_->getVarNodes(); + nWithoutNeighs = 0; + for (unsigned i = 0; i < vars.size(); i++) { + const FgFacSet& factors = vars[i]->neighbors(); + if (factors.size() == 1 && factors[0]->neighbors().size() == 1) { + nWithoutNeighs ++; + } + } + } + + lfg_ = new CFactorGraph (*factorGraph_); + + // cout << "Uncompressed Factor Graph" << endl; + // factorGraph_->printGraphicalModel(); + // factorGraph_->exportToGraphViz ("uncompressed_fg.dot"); + factorGraph_ = lfg_->getCompressedFactorGraph(); + + if (COLLECT_STATISTICS) { + unsigned nClusterVars = factorGraph_->getVarNodes().size(); + unsigned nClusterFacs = factorGraph_->getFactorNodes().size(); + Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs, + nClusterVars, nClusterFacs, + nWithoutNeighs); + } + + // cout << "Compressed Factor Graph" << endl; + // factorGraph_->printGraphicalModel(); + // factorGraph_->exportToGraphViz ("compressed_fg.dot"); + // abort(); + FgBpSolver::initializeSolver(); +} + + + +void +CbpSolver::createLinks (void) +{ + const FacClusterSet fcs = lfg_->getFacClusters(); + for (unsigned i = 0; i < fcs.size(); i++) { + const VarClusterSet vcs = fcs[i]->getVarClusters(); + for (unsigned j = 0; j < vcs.size(); j++) { + unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]); + links_.push_back (new CbpSolverLink (fcs[i]->getRepresentativeFactor(), + vcs[j]->getRepresentativeVariable(), c)); + } + } + return; +} + + + +void +CbpSolver::maxResidualSchedule (void) +{ + if (nIters_ == 1) { + for (unsigned i = 0; i < links_.size(); i++) { + calculateMessage (links_[i]); + SortedOrder::iterator it = sortedOrder_.insert (links_[i]); + linkMap_.insert (make_pair (links_[i], it)); + if (DL >= 2 && DL < 5) { + cout << "calculating " << links_[i]->toString() << endl; + } + } + return; + } + + for (unsigned c = 0; c < links_.size(); c++) { + if (DL >= 2) { + cout << endl << "current residuals:" << endl; + for (SortedOrder::iterator it = sortedOrder_.begin(); + it != sortedOrder_.end(); it ++) { + cout << " " << setw (30) << left << (*it)->toString(); + cout << "residual = " << (*it)->getResidual() << endl; + } + } + + SortedOrder::iterator it = sortedOrder_.begin(); + SpLink* link = *it; + if (DL >= 2) { + cout << "updating " << (*sortedOrder_.begin())->toString() << endl; + } + if (link->getResidual() < BpOptions::accuracy) { + return; + } + link->updateMessage(); + link->clearResidual(); + sortedOrder_.erase (it); + linkMap_.find (link)->second = sortedOrder_.insert (link); + + // update the messages that depend on message source --> destin + const FgFacSet& factorNeighbors = link->getVariable()->neighbors(); + for (unsigned i = 0; i < factorNeighbors.size(); i++) { + const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks(); + for (unsigned j = 0; j < links.size(); j++) { + if (links[j]->getVariable() != link->getVariable()) { + if (DL >= 2 && DL < 5) { + cout << " calculating " << links[j]->toString() << endl; + } + calculateMessage (links[j]); + SpLinkMap::iterator iter = linkMap_.find (links[j]); + sortedOrder_.erase (iter->second); + iter->second = sortedOrder_.insert (links[j]); + } + } + } + // in counting bp, the message that a variable X sends to + // to a factor F depends on the message that F sent to the X + const SpLinkSet& links = ninf(link->getFactor())->getLinks(); + for (unsigned i = 0; i < links.size(); i++) { + if (links[i]->getVariable() != link->getVariable()) { + if (DL >= 2 && DL < 5) { + cout << " calculating " << links[i]->toString() << endl; + } + calculateMessage (links[i]); + SpLinkMap::iterator iter = linkMap_.find (links[i]); + sortedOrder_.erase (iter->second); + iter->second = sortedOrder_.insert (links[i]); + } + } + } +} + + + +ParamSet +CbpSolver::getVar2FactorMsg (const SpLink* link) const +{ + ParamSet msg; + const FgVarNode* src = link->getVariable(); + const FgFacNode* dst = link->getFactor(); + const CbpSolverLink* l = static_cast (link); + if (src->hasEvidence()) { + msg.resize (src->nrStates(), Util::noEvidence()); + double value = link->getMessage()[src->getEvidence()]; + msg[src->getEvidence()] = Util::pow (value, l->getNumberOfEdges() - 1); + } else { + msg = link->getMessage(); + Util::pow (msg, l->getNumberOfEdges() - 1); + } + if (DL >= 5) { + cout << " " << "init: " << Util::parametersToString (msg) << endl; + } + const SpLinkSet& links = ninf(src)->getLinks(); + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned i = 0; i < links.size(); i++) { + if (links[i]->getFactor() != dst) { + CbpSolverLink* l = static_cast (links[i]); + Util::multiply (msg, l->getPoweredMessage()); + if (DL >= 5) { + cout << " msg from " << l->getFactor()->getLabel() << ": " ; + cout << Util::parametersToString (l->getPoweredMessage()) << endl; + } + } + } + break; + case NumberSpace::LOGARITHM: + for (unsigned i = 0; i < links.size(); i++) { + if (links[i]->getFactor() != dst) { + CbpSolverLink* l = static_cast (links[i]); + Util::add (msg, l->getPoweredMessage()); + } + } + } + if (DL >= 5) { + cout << " result = " << Util::parametersToString (msg) << endl; + } + return msg; +} + + + +void +CbpSolver::printLinkInformation (void) const +{ + for (unsigned i = 0; i < links_.size(); i++) { + CbpSolverLink* l = static_cast (links_[i]); + cout << l->toString() << ":" << endl; + cout << " curr msg = " ; + cout << Util::parametersToString (l->getMessage()) << endl; + cout << " next msg = " ; + cout << Util::parametersToString (l->getNextMessage()) << endl; + cout << " powered = " ; + cout << Util::parametersToString (l->getPoweredMessage()) << endl; + cout << " residual = " << l->getResidual() << endl; + } +} + diff --git a/packages/CLPBN/clpbn/bp/CbpSolver.h b/packages/CLPBN/clpbn/bp/CbpSolver.h new file mode 100644 index 000000000..69ce4ce7f --- /dev/null +++ b/packages/CLPBN/clpbn/bp/CbpSolver.h @@ -0,0 +1,58 @@ +#ifndef HORUS_CBP_H +#define HORUS_CBP_H + +#include "FgBpSolver.h" +#include "CFactorGraph.h" + +class Factor; + +class CbpSolverLink : public SpLink +{ + public: + CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn) + { + edgeCount_ = c; + poweredMsg_.resize (vn->nrStates(), Util::one()); + } + + void updateMessage (void) + { + poweredMsg_ = *nextMsg_; + swap (currMsg_, nextMsg_); + msgSended_ = true; + Util::pow (poweredMsg_, edgeCount_); + } + + unsigned getNumberOfEdges (void) const { return edgeCount_; } + const ParamSet& getPoweredMessage (void) const { return poweredMsg_; } + + private: + ParamSet poweredMsg_; + unsigned edgeCount_; +}; + + + +class CbpSolver : public FgBpSolver +{ + public: + CbpSolver (FactorGraph& fg) : FgBpSolver (fg) { } + ~CbpSolver (void); + + ParamSet getPosterioriOf (VarId); + ParamSet getJointDistributionOf (const VarIdSet&); + + private: + void initializeSolver (void); + void createLinks (void); + + void maxResidualSchedule (void); + ParamSet getVar2FactorMsg (const SpLink*) const; + void printLinkInformation (void) const; + + + CFactorGraph* lfg_; +}; + +#endif // HORUS_CBP_H + diff --git a/packages/CLPBN/clpbn/bp/CountingBP.cpp b/packages/CLPBN/clpbn/bp/CountingBP.cpp deleted file mode 100644 index 645c61418..000000000 --- a/packages/CLPBN/clpbn/bp/CountingBP.cpp +++ /dev/null @@ -1,198 +0,0 @@ -#include "CountingBP.h" - - -CountingBP::~CountingBP (void) -{ - delete lfg_; - delete fg_; - for (unsigned i = 0; i < links_.size(); i++) { - delete links_[i]; - } - links_.clear(); -} - - - -ParamSet -CountingBP::getPosterioriOf (Vid vid) const -{ - FgVarNode* var = lfg_->getEquivalentVariable (vid); - ParamSet probs; - - if (var->hasEvidence()) { - probs.resize (var->getDomainSize(), 0.0); - probs[var->getEvidence()] = 1.0; - } else { - probs.resize (var->getDomainSize(), 1.0); - CLinkSet links = varsI_[var->getIndex()]->getLinks(); - for (unsigned i = 0; i < links.size(); i++) { - ParamSet msg = links[i]->getMessage(); - CountingBPLink* l = static_cast (links[i]); - Util::pow (msg, l->getNumberOfEdges()); - for (unsigned j = 0; j < msg.size(); j++) { - probs[j] *= msg[j]; - } - } - Util::normalize (probs); - } - return probs; -} - - - -void -CountingBP::initializeSolver (void) -{ - lfg_ = new LiftedFG (*fg_); - unsigned nUncVars = fg_->getFgVarNodes().size(); - unsigned nUncFactors = fg_->getFactors().size(); - CFgVarSet vars = fg_->getFgVarNodes(); - unsigned nNeighborLessVars = 0; - for (unsigned i = 0; i < vars.size(); i++) { - CFactorSet factors = vars[i]->getFactors(); - if (factors.size() == 1 && factors[0]->getFgVarNodes().size() == 1) { - nNeighborLessVars ++; - } - } - // cout << "UNCOMPRESSED FACTOR GRAPH" << endl; - // fg_->printGraphicalModel(); - fg_->exportToDotFormat ("uncompress.dot"); - - FactorGraph *temp; - temp = fg_; - fg_ = lfg_->getCompressedFactorGraph(); - unsigned nCompVars = fg_->getFgVarNodes().size(); - unsigned nCompFactors = fg_->getFactors().size(); - - Statistics::updateCompressingStats (nUncVars, - nUncFactors, - nCompVars, - nCompFactors, - nNeighborLessVars); - - cout << "COMPRESSED FACTOR GRAPH" << endl; - fg_->printGraphicalModel(); - //fg_->exportToDotFormat ("compress.dot"); - - SPSolver::initializeSolver(); -} - - - -void -CountingBP::createLinks (void) -{ - const FactorClusterSet fcs = lfg_->getFactorClusters(); - for (unsigned i = 0; i < fcs.size(); i++) { - const VarClusterSet vcs = fcs[i]->getVarClusters(); - for (unsigned j = 0; j < vcs.size(); j++) { - unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]); - links_.push_back ( - new CountingBPLink (fcs[i]->getRepresentativeFactor(), - vcs[j]->getRepresentativeVariable(), c)); - //cout << (links_.back())->toString() << " edge count =" << c << endl; - } - } - return; -} - - - -void -CountingBP::deleteJunction (Factor* f, FgVarNode*) -{ - f->freeDistribution(); -} - - - -void -CountingBP::maxResidualSchedule (void) -{ - if (nIter_ == 1) { - for (unsigned i = 0; i < links_.size(); i++) { - links_[i]->setNextMessage (getFactor2VarMsg (links_[i])); - SortedOrder::iterator it = sortedOrder_.insert (links_[i]); - linkMap_.insert (make_pair (links_[i], it)); - if (DL >= 2 && DL < 5) { - cout << "calculating " << links_[i]->toString() << endl; - } - } - return; - } - - for (unsigned c = 0; c < links_.size(); c++) { - if (DL >= 2) { - cout << endl << "current residuals:" << endl; - for (SortedOrder::iterator it = sortedOrder_.begin(); - it != sortedOrder_.end(); it ++) { - cout << " " << setw (30) << left << (*it)->toString(); - cout << "residual = " << (*it)->getResidual() << endl; - } - } - - SortedOrder::iterator it = sortedOrder_.begin(); - Link* link = *it; - if (DL >= 2) { - cout << "updating " << (*sortedOrder_.begin())->toString() << endl; - } - if (link->getResidual() < SolverOptions::accuracy) { - return; - } - link->updateMessage(); - link->clearResidual(); - sortedOrder_.erase (it); - linkMap_.find (link)->second = sortedOrder_.insert (link); - - // update the messages that depend on message source --> destin - CFactorSet factorNeighbors = link->getVariable()->getFactors(); - for (unsigned i = 0; i < factorNeighbors.size(); i++) { - CLinkSet links = factorsI_[factorNeighbors[i]->getIndex()]->getLinks(); - for (unsigned j = 0; j < links.size(); j++) { - if (links[j]->getVariable() != link->getVariable()) { //FIXMEFIXME - if (DL >= 2 && DL < 5) { - cout << " calculating " << links[j]->toString() << endl; - } - links[j]->setNextMessage (getFactor2VarMsg (links[j])); - LinkMap::iterator iter = linkMap_.find (links[j]); - sortedOrder_.erase (iter->second); - iter->second = sortedOrder_.insert (links[j]); - } - } - } - } -} - - - -ParamSet -CountingBP::getVar2FactorMsg (const Link* link) const -{ - const FgVarNode* src = link->getVariable(); - const Factor* dest = link->getFactor(); - ParamSet msg; - if (src->hasEvidence()) { - cout << "has evidence" << endl; - msg.resize (src->getDomainSize(), 0.0); - msg[src->getEvidence()] = link->getMessage()[src->getEvidence()]; - cout << "-> " << link->getVariable()->getLabel() << " " << link->getFactor()->getLabel() << endl; - cout << "-> p2s " << Util::parametersToString (msg) << endl; - } else { - msg = link->getMessage(); - } - const CountingBPLink* l = static_cast (link); - Util::pow (msg, l->getNumberOfEdges() - 1); - CLinkSet links = varsI_[src->getIndex()]->getLinks(); - for (unsigned i = 0; i < links.size(); i++) { - if (links[i]->getFactor() != dest) { - ParamSet msgFromFactor = links[i]->getMessage(); - CountingBPLink* l = static_cast (links[i]); - Util::pow (msgFromFactor, l->getNumberOfEdges()); - for (unsigned j = 0; j < msgFromFactor.size(); j++) { - msg[j] *= msgFromFactor[j]; - } - } - } - return msg; -} - diff --git a/packages/CLPBN/clpbn/bp/CountingBP.h b/packages/CLPBN/clpbn/bp/CountingBP.h deleted file mode 100644 index 540817ba0..000000000 --- a/packages/CLPBN/clpbn/bp/CountingBP.h +++ /dev/null @@ -1,45 +0,0 @@ -#ifndef BP_COUNTING_BP_H -#define BP_COUNTING_BP_H - -#include "SPSolver.h" -#include "LiftedFG.h" - -class Factor; -class FgVarNode; - -class CountingBPLink : public Link -{ - public: - CountingBPLink (Factor* f, FgVarNode* v, unsigned c) : Link (f, v) - { - edgeCount_ = c; - } - - unsigned getNumberOfEdges (void) const { return edgeCount_; } - - private: - unsigned edgeCount_; -}; - - -class CountingBP : public SPSolver -{ - public: - CountingBP (FactorGraph& fg) : SPSolver (fg) { } - ~CountingBP (void); - - ParamSet getPosterioriOf (Vid) const; - - private: - void initializeSolver (void); - void createLinks (void); - void deleteJunction (Factor*, FgVarNode*); - - void maxResidualSchedule (void); - ParamSet getVar2FactorMsg (const Link*) const; - - LiftedFG* lfg_; -}; - -#endif // BP_COUNTING_BP_H - diff --git a/packages/CLPBN/clpbn/bp/CptEntry.h b/packages/CLPBN/clpbn/bp/CptEntry.h index c843f53e4..b5248bada 100644 --- a/packages/CLPBN/clpbn/bp/CptEntry.h +++ b/packages/CLPBN/clpbn/bp/CptEntry.h @@ -1,5 +1,5 @@ -#ifndef BP_CPT_ENTRY_H -#define BP_CPT_ENTRY_H +#ifndef HORUS_CPTENTRY_H +#define HORUS_CPTENTRY_H #include @@ -39,5 +39,5 @@ class CptEntry DConf conf_; }; -#endif //BP_CPT_ENTRY_H +#endif // HORUS_CPTENTRY_H diff --git a/packages/CLPBN/clpbn/bp/Distribution.cpp b/packages/CLPBN/clpbn/bp/Distribution.cpp deleted file mode 100644 index 309d8b73e..000000000 --- a/packages/CLPBN/clpbn/bp/Distribution.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include -#include - -#include - -Distribution::Distribution (int id, - double* params, - int nParams, - vector domain) -{ - this->id = id; - this->params = params; - this->nParams = nParams; - this->domain = domain; -} - - -Distribution::Distribution (double* params, - int nParams, - vector domain) -{ - this->id = -1; - this->params = params; - this->nParams = nParams; - this->domain = domain; -} - - - -/* -Distribution::~Distribution() -{ - delete params; - for (unsigned int i = 0; i < cptEntries.size(); i++) { - delete cptEntries[i]; - } -} -*/ - - diff --git a/packages/CLPBN/clpbn/bp/Distribution.h b/packages/CLPBN/clpbn/bp/Distribution.h index 16f4b826d..aef050838 100644 --- a/packages/CLPBN/clpbn/bp/Distribution.h +++ b/packages/CLPBN/clpbn/bp/Distribution.h @@ -1,5 +1,5 @@ -#ifndef BP_DISTRIBUTION_H -#define BP_DISTRIBUTION_H +#ifndef HORUS_DISTRIBUTION_H +#define HORUS_DISTRIBUTION_H #include @@ -11,18 +11,16 @@ using namespace std; struct Distribution { public: - Distribution (unsigned id, bool shared = false) + Distribution (unsigned id) { this->id = id; this->params = params; - this->shared = shared; } - Distribution (const ParamSet& params, bool shared = false) + Distribution (const ParamSet& params, unsigned id = -1) { - this->id = -1; + this->id = id; this->params = params; - this->shared = shared; } void updateParameters (const ParamSet& params) @@ -33,11 +31,10 @@ struct Distribution unsigned id; ParamSet params; vector entries; - bool shared; private: DISALLOW_COPY_AND_ASSIGN (Distribution); }; -#endif //BP_DISTRIBUTION_H +#endif // HORUS_DISTRIBUTION_H diff --git a/packages/CLPBN/clpbn/bp/ElimGraph.cpp b/packages/CLPBN/clpbn/bp/ElimGraph.cpp new file mode 100644 index 000000000..4c5d7581b --- /dev/null +++ b/packages/CLPBN/clpbn/bp/ElimGraph.cpp @@ -0,0 +1,322 @@ +#include + +#include "ElimGraph.h" +#include "BayesNet.h" + + +ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS; + + +ElimGraph::ElimGraph (const BayesNet& bayesNet) +{ + const BnNodeSet& bnNodes = bayesNet.getBayesNodes(); + for (unsigned i = 0; i < bnNodes.size(); i++) { + if (bnNodes[i]->hasEvidence() == false) { + addNode (new EgNode (bnNodes[i])); + } + } + + for (unsigned i = 0; i < bnNodes.size(); i++) { + if (bnNodes[i]->hasEvidence() == false) { + EgNode* n = getEgNode (bnNodes[i]->varId()); + const BnNodeSet& childs = bnNodes[i]->getChilds(); + for (unsigned j = 0; j < childs.size(); j++) { + if (childs[j]->hasEvidence() == false) { + addEdge (n, getEgNode (childs[j]->varId())); + } + } + } + } + + for (unsigned i = 0; i < bnNodes.size(); i++) { + vector neighs; + const vector& parents = bnNodes[i]->getParents(); + for (unsigned i = 0; i < parents.size(); i++) { + if (parents[i]->hasEvidence() == false) { + neighs.push_back (getEgNode (parents[i]->varId())); + } + } + if (neighs.size() > 0) { + for (unsigned i = 0; i < neighs.size() - 1; i++) { + for (unsigned j = i+1; j < neighs.size(); j++) { + if (!neighbors (neighs[i], neighs[j])) { + addEdge (neighs[i], neighs[j]); + } + } + } + } + } + + setIndexes(); +} + + + +ElimGraph::~ElimGraph (void) +{ + for (unsigned i = 0; i < nodes_.size(); i++) { + delete nodes_[i]; + } +} + + + +void +ElimGraph::addNode (EgNode* n) +{ + nodes_.push_back (n); + vid2nodes_.insert (make_pair (n->varId(), n)); +} + + + +EgNode* +ElimGraph::getEgNode (VarId vid) const +{ + unordered_map::const_iterator it = vid2nodes_.find (vid); + if (it == vid2nodes_.end()) { + return 0; + } else { + return it->second; + } +} + + + +VarIdSet +ElimGraph::getEliminatingOrder (const VarIdSet& exclude) +{ + VarIdSet elimOrder; + marked_.resize (nodes_.size(), false); + + for (unsigned i = 0; i < exclude.size(); i++) { + EgNode* node = getEgNode (exclude[i]); + assert (node); + marked_[*node] = true; + } + + unsigned nVarsToEliminate = nodes_.size() - exclude.size(); + for (unsigned i = 0; i < nVarsToEliminate; i++) { + EgNode* node = getLowestCostNode(); + marked_[*node] = true; + elimOrder.push_back (node->varId()); + connectAllNeighbors (node); + } + return elimOrder; +} + + + +EgNode* +ElimGraph::getLowestCostNode (void) const +{ + EgNode* bestNode = 0; + unsigned minCost = std::numeric_limits::max(); + for (unsigned i = 0; i < nodes_.size(); i++) { + if (marked_[i]) continue; + unsigned cost = 0; + switch (elimHeuristic_) { + case MIN_NEIGHBORS: + cost = getNeighborsCost (nodes_[i]); + break; + case MIN_WEIGHT: + cost = getWeightCost (nodes_[i]); + break; + case MIN_FILL: + cost = getFillCost (nodes_[i]); + break; + case WEIGHTED_MIN_FILL: + cost = getWeightedFillCost (nodes_[i]); + break; + default: + assert (false); + } + if (cost < minCost) { + bestNode = nodes_[i]; + minCost = cost; + } + } + assert (bestNode); + return bestNode; +} + + + +unsigned +ElimGraph::getNeighborsCost (const EgNode* n) const +{ + unsigned cost = 0; + const vector& neighs = n->neighbors(); + for (unsigned i = 0; i < neighs.size(); i++) { + if (marked_[*neighs[i]] == false) { + cost ++; + } + } + return cost; +} + + + +unsigned +ElimGraph::getWeightCost (const EgNode* n) const +{ + unsigned cost = 1; + const vector& neighs = n->neighbors(); + for (unsigned i = 0; i < neighs.size(); i++) { + if (marked_[*neighs[i]] == false) { + cost *= neighs[i]->nrStates(); + } + } + return cost; +} + + + +unsigned +ElimGraph::getFillCost (const EgNode* n) const +{ + unsigned cost = 0; + const vector& neighs = n->neighbors(); + if (neighs.size() > 0) { + for (unsigned i = 0; i < neighs.size() - 1; i++) { + if (marked_[*neighs[i]] == true) continue; + for (unsigned j = i+1; j < neighs.size(); j++) { + if (marked_[*neighs[j]] == true) continue; + if (!neighbors (neighs[i], neighs[j])) { + cost ++; + } + } + } + } + return cost; +} + + + +unsigned +ElimGraph::getWeightedFillCost (const EgNode* n) const +{ + unsigned cost = 0; + const vector& neighs = n->neighbors(); + if (neighs.size() > 0) { + for (unsigned i = 0; i < neighs.size() - 1; i++) { + if (marked_[*neighs[i]] == true) continue; + for (unsigned j = i+1; j < neighs.size(); j++) { + if (marked_[*neighs[j]] == true) continue; + if (!neighbors (neighs[i], neighs[j])) { + cost += neighs[i]->nrStates() * neighs[j]->nrStates(); + } + } + } + } + return cost; +} + + + +void +ElimGraph::connectAllNeighbors (const EgNode* n) +{ + const vector& neighs = n->neighbors(); + if (neighs.size() > 0) { + for (unsigned i = 0; i < neighs.size() - 1; i++) { + if (marked_[*neighs[i]] == true) continue; + for (unsigned j = i+1; j < neighs.size(); j++) { + if (marked_[*neighs[j]] == true) continue; + if (!neighbors (neighs[i], neighs[j])) { + addEdge (neighs[i], neighs[j]); + } + } + } + } +} + + + +bool +ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const +{ + const vector& neighs = n1->neighbors(); + for (unsigned i = 0; i < neighs.size(); i++) { + if (neighs[i] == n2) { + return true; + } + } + return false; +} + + + +void +ElimGraph::setIndexes (void) +{ + for (unsigned i = 0; i < nodes_.size(); i++) { + nodes_[i]->setIndex (i); + } +} + + + +void +ElimGraph::printGraphicalModel (void) const +{ + for (unsigned i = 0; i < nodes_.size(); i++) { + cout << "node " << nodes_[i]->label() << " neighs:" ; + vector neighs = nodes_[i]->neighbors(); + for (unsigned j = 0; j < neighs.size(); j++) { + cout << " " << neighs[j]->label(); + } + cout << endl; + } +} + + + +void +ElimGraph::exportToGraphViz (const char* fileName, + bool showNeighborless, + const VarIdSet& highlightVarIds) const +{ + ofstream out (fileName); + if (!out.is_open()) { + cerr << "error: cannot open file to write at " ; + cerr << "Markov::exportToDotFile()" << endl; + abort(); + } + + out << "strict graph {" << endl; + + for (unsigned i = 0; i < nodes_.size(); i++) { + if (showNeighborless || nodes_[i]->neighbors().size() != 0) { + out << '"' << nodes_[i]->label() << '"' ; + if (nodes_[i]->hasEvidence()) { + out << " [style=filled, fillcolor=yellow]" << endl; + } else { + out << endl; + } + } + } + + for (unsigned i = 0; i < highlightVarIds.size(); i++) { + EgNode* node =getEgNode (highlightVarIds[i]); + if (node) { + out << '"' << node->label() << '"' ; + out << " [shape=box3d]" << endl; + } else { + cout << "error: invalid variable id: " << highlightVarIds[i] << endl; + abort(); + } + } + + for (unsigned i = 0; i < nodes_.size(); i++) { + vector neighs = nodes_[i]->neighbors(); + for (unsigned j = 0; j < neighs.size(); j++) { + out << '"' << nodes_[i]->label() << '"' << " -- " ; + out << '"' << neighs[j]->label() << '"' << endl; + } + } + + out << "}" << endl; + out.close(); +} + diff --git a/packages/CLPBN/clpbn/bp/ElimGraph.h b/packages/CLPBN/clpbn/bp/ElimGraph.h new file mode 100644 index 000000000..de5cb80f6 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/ElimGraph.h @@ -0,0 +1,76 @@ +#ifndef HORUS_ELIMGRAPH_H +#define HORUS_ELIMGRAPH_H + +#include "unordered_map" + +#include "FactorGraph.h" +#include "Shared.h" + +using namespace std; + +enum ElimHeuristic +{ + MIN_NEIGHBORS, + MIN_WEIGHT, + MIN_FILL, + WEIGHTED_MIN_FILL +}; + + +class EgNode : public VarNode { + public: + EgNode (VarNode* var) : VarNode (var) { } + void addNeighbor (EgNode* n) + { + neighs_.push_back (n); + } + + const vector& neighbors (void) const { return neighs_; } + private: + vector neighs_; +}; + + +class ElimGraph +{ + public: + ElimGraph (const BayesNet&); + ~ElimGraph (void); + + void addEdge (EgNode* n1, EgNode* n2) + { + assert (n1 != n2); + n1->addNeighbor (n2); + n2->addNeighbor (n1); + } + void addNode (EgNode*); + EgNode* getEgNode (VarId) const; + VarIdSet getEliminatingOrder (const VarIdSet&); + void printGraphicalModel (void) const; + void exportToGraphViz (const char*, bool = true, + const VarIdSet& = VarIdSet()) const; + void setIndexes(); + + static void setEliminationHeuristic (ElimHeuristic h) + { + elimHeuristic_ = h; + } + + private: + EgNode* getLowestCostNode (void) const; + unsigned getNeighborsCost (const EgNode*) const; + unsigned getWeightCost (const EgNode*) const; + unsigned getFillCost (const EgNode*) const; + unsigned getWeightedFillCost (const EgNode*) const; + void connectAllNeighbors (const EgNode*); + bool neighbors (const EgNode*, const EgNode*) const; + + + vector nodes_; + vector marked_; + unordered_map vid2nodes_; + static ElimHeuristic elimHeuristic_; +}; + +#endif // HORUS_ELIMGRAPH_H + diff --git a/packages/CLPBN/clpbn/bp/Factor.cpp b/packages/CLPBN/clpbn/bp/Factor.cpp index 7b3081ab2..e5f16f9e1 100644 --- a/packages/CLPBN/clpbn/bp/Factor.cpp +++ b/packages/CLPBN/clpbn/bp/Factor.cpp @@ -1,33 +1,38 @@ #include #include +#include + #include #include #include "Factor.h" -#include "FgVarNode.h" +#include "StatesIndexer.h" Factor::Factor (const Factor& g) { - copyFactor (g); + copyFromFactor (g); } -Factor::Factor (FgVarNode* var) +Factor::Factor (VarId vid, unsigned nStates) { - Factor (FgVarSet() = {var}); + varids_.push_back (vid); + ranges_.push_back (nStates); + dist_ = new Distribution (ParamSet (nStates, 1.0)); } -Factor::Factor (const FgVarSet& vars) +Factor::Factor (const VarNodes& vars) { - vars_ = vars; int nParams = 1; - for (unsigned i = 0; i < vars_.size(); i++) { - nParams *= vars_[i]->getDomainSize(); + for (unsigned i = 0; i < vars.size(); i++) { + varids_.push_back (vars[i]->varId()); + ranges_.push_back (vars[i]->nrStates()); + nParams *= vars[i]->nrStates(); } // create a uniform distribution double val = 1.0 / nParams; @@ -36,29 +41,44 @@ Factor::Factor (const FgVarSet& vars) -Factor::Factor (FgVarNode* var, - const ParamSet& params) +Factor::Factor (VarId vid, unsigned nStates, const ParamSet& params) { - vars_.push_back (var); + varids_.push_back (vid); + ranges_.push_back (nStates); dist_ = new Distribution (params); } -Factor::Factor (FgVarSet& vars, - Distribution* dist) +Factor::Factor (VarNodes& vars, Distribution* dist) { - vars_ = vars; + for (unsigned i = 0; i < vars.size(); i++) { + varids_.push_back (vars[i]->varId()); + ranges_.push_back (vars[i]->nrStates()); + } dist_ = dist; } -Factor::Factor (const FgVarSet& vars, +Factor::Factor (const VarNodes& vars, const ParamSet& params) +{ + for (unsigned i = 0; i < vars.size(); i++) { + varids_.push_back (vars[i]->varId()); + ranges_.push_back (vars[i]->nrStates()); + } + dist_ = new Distribution (params); +} + + + +Factor::Factor (const VarIdSet& vids, + const Ranges& ranges, const ParamSet& params) { - vars_ = vars; - dist_ = new Distribution (params); + varids_ = vids; + ranges_ = ranges; + dist_ = new Distribution (params); } @@ -73,9 +93,10 @@ Factor::setParameters (const ParamSet& params) void -Factor::copyFactor (const Factor& g) +Factor::copyFromFactor (const Factor& g) { - vars_ = g.getFgVarNodes(); + varids_ = g.getVarIds(); + ranges_ = g.getRanges(); dist_ = new Distribution (g.getDistribution()->params); } @@ -84,50 +105,43 @@ Factor::copyFactor (const Factor& g) void Factor::multiplyByFactor (const Factor& g, const vector* entries) { - if (vars_.size() == 0) { - copyFactor (g); + if (varids_.size() == 0) { + copyFromFactor (g); return; } - const FgVarSet& gVs = g.getFgVarNodes(); - const ParamSet& gPs = g.getParameters(); + const VarIdSet& gvarids = g.getVarIds(); + const Ranges& granges = g.getRanges(); + const ParamSet& gparams = g.getParameters(); - bool factorsAreEqual = true; - if (gVs.size() == vars_.size()) { - for (unsigned i = 0; i < vars_.size(); i++) { - if (gVs[i] != vars_[i]) { - factorsAreEqual = false; - break; - } - } - } else { - factorsAreEqual = false; - } - - if (factorsAreEqual) { + if (varids_ == gvarids) { // optimization: if the factors contain the same set of variables, - // we can do 1 to 1 operations on the parameteres - for (unsigned i = 0; i < dist_->params.size(); i++) { - dist_->params[i] *= gPs[i]; + // we can do a 1 to 1 operation on the parameters + switch (NSPACE) { + case NumberSpace::NORMAL: + Util::multiply (dist_->params, gparams); + break; + case NumberSpace::LOGARITHM: + Util::add (dist_->params, gparams); } } else { bool hasCommonVars = false; - vector gVsIndexes; - for (unsigned i = 0; i < gVs.size(); i++) { - int idx = getIndexOf (gVs[i]); - if (idx == -1) { - insertVariable (gVs[i]); - gVsIndexes.push_back (vars_.size() - 1); + vector gvarpos; + for (unsigned i = 0; i < gvarids.size(); i++) { + int pos = getPositionOf (gvarids[i]); + if (pos == -1) { + insertVariable (gvarids[i], granges[i]); + gvarpos.push_back (varids_.size() - 1); } else { hasCommonVars = true; - gVsIndexes.push_back (idx); + gvarpos.push_back (pos); } } if (hasCommonVars) { - vector gVsOffsets (gVs.size()); - gVsOffsets[gVs.size() - 1] = 1; - for (int i = gVs.size() - 2; i >= 0; i--) { - gVsOffsets[i] = gVsOffsets[i + 1] * gVs[i + 1]->getDomainSize(); + vector gvaroffsets (gvarids.size()); + gvaroffsets[gvarids.size() - 1] = 1; + for (int i = gvarids.size() - 2; i >= 0; i--) { + gvaroffsets[i] = gvaroffsets[i + 1] * granges[i + 1]; } if (entries == 0) { @@ -137,50 +151,88 @@ Factor::multiplyByFactor (const Factor& g, const vector* entries) for (unsigned i = 0; i < entries->size(); i++) { unsigned idx = 0; const DConf& conf = (*entries)[i].getDomainConfiguration(); - for (unsigned j = 0; j < gVsIndexes.size(); j++) { - idx += gVsOffsets[j] * conf[ gVsIndexes[j] ]; + for (unsigned j = 0; j < gvarpos.size(); j++) { + idx += gvaroffsets[j] * conf[ gvarpos[j] ]; + } + switch (NSPACE) { + case NumberSpace::NORMAL: + dist_->params[i] *= gparams[idx]; + break; + case NumberSpace::LOGARITHM: + dist_->params[i] += gparams[idx]; } - dist_->params[i] = dist_->params[i] * gPs[idx]; } } else { // optimization: if the original factors doesn't have common variables, // we don't need to marry the states of the common variables unsigned count = 0; for (unsigned i = 0; i < dist_->params.size(); i++) { - dist_->params[i] *= gPs[count]; + switch (NSPACE) { + case NumberSpace::NORMAL: + dist_->params[i] *= gparams[count]; + break; + case NumberSpace::LOGARITHM: + dist_->params[i] += gparams[count]; + } count ++; - if (count >= gPs.size()) { + if (count >= gparams.size()) { count = 0; } } } } + dist_->entries.clear(); } void -Factor::insertVariable (FgVarNode* var) +Factor::insertVariable (VarId vid, unsigned nStates) { - assert (getIndexOf (var) == -1); + assert (getPositionOf (vid) == -1); ParamSet newPs; - newPs.reserve (dist_->params.size() * var->getDomainSize()); + newPs.reserve (dist_->params.size() * nStates); for (unsigned i = 0; i < dist_->params.size(); i++) { - for (unsigned j = 0; j < var->getDomainSize(); j++) { + for (unsigned j = 0; j < nStates; j++) { newPs.push_back (dist_->params[i]); } } - vars_.push_back (var); + varids_.push_back (vid); + ranges_.push_back (nStates); dist_->updateParameters (newPs); + dist_->entries.clear(); } void -Factor::removeVariable (const FgVarNode* var) +Factor::removeAllVariablesExcept (VarId vid) { - int varIndex = getIndexOf (var); - assert (varIndex >= 0 && varIndex < (int)vars_.size()); + assert (getPositionOf (vid) != -1); + while (varids_.back() != vid) { + removeLastVariable(); + } + while (varids_.front() != vid) { + removeFirstVariable(); + } +} + + + +void +Factor::removeVariable (VarId vid) +{ + int pos = getPositionOf (vid); + assert (pos != -1); + + if (vid == varids_.back()) { + removeLastVariable(); // optimization + return; + } + if (vid == varids_.front()) { + removeFirstVariable(); // optimization + return; + } // number of parameters separating a different state of `var', // with the states of the remaining variables fixed @@ -190,36 +242,36 @@ Factor::removeVariable (const FgVarNode* var) // on the left of `var', with the states of the remaining vars fixed unsigned leftVarOffset = 1; - for (int i = vars_.size() - 1; i > varIndex; i--) { - varOffset *= vars_[i]->getDomainSize(); - leftVarOffset *= vars_[i]->getDomainSize(); + for (int i = varids_.size() - 1; i > pos; i--) { + varOffset *= ranges_[i]; + leftVarOffset *= ranges_[i]; } - leftVarOffset *= vars_[varIndex]->getDomainSize(); + leftVarOffset *= ranges_[pos]; unsigned offset = 0; unsigned count1 = 0; unsigned count2 = 0; - unsigned newPsSize = dist_->params.size() / vars_[varIndex]->getDomainSize(); + unsigned newPsSize = dist_->params.size() / ranges_[pos]; ParamSet newPs; newPs.reserve (newPsSize); - // stringstream ss; - // ss << "marginalizing " << vars_[varIndex]->getLabel(); - // ss << " from factor " << getLabel() << endl; while (newPs.size() < newPsSize) { - // ss << " sum = "; - double sum = 0.0; - for (unsigned i = 0; i < vars_[varIndex]->getDomainSize(); i++) { - // if (i != 0) ss << " + "; - // ss << dist_->params[offset]; - sum += dist_->params[offset]; + double sum = Util::addIdenty(); + for (unsigned i = 0; i < ranges_[pos]; i++) { + switch (NSPACE) { + case NumberSpace::NORMAL: + sum += dist_->params[offset]; + break; + case NumberSpace::LOGARITHM: + Util::logSum (sum, dist_->params[offset]); + } offset += varOffset; } newPs.push_back (sum); count1 ++; - if (varIndex == (int)vars_.size() - 1) { - offset = count1 * vars_[varIndex]->getDomainSize(); + if (pos == (int)varids_.size() - 1) { + offset = count1 * ranges_[pos]; } else { if (((offset - varOffset + 1) % leftVarOffset) == 0) { count1 = 0; @@ -227,11 +279,200 @@ Factor::removeVariable (const FgVarNode* var) } offset = (leftVarOffset * count2) + count1; } - // ss << " = " << sum << endl; } - // cout << ss.str() << endl; - vars_.erase (vars_.begin() + varIndex); + varids_.erase (varids_.begin() + pos); + ranges_.erase (ranges_.begin() + pos); dist_->updateParameters (newPs); + dist_->entries.clear(); +} + + + +void +Factor::removeFirstVariable (void) +{ + ParamSet& params = dist_->params; + unsigned nStates = ranges_.front(); + unsigned sep = params.size() / nStates; + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned i = sep; i < params.size(); i++) { + params[i % sep] += params[i]; + } + break; + case NumberSpace::LOGARITHM: + for (unsigned i = sep; i < params.size(); i++) { + Util::logSum (params[i % sep], params[i]); + } + } + params.resize (sep); + varids_.erase (varids_.begin()); + ranges_.erase (ranges_.begin()); + dist_->entries.clear(); +} + + + +void +Factor::removeLastVariable (void) +{ + ParamSet& params = dist_->params; + unsigned nStates = ranges_.back(); + unsigned idx1 = 0; + unsigned idx2 = 0; + switch (NSPACE) { + case NumberSpace::NORMAL: + while (idx1 < params.size()) { + params[idx2] = params[idx1]; + idx1 ++; + for (unsigned j = 1; j < nStates; j++) { + params[idx2] += params[idx1]; + idx1 ++; + } + idx2 ++; + } + break; + case NumberSpace::LOGARITHM: + while (idx1 < params.size()) { + params[idx2] = params[idx1]; + idx1 ++; + for (unsigned j = 1; j < nStates; j++) { + Util::logSum (params[idx2], params[idx1]); + idx1 ++; + } + idx2 ++; + } + } + params.resize (idx2); + varids_.pop_back(); + ranges_.pop_back(); + dist_->entries.clear(); +} + + + +void +Factor::orderVariables (void) +{ + VarIdSet sortedVarIds = varids_; + sort (sortedVarIds.begin(), sortedVarIds.end()); + orderVariables (sortedVarIds); +} + + + +void +Factor::orderVariables (const VarIdSet& newVarIdOrder) +{ + assert (newVarIdOrder.size() == varids_.size()); + if (newVarIdOrder == varids_) { + return; + } + + Ranges newRangeOrder; + for (unsigned i = 0; i < newVarIdOrder.size(); i++) { + unsigned pos = getPositionOf (newVarIdOrder[i]); + newRangeOrder.push_back (ranges_[pos]); + } + + vector positions; + for (unsigned i = 0; i < newVarIdOrder.size(); i++) { + positions.push_back (getPositionOf (newVarIdOrder[i])); + } + + unsigned N = ranges_.size(); + ParamSet newPs (dist_->params.size()); + for (unsigned i = 0; i < dist_->params.size(); i++) { + unsigned li = i; + // calculate vector index corresponding to linear index + vector vi (N); + for (int k = N-1; k >= 0; k--) { + vi[k] = li % ranges_[k]; + li /= ranges_[k]; + } + // convert permuted vector index to corresponding linear index + unsigned prod = 1; + unsigned new_li = 0; + for (int k = N-1; k >= 0; k--) { + new_li += vi[positions[k]] * prod; + prod *= ranges_[positions[k]]; + } + newPs[new_li] = dist_->params[i]; + } + varids_ = newVarIdOrder; + ranges_ = newRangeOrder; + dist_->params = newPs; + dist_->entries.clear(); +} + + + +void +Factor::removeInconsistentEntries (VarId vid, unsigned evidence) +{ + int pos = getPositionOf (vid); + assert (pos != -1); + ParamSet newPs; + newPs.reserve (dist_->params.size() / ranges_[pos]); + StatesIndexer idx (ranges_); + for (unsigned i = 0; i < evidence; i++) { + idx.incrementState (pos); + } + while (idx.valid()) { + newPs.push_back (dist_->params[idx.getLinearIndex()]); + idx.nextSameState (pos); + } + varids_.erase (varids_.begin() + pos); + ranges_.erase (ranges_.begin() + pos); + dist_->updateParameters (newPs); + dist_->entries.clear(); +} + + + +string +Factor::getLabel (void) const +{ + stringstream ss; + ss << "f(" ; + for (unsigned i = 0; i < varids_.size(); i++) { + if (i != 0) ss << "," ; + ss << VarNode (varids_[i], ranges_[i]).label(); + } + ss << ")" ; + return ss.str(); +} + + + +void +Factor::printFactor (void) const +{ + VarNodes vars; + for (unsigned i = 0; i < varids_.size(); i++) { + vars.push_back (new VarNode (varids_[i], ranges_[i])); + } + vector jointStrings = Util::getJointStateStrings (vars); + for (unsigned i = 0; i < dist_->params.size(); i++) { + cout << "f(" << jointStrings[i] << ")" ; + cout << " = " << dist_->params[i] << endl; + } + for (unsigned i = 0; i < vars.size(); i++) { + delete vars[i]; + } +} + + + +int +Factor::getPositionOf (VarId vid) const +{ + for (unsigned i = 0; i < varids_.size(); i++) { + if (varids_[i] == vid) { + return i; + } + } + return -1; } @@ -242,21 +483,20 @@ Factor::getCptEntries (void) const if (dist_->entries.size() == 0) { vector confs (dist_->params.size()); for (unsigned i = 0; i < dist_->params.size(); i++) { - confs[i].resize (vars_.size()); + confs[i].resize (varids_.size()); } - unsigned nReps = 1; - for (int i = vars_.size() - 1; i >= 0; i--) { + for (int i = varids_.size() - 1; i >= 0; i--) { unsigned index = 0; while (index < dist_->params.size()) { - for (unsigned j = 0; j < vars_[i]->getDomainSize(); j++) { + for (unsigned j = 0; j < ranges_[i]; j++) { for (unsigned r = 0; r < nReps; r++) { - confs[index][i] = j; + confs[index][i] = j; index++; } } } - nReps *= vars_[i]->getDomainSize(); + nReps *= ranges_[i]; } dist_->entries.clear(); dist_->entries.reserve (dist_->params.size()); @@ -267,53 +507,3 @@ Factor::getCptEntries (void) const return dist_->entries; } - - -string -Factor::getLabel (void) const -{ - stringstream ss; - ss << "Φ(" ; - for (unsigned i = 0; i < vars_.size(); i++) { - if (i != 0) ss << "," ; - ss << vars_[i]->getLabel(); - } - ss << ")" ; - return ss.str(); -} - - - -void -Factor::printFactor (void) -{ - stringstream ss; - ss << getLabel() << endl; - ss << "--------------------" << endl; - VarSet vs; - for (unsigned i = 0; i < vars_.size(); i++) { - vs.push_back (vars_[i]); - } - vector domainConfs = Util::getInstantiations (vs); - const vector& entries = getCptEntries(); - for (unsigned i = 0; i < entries.size(); i++) { - ss << "Φ(" << domainConfs[i] << ")" ; - unsigned idx = entries[i].getParameterIndex(); - ss << " = " << dist_->params[idx] << endl; - } - cout << ss.str(); -} - - - -int -Factor::getIndexOf (const FgVarNode* var) const -{ - for (unsigned i = 0; i < vars_.size(); i++) { - if (vars_[i] == var) { - return i; - } - } - return -1; -} - diff --git a/packages/CLPBN/clpbn/bp/Factor.h b/packages/CLPBN/clpbn/bp/Factor.h index 58ea63eb9..b24046c7f 100644 --- a/packages/CLPBN/clpbn/bp/Factor.h +++ b/packages/CLPBN/clpbn/bp/Factor.h @@ -1,48 +1,69 @@ -#ifndef BP_FACTOR_H -#define BP_FACTOR_H +#ifndef HORUS_FACTOR_H +#define HORUS_FACTOR_H #include #include "Distribution.h" #include "CptEntry.h" +#include "VarNode.h" + using namespace std; -class FgVarNode; class Distribution; + class Factor { public: Factor (void) { } Factor (const Factor&); - Factor (FgVarNode*); - Factor (CFgVarSet); - Factor (FgVarNode*, const ParamSet&); - Factor (FgVarSet&, Distribution*); - Factor (CFgVarSet, CParamSet); + Factor (VarId, unsigned); + Factor (const VarNodes&); + Factor (VarId, unsigned, const ParamSet&); + Factor (VarNodes&, Distribution*); + Factor (const VarNodes&, const ParamSet&); + Factor (const VarIdSet&, const Ranges&, const ParamSet&); - void setParameters (CParamSet); - void copyFactor (const Factor& f); - void multiplyByFactor (const Factor& f, const vector* = 0); - void insertVariable (FgVarNode* index); - void removeVariable (const FgVarNode* var); - const vector& getCptEntries (void) const; + void setParameters (const ParamSet&); + void copyFromFactor (const Factor& f); + void multiplyByFactor (const Factor&, const vector* = 0); + void insertVariable (VarId, unsigned); + void removeAllVariablesExcept (VarId); + void removeVariable (VarId); + void removeFirstVariable (void); + void removeLastVariable (void); + void orderVariables (void); + void orderVariables (const VarIdSet&); + void removeInconsistentEntries (VarId, unsigned); string getLabel (void) const; - void printFactor (void); + void printFactor (void) const; + int getPositionOf (VarId) const; + const vector& getCptEntries (void) const; - CFgVarSet getFgVarNodes (void) const { return vars_; } - CParamSet getParameters (void) const { return dist_->params; } - Distribution* getDistribution (void) const { return dist_; } - unsigned getIndex (void) const { return index_; } - void setIndex (unsigned index) { index_ = index; } - void freeDistribution (void) { delete dist_; dist_ = 0;} - int getIndexOf (const FgVarNode*) const; + const VarIdSet& getVarIds (void) const { return varids_; } + const Ranges& getRanges (void) const { return ranges_; } + const ParamSet& getParameters (void) const { return dist_->params; } + Distribution* getDistribution (void) const { return dist_; } + unsigned nrVariables (void) const { return varids_.size(); } + unsigned nrParameters() const { return dist_->params.size(); } + + void setDistribution (Distribution* dist) + { + dist_ = dist; + } + void freeDistribution (void) + { + delete dist_; + dist_ = 0; + } private: - FgVarSet vars_; - Distribution* dist_; - unsigned index_; + + VarIdSet varids_; + Ranges ranges_; + Distribution* dist_; }; -#endif //BP_FACTOR_H +#endif // HORUS_FACTOR_H + diff --git a/packages/CLPBN/clpbn/bp/FactorGraph.cpp b/packages/CLPBN/clpbn/bp/FactorGraph.cpp index 198c918fe..e98716410 100644 --- a/packages/CLPBN/clpbn/bp/FactorGraph.cpp +++ b/packages/CLPBN/clpbn/bp/FactorGraph.cpp @@ -1,20 +1,50 @@ -#include -#include #include +#include +#include #include #include #include #include "FactorGraph.h" -#include "FgVarNode.h" #include "Factor.h" #include "BayesNet.h" -FactorGraph::FactorGraph (const char* fileName) + +FactorGraph::FactorGraph (const BayesNet& bn) { - ifstream is (fileName); + const BnNodeSet& nodes = bn.getBayesNodes(); + for (unsigned i = 0; i < nodes.size(); i++) { + FgVarNode* varNode = new FgVarNode (nodes[i]); + addVariable (varNode); + } + + for (unsigned i = 0; i < nodes.size(); i++) { + const BnNodeSet& parents = nodes[i]->getParents(); + if (!(nodes[i]->hasEvidence() && parents.size() == 0)) { + VarNodes neighs; + neighs.push_back (varNodes_[nodes[i]->getIndex()]); + for (unsigned j = 0; j < parents.size(); j++) { + neighs.push_back (varNodes_[parents[j]->getIndex()]); + } + FgFacNode* fn = new FgFacNode ( + new Factor (neighs, nodes[i]->getDistribution())); + addFactor (fn); + for (unsigned j = 0; j < neighs.size(); j++) { + addEdge (fn, static_cast (neighs[j])); + } + } + } + setIndexes(); +} + + + +void +FactorGraph::readFromUaiFormat (const char* fileName) +{ + ifstream is (fileName); if (!is.is_open()) { cerr << "error: cannot read from file " + std::string (fileName) << endl; abort(); @@ -29,90 +59,159 @@ FactorGraph::FactorGraph (const char* fileName) } while (is.peek() == '#' || is.peek() == '\n') getline (is, line); - int nVars; + unsigned nVars; is >> nVars; while (is.peek() == '#' || is.peek() == '\n') getline (is, line); vector domainSizes (nVars); - for (int i = 0; i < nVars; i++) { - int ds; + for (unsigned i = 0; i < nVars; i++) { + unsigned ds; is >> ds; domainSizes[i] = ds; } while (is.peek() == '#' || is.peek() == '\n') getline (is, line); - for (int i = 0; i < nVars; i++) { + for (unsigned i = 0; i < nVars; i++) { addVariable (new FgVarNode (i, domainSizes[i])); } - int nFactors; + unsigned nFactors; is >> nFactors; - for (int i = 0; i < nFactors; i++) { + for (unsigned i = 0; i < nFactors; i++) { while (is.peek() == '#' || is.peek() == '\n') getline (is, line); - int nFactorVars; + unsigned nFactorVars; is >> nFactorVars; - FgVarSet factorVars; - for (int j = 0; j < nFactorVars; j++) { - int vid; + VarNodes neighs; + for (unsigned j = 0; j < nFactorVars; j++) { + unsigned vid; is >> vid; - FgVarNode* var = getFgVarNode (vid); - if (!var) { + FgVarNode* neigh = getFgVarNode (vid); + if (!neigh) { cerr << "error: invalid variable identifier (" << vid << ")" << endl; abort(); } - factorVars.push_back (var); + neighs.push_back (neigh); } - Factor* f = new Factor (factorVars); - factors_.push_back (f); - for (unsigned j = 0; j < factorVars.size(); j++) { - factorVars[j]->addFactor (f); + FgFacNode* fn = new FgFacNode (new Factor (neighs)); + addFactor (fn); + for (unsigned j = 0; j < neighs.size(); j++) { + addEdge (fn, static_cast (neighs[j])); } } - for (int i = 0; i < nFactors; i++) { + for (unsigned i = 0; i < nFactors; i++) { while (is.peek() == '#' || is.peek() == '\n') getline (is, line); - int nParams; + unsigned nParams; is >> nParams; + if (facNodes_[i]->getParameters().size() != nParams) { + cerr << "error: invalid number of parameters for factor " ; + cerr << facNodes_[i]->getLabel() ; + cerr << ", expected: " << facNodes_[i]->getParameters().size(); + cerr << ", given: " << nParams << endl; + abort(); + } ParamSet params (nParams); - for (int j = 0; j < nParams; j++) { + for (unsigned j = 0; j < nParams; j++) { double param; is >> param; params[j] = param; } - factors_[i]->setParameters (params); + if (NSPACE == NumberSpace::LOGARITHM) { + Util::toLog (params); + } + facNodes_[i]->factor()->setParameters (params); } is.close(); - - for (unsigned i = 0; i < varNodes_.size(); i++) { - varNodes_[i]->setIndex (i); - } + setIndexes(); } -FactorGraph::FactorGraph (const BayesNet& bn) +void +FactorGraph::readFromLibDaiFormat (const char* fileName) { - const BnNodeSet& nodes = bn.getBayesNodes(); - for (unsigned i = 0; i < nodes.size(); i++) { - FgVarNode* varNode = new FgVarNode (nodes[i]); - varNode->setIndex (i); - addVariable (varNode); + ifstream is (fileName); + if (!is.is_open()) { + cerr << "error: cannot read from file " + std::string (fileName) << endl; + abort(); } - for (unsigned i = 0; i < nodes.size(); i++) { - const BnNodeSet& parents = nodes[i]->getParents(); - if (!(nodes[i]->hasEvidence() && parents.size() == 0)) { - FgVarSet factorVars = { varNodes_[nodes[i]->getIndex()] }; - for (unsigned j = 0; j < parents.size(); j++) { - factorVars.push_back (varNodes_[parents[j]->getIndex()]); - } - Factor* f = new Factor (factorVars, nodes[i]->getDistribution()); - factors_.push_back (f); - for (unsigned j = 0; j < factorVars.size(); j++) { - factorVars[j]->addFactor (f); + string line; + unsigned nFactors; + + while ((is.peek()) == '#') getline (is, line); + is >> nFactors; + + if (is.fail()) { + cerr << "error: cannot read the number of factors" << endl; + abort(); + } + + getline (is, line); + if (is.fail() || line.size() > 0) { + cerr << "error: cannot read the number of factors" << endl; + abort(); + } + + for (unsigned i = 0; i < nFactors; i++) { + unsigned nVars; + while ((is.peek()) == '#') getline (is, line); + + is >> nVars; + VarIdSet vids; + for (unsigned j = 0; j < nVars; j++) { + VarId vid; + while ((is.peek()) == '#') getline (is, line); + is >> vid; + vids.push_back (vid); + } + + VarNodes neighs; + unsigned nParams = 1; + for (unsigned j = 0; j < nVars; j++) { + unsigned dsize; + while ((is.peek()) == '#') getline (is, line); + is >> dsize; + FgVarNode* var = getFgVarNode (vids[j]); + if (var == 0) { + var = new FgVarNode (vids[j], dsize); + addVariable (var); + } else { + if (var->nrStates() != dsize) { + cerr << "error: variable `" << vids[j] << "' appears in two or " ; + cerr << "more factors with different domain sizes" << endl; + } } + neighs.push_back (var); + nParams *= var->nrStates(); + } + ParamSet params (nParams, 0); + unsigned nNonzeros; + while ((is.peek()) == '#') + getline (is, line); + is >> nNonzeros; + + for (unsigned j = 0; j < nNonzeros; j++) { + unsigned index; + Param val; + while ((is.peek()) == '#') getline (is, line); + is >> index; + while ((is.peek()) == '#') getline (is, line); + is >> val; + params[index] = val; + } + reverse (neighs.begin(), neighs.end()); + if (NSPACE == NumberSpace::LOGARITHM) { + Util::toLog (params); + } + FgFacNode* fn = new FgFacNode (new Factor (neighs, params)); + addFactor (fn); + for (unsigned j = 0; j < neighs.size(); j++) { + addEdge (fn, static_cast (neighs[j])); } } + is.close(); + setIndexes(); } @@ -122,82 +221,63 @@ FactorGraph::~FactorGraph (void) for (unsigned i = 0; i < varNodes_.size(); i++) { delete varNodes_[i]; } - for (unsigned i = 0; i < factors_.size(); i++) { - delete factors_[i]; + for (unsigned i = 0; i < facNodes_.size(); i++) { + delete facNodes_[i]; } } void -FactorGraph::addVariable (FgVarNode* varNode) +FactorGraph::addVariable (FgVarNode* vn) { - varNodes_.push_back (varNode); - varNode->setIndex (varNodes_.size() - 1); - indexMap_.insert (make_pair (varNode->getVarId(), varNodes_.size() - 1)); + varNodes_.push_back (vn); + vn->setIndex (varNodes_.size() - 1); + indexMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1)); } void -FactorGraph::removeVariable (const FgVarNode* var) +FactorGraph::addFactor (FgFacNode* fn) { - if (varNodes_[varNodes_.size() - 1] == var) { - varNodes_.pop_back(); - } else { - for (unsigned i = 0; i < varNodes_.size(); i++) { - if (varNodes_[i] == var) { - varNodes_.erase (varNodes_.begin() + i); - return; - } - } - assert (false); - } - indexMap_.erase (indexMap_.find (var->getVarId())); + facNodes_.push_back (fn); + fn->setIndex (facNodes_.size() - 1); +} + + +void +FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn) +{ + vn->addNeighbor (fn); + fn->addNeighbor (vn); } void -FactorGraph::addFactor (Factor* f) +FactorGraph::addEdge (FgFacNode* fn, FgVarNode* vn) { - factors_.push_back (f); - const FgVarSet& factorVars = f->getFgVarNodes(); - for (unsigned i = 0; i < factorVars.size(); i++) { - factorVars[i]->addFactor (f); - } + fn->addNeighbor (vn); + vn->addNeighbor (fn); } -void -FactorGraph::removeFactor (const Factor* f) +VarNode* +FactorGraph::getVariableNode (VarId vid) const { - const FgVarSet& factorVars = f->getFgVarNodes(); - for (unsigned i = 0; i < factorVars.size(); i++) { - if (factorVars[i]) { - factorVars[i]->removeFactor (f); - } - } - if (factors_[factors_.size() - 1] == f) { - factors_.pop_back(); - } else { - for (unsigned i = 0; i < factors_.size(); i++) { - if (factors_[i] == f) { - factors_.erase (factors_.begin() + i); - return; - } - } - assert (false); - } + FgVarNode* vn = getFgVarNode (vid); + assert (vn); + return vn; } -VarSet -FactorGraph::getVariables (void) const +VarNodes +FactorGraph::getVariableNodes (void) const { - VarSet vars; + VarNodes vars; for (unsigned i = 0; i < varNodes_.size(); i++) { vars.push_back (varNodes_[i]); } @@ -206,10 +286,10 @@ FactorGraph::getVariables (void) const -Variable* -FactorGraph::getVariable (Vid vid) const +bool +FactorGraph::isTree (void) const { - return getFgVarNode (vid); + return !containsCycle(); } @@ -220,8 +300,8 @@ FactorGraph::setIndexes (void) for (unsigned i = 0; i < varNodes_.size(); i++) { varNodes_[i]->setIndex (i); } - for (unsigned i = 0; i < factors_.size(); i++) { - factors_[i]->setIndex (i); + for (unsigned i = 0; i < facNodes_.size(); i++) { + facNodes_[i]->setIndex (i); } } @@ -231,8 +311,8 @@ void FactorGraph::freeDistributions (void) { set dists; - for (unsigned i = 0; i < factors_.size(); i++) { - dists.insert (factors_[i]->getDistribution()); + for (unsigned i = 0; i < facNodes_.size(); i++) { + dists.insert (facNodes_[i]->factor()->getDistribution()); } for (set::iterator it = dists.begin(); it != dists.end(); it++) { @@ -246,19 +326,18 @@ void FactorGraph::printGraphicalModel (void) const { for (unsigned i = 0; i < varNodes_.size(); i++) { - cout << "variable number " << varNodes_[i]->getIndex() << endl; - cout << "Id = " << varNodes_[i]->getVarId() << endl; - cout << "Label = " << varNodes_[i]->getLabel() << endl; - cout << "Domain size = " << varNodes_[i]->getDomainSize() << endl; + cout << "VarId = " << varNodes_[i]->varId() << endl; + cout << "Label = " << varNodes_[i]->label() << endl; + cout << "Nr States = " << varNodes_[i]->nrStates() << endl; cout << "Evidence = " << varNodes_[i]->getEvidence() << endl; cout << "Factors = " ; - for (unsigned j = 0; j < varNodes_[i]->getFactors().size(); j++) { - cout << varNodes_[i]->getFactors()[j]->getLabel() << " " ; + for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) { + cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ; } cout << endl << endl; } - for (unsigned i = 0; i < factors_.size(); i++) { - factors_[i]->printFactor(); + for (unsigned i = 0; i < facNodes_.size(); i++) { + facNodes_[i]->factor()->printFactor(); cout << endl; } } @@ -266,7 +345,7 @@ FactorGraph::printGraphicalModel (void) const void -FactorGraph::exportToDotFormat (const char* fileName) const +FactorGraph::exportToGraphViz (const char* fileName) const { ofstream out (fileName); if (!out.is_open()) { @@ -279,24 +358,23 @@ FactorGraph::exportToDotFormat (const char* fileName) const for (unsigned i = 0; i < varNodes_.size(); i++) { if (varNodes_[i]->hasEvidence()) { - out << '"' << varNodes_[i]->getLabel() << '"' ; + out << '"' << varNodes_[i]->label() << '"' ; out << " [style=filled, fillcolor=yellow]" << endl; } } - for (unsigned i = 0; i < factors_.size(); i++) { - out << '"' << factors_[i]->getLabel() << '"' ; - out << " [label=\"" << factors_[i]->getLabel() << "\\n("; - out << factors_[i]->getDistribution()->id << ")" << "\"" ; - out << ", shape=box]" << endl; + for (unsigned i = 0; i < facNodes_.size(); i++) { + out << '"' << facNodes_[i]->getLabel() << '"' ; + out << " [label=\"" << facNodes_[i]->getLabel(); + out << "\"" << ", shape=box]" << endl; } - for (unsigned i = 0; i < factors_.size(); i++) { - CFgVarSet myVars = factors_[i]->getFgVarNodes(); + for (unsigned i = 0; i < facNodes_.size(); i++) { + const FgVarSet& myVars = facNodes_[i]->neighbors(); for (unsigned j = 0; j < myVars.size(); j++) { - out << '"' << factors_[i]->getLabel() << '"' ; + out << '"' << facNodes_[i]->getLabel() << '"' ; out << " -- " ; - out << '"' << myVars[j]->getLabel() << '"' << endl; + out << '"' << myVars[j]->label() << '"' << endl; } } @@ -319,13 +397,13 @@ FactorGraph::exportToUaiFormat (const char* fileName) const out << "MARKOV" << endl; out << varNodes_.size() << endl; for (unsigned i = 0; i < varNodes_.size(); i++) { - out << varNodes_[i]->getDomainSize() << " " ; + out << varNodes_[i]->nrStates() << " " ; } out << endl; - out << factors_.size() << endl; - for (unsigned i = 0; i < factors_.size(); i++) { - CFgVarSet factorVars = factors_[i]->getFgVarNodes(); + out << facNodes_.size() << endl; + for (unsigned i = 0; i < facNodes_.size(); i++) { + const FgVarSet& factorVars = facNodes_[i]->neighbors(); out << factorVars.size(); for (unsigned j = 0; j < factorVars.size(); j++) { out << " " << factorVars[j]->getIndex(); @@ -333,8 +411,8 @@ FactorGraph::exportToUaiFormat (const char* fileName) const out << endl; } - for (unsigned i = 0; i < factors_.size(); i++) { - CParamSet params = factors_[i]->getParameters(); + for (unsigned i = 0; i < facNodes_.size(); i++) { + const ParamSet& params = facNodes_[i]->getParameters(); out << endl << params.size() << endl << " " ; for (unsigned j = 0; j < params.size(); j++) { out << params[j] << " " ; @@ -345,3 +423,102 @@ FactorGraph::exportToUaiFormat (const char* fileName) const out.close(); } + + +void +FactorGraph::exportToLibDaiFormat (const char* fileName) const +{ + ofstream out (fileName); + if (!out.is_open()) { + cerr << "error: cannot open file to write at " ; + cerr << "FactorGraph::exportToLibDaiFormat()" << endl; + abort(); + } + out << facNodes_.size() << endl << endl; + for (unsigned i = 0; i < facNodes_.size(); i++) { + const FgVarSet& factorVars = facNodes_[i]->neighbors(); + out << factorVars.size() << endl; + for (int j = factorVars.size() - 1; j >= 0; j--) { + out << factorVars[j]->varId() << " " ; + } + out << endl; + for (unsigned j = 0; j < factorVars.size(); j++) { + out << factorVars[j]->nrStates() << " " ; + } + out << endl; + const ParamSet& params = facNodes_[i]->factor()->getParameters(); + out << params.size() << endl; + for (unsigned j = 0; j < params.size(); j++) { + out << j << " " << params[j] << endl; + } + out << endl; + } + out.close(); +} + + + +bool +FactorGraph::containsCycle (void) const +{ + vector visitedVars (varNodes_.size(), false); + vector visitedFactors (facNodes_.size(), false); + for (unsigned i = 0; i < varNodes_.size(); i++) { + int v = varNodes_[i]->getIndex(); + if (!visitedVars[v]) { + if (containsCycle (varNodes_[i], 0, visitedVars, visitedFactors)) { + return true; + } + } + } + return false; +} + + + +bool +FactorGraph::containsCycle (const FgVarNode* v, + const FgFacNode* p, + vector& visitedVars, + vector& visitedFactors) const +{ + visitedVars[v->getIndex()] = true; + const FgFacSet& adjacencies = v->neighbors(); + for (unsigned i = 0; i < adjacencies.size(); i++) { + int w = adjacencies[i]->getIndex(); + if (!visitedFactors[w]) { + if (containsCycle (adjacencies[i], v, visitedVars, visitedFactors)) { + return true; + } + } + else if (visitedFactors[w] && adjacencies[i] != p) { + return true; + } + } + return false; // no cycle detected in this component +} + + + +bool +FactorGraph::containsCycle (const FgFacNode* v, + const FgVarNode* p, + vector& visitedVars, + vector& visitedFactors) const +{ + visitedFactors[v->getIndex()] = true; + const FgVarSet& adjacencies = v->neighbors(); + for (unsigned i = 0; i < adjacencies.size(); i++) { + int w = adjacencies[i]->getIndex(); + if (!visitedVars[w]) { + if (containsCycle (adjacencies[i], v, visitedVars, visitedFactors)) { + return true; + } + } + else if (visitedVars[w] && adjacencies[i] != p) { + return true; + } + } + return false; // no cycle detected in this component +} + diff --git a/packages/CLPBN/clpbn/bp/FactorGraph.h b/packages/CLPBN/clpbn/bp/FactorGraph.h index 297d02a6b..3700d731d 100644 --- a/packages/CLPBN/clpbn/bp/FactorGraph.h +++ b/packages/CLPBN/clpbn/bp/FactorGraph.h @@ -1,41 +1,116 @@ -#ifndef BP_FACTOR_GRAPH_H -#define BP_FACTOR_GRAPH_H +#ifndef HORUS_FACTORGRAPH_H +#define HORUS_FACTORGRAPH_H #include #include "GraphicalModel.h" #include "Shared.h" +#include "Distribution.h" +#include "Factor.h" using namespace std; -class FgVarNode; -class Factor; -class BayesNet; + +class FgFacNode; + +class FgVarNode : public VarNode +{ + public: + FgVarNode (VarId varId, unsigned nrStates) : VarNode (varId, nrStates) { } + FgVarNode (const VarNode* v) : VarNode (v) { } + void addNeighbor (FgFacNode* fn) + { + neighs_.push_back (fn); + } + const vector& neighbors (void) const + { + return neighs_; + } + + private: + DISALLOW_COPY_AND_ASSIGN (FgVarNode); + // members + vector neighs_; +}; + + +class FgFacNode +{ + public: + FgFacNode (Factor* factor) + { + factor_ = factor; + index_ = -1; + } + Factor* factor() const + { + return factor_; + } + void addNeighbor (FgVarNode* vn) + { + neighs_.push_back (vn); + } + const vector& neighbors (void) const + { + return neighs_; + } + int getIndex (void) const + { + assert (index_ != -1); + return index_; + } + void setIndex (int index) + { + index_ = index; + } + Distribution* getDistribution (void) + { + return factor_->getDistribution(); + } + const ParamSet& getParameters (void) const + { + return factor_->getParameters(); + } + string getLabel (void) + { + return factor_->getLabel(); + } + private: + DISALLOW_COPY_AND_ASSIGN (FgFacNode); + + Factor* factor_; + int index_; + vector neighs_; +}; + class FactorGraph : public GraphicalModel { public: FactorGraph (void) {}; - FactorGraph (const char*); FactorGraph (const BayesNet&); ~FactorGraph (void); - + + void readFromUaiFormat (const char*); + void readFromLibDaiFormat (const char*); void addVariable (FgVarNode*); - void removeVariable (const FgVarNode*); - void addFactor (Factor*); - void removeFactor (const Factor*); - VarSet getVariables (void) const; - Variable* getVariable (unsigned) const; + void addFactor (FgFacNode*); + void addEdge (FgVarNode*, FgFacNode*); + void addEdge (FgFacNode*, FgVarNode*); + VarNode* getVariableNode (unsigned) const; + VarNodes getVariableNodes (void) const; + bool isTree (void) const; void setIndexes (void); void freeDistributions (void); void printGraphicalModel (void) const; - void exportToDotFormat (const char*) const; + void exportToGraphViz (const char*) const; void exportToUaiFormat (const char*) const; - - const FgVarSet& getFgVarNodes (void) const { return varNodes_; } - const FactorSet& getFactors (void) const { return factors_; } + void exportToLibDaiFormat (const char*) const; + + const FgVarSet& getVarNodes (void) const { return varNodes_; } + const FgFacSet& getFactorNodes (void) const { return facNodes_; } - FgVarNode* getFgVarNode (Vid vid) const + FgVarNode* getFgVarNode (VarId vid) const { IndexMap::const_iterator it = indexMap_.find (vid); if (it == indexMap_.end()) { @@ -46,12 +121,20 @@ class FactorGraph : public GraphicalModel } private: + bool containsCycle (void) const; + bool containsCycle (const FgVarNode*, const FgFacNode*, + vector&, vector&) const; + bool containsCycle (const FgFacNode*, const FgVarNode*, + vector&, vector&) const; + DISALLOW_COPY_AND_ASSIGN (FactorGraph); - FgVarSet varNodes_; - FactorSet factors_; - IndexMap indexMap_; + FgVarSet varNodes_; + FgFacSet facNodes_; + + typedef unordered_map IndexMap; + IndexMap indexMap_; }; -#endif // BP_FACTOR_GRAPH_H +#endif // HORUS_FACTORGRAPH_H diff --git a/packages/CLPBN/clpbn/bp/FgBpSolver.cpp b/packages/CLPBN/clpbn/bp/FgBpSolver.cpp new file mode 100644 index 000000000..b7d5d301e --- /dev/null +++ b/packages/CLPBN/clpbn/bp/FgBpSolver.cpp @@ -0,0 +1,499 @@ +#include +#include + +#include + +#include + +#include "FgBpSolver.h" +#include "FactorGraph.h" +#include "Factor.h" +#include "Shared.h" + + +FgBpSolver::FgBpSolver (const FactorGraph& fg) : Solver (&fg) +{ + factorGraph_ = &fg; +} + + + +FgBpSolver::~FgBpSolver (void) +{ + for (unsigned i = 0; i < varsI_.size(); i++) { + delete varsI_[i]; + } + for (unsigned i = 0; i < facsI_.size(); i++) { + delete facsI_[i]; + } + for (unsigned i = 0; i < links_.size(); i++) { + delete links_[i]; + } +} + + + +void +FgBpSolver::runSolver (void) +{ + clock_t start; + if (COLLECT_STATISTICS) { + start = clock(); + } + if (false) { + //if (!BpOptions::useAlwaysLoopySolver && factorGraph_->isTree()) { + runTreeSolver(); + } else { + runLoopySolver(); + if (DL >= 2) { + cout << endl; + if (nIters_ < BpOptions::maxIter) { + cout << "Sum-Product converged in " ; + cout << nIters_ << " iterations" << endl; + } else { + cout << "The maximum number of iterations was hit, terminating..." ; + cout << endl; + } + } + } + unsigned size = factorGraph_->getVarNodes().size(); + if (COLLECT_STATISTICS) { + unsigned nIters = 0; + bool loopy = factorGraph_->isTree() == false; + if (loopy) nIters = nIters_; + double time = (double (clock() - start)) / CLOCKS_PER_SEC; + Statistics::updateStatistics (size, loopy, nIters, time); + } + if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) { + stringstream ss; + ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ; + factorGraph_->exportToGraphViz (ss.str().c_str()); + } +} + + + +ParamSet +FgBpSolver::getPosterioriOf (VarId vid) +{ + assert (factorGraph_->getFgVarNode (vid)); + FgVarNode* var = factorGraph_->getFgVarNode (vid); + ParamSet probs; + + if (var->hasEvidence()) { + probs.resize (var->nrStates(), Util::noEvidence()); + probs[var->getEvidence()] = Util::withEvidence(); + } else { + probs.resize (var->nrStates(), Util::multIdenty()); + const SpLinkSet& links = ninf(var)->getLinks(); + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned i = 0; i < links.size(); i++) { + Util::multiply (probs, links[i]->getMessage()); + } + Util::normalize (probs); + break; + case NumberSpace::LOGARITHM: + for (unsigned i = 0; i < links.size(); i++) { + Util::add (probs, links[i]->getMessage()); + } + Util::normalize (probs); + Util::fromLog (probs); + } + } + return probs; +} + + + +ParamSet +FgBpSolver::getJointDistributionOf (const VarIdSet& jointVarIds) +{ + unsigned msgSize = 1; + vector dsizes (jointVarIds.size()); + for (unsigned i = 0; i < jointVarIds.size(); i++) { + dsizes[i] = factorGraph_->getFgVarNode (jointVarIds[i])->nrStates(); + msgSize *= dsizes[i]; + } + unsigned reps = 1; + ParamSet jointDist (msgSize, Util::multIdenty()); + for (int i = jointVarIds.size() - 1 ; i >= 0; i--) { + Util::multiply (jointDist, getPosterioriOf (jointVarIds[i]), reps); + reps *= dsizes[i]; + } + return jointDist; +} + + + +void +FgBpSolver::runTreeSolver (void) +{ + initializeSolver(); + const FgFacSet& facNodes = factorGraph_->getFactorNodes(); + bool finish = false; + while (!finish) { + finish = true; + for (unsigned i = 0; i < facNodes.size(); i++) { + const SpLinkSet& links = ninf (facNodes[i])->getLinks(); + for (unsigned j = 0; j < links.size(); j++) { + if (!links[j]->messageWasSended()) { + if (readyToSendMessage (links[j])) { + calculateAndUpdateMessage (links[j], false); + } + finish = false; + } + } + } + } +} + + + +bool +FgBpSolver::readyToSendMessage (const SpLink* link) const +{ + const FgVarSet& neighbors = link->getFactor()->neighbors(); + for (unsigned i = 0; i < neighbors.size(); i++) { + if (neighbors[i] != link->getVariable()) { + const SpLinkSet& links = ninf (neighbors[i])->getLinks(); + for (unsigned j = 0; j < links.size(); j++) { + if (links[j]->getFactor() != link->getFactor() && + !links[j]->messageWasSended()) { + return false; + } + } + } + } + return true; +} + + + +void +FgBpSolver::runLoopySolver (void) +{ + initializeSolver(); + nIters_ = 0; + + while (!converged() && nIters_ < BpOptions::maxIter) { + + nIters_ ++; + if (DL >= 2) { + cout << "****************************************" ; + cout << "****************************************" ; + cout << endl; + cout << " Iteration " << nIters_ << endl; + cout << "****************************************" ; + cout << "****************************************" ; + cout << endl; + } + + switch (BpOptions::schedule) { + case BpOptions::Schedule::SEQ_RANDOM: + random_shuffle (links_.begin(), links_.end()); + // no break + + case BpOptions::Schedule::SEQ_FIXED: + for (unsigned i = 0; i < links_.size(); i++) { + calculateAndUpdateMessage (links_[i]); + } + break; + + case BpOptions::Schedule::PARALLEL: + for (unsigned i = 0; i < links_.size(); i++) { + calculateMessage (links_[i]); + } + for (unsigned i = 0; i < links_.size(); i++) { + updateMessage(links_[i]); + } + break; + + case BpOptions::Schedule::MAX_RESIDUAL: + maxResidualSchedule(); + break; + } + if (DL >= 2) { + cout << endl; + } + } +} + + + +void +FgBpSolver::initializeSolver (void) +{ + const FgVarSet& varNodes = factorGraph_->getVarNodes(); + for (unsigned i = 0; i < varsI_.size(); i++) { + delete varsI_[i]; + } + varsI_.reserve (varNodes.size()); + for (unsigned i = 0; i < varNodes.size(); i++) { + varsI_.push_back (new SPNodeInfo()); + } + + const FgFacSet& facNodes = factorGraph_->getFactorNodes(); + for (unsigned i = 0; i < facsI_.size(); i++) { + delete facsI_[i]; + } + facsI_.reserve (facNodes.size()); + for (unsigned i = 0; i < facNodes.size(); i++) { + facsI_.push_back (new SPNodeInfo()); + } + + for (unsigned i = 0; i < links_.size(); i++) { + delete links_[i]; + } + createLinks(); + + for (unsigned i = 0; i < links_.size(); i++) { + FgFacNode* src = links_[i]->getFactor(); + FgVarNode* dst = links_[i]->getVariable(); + ninf (dst)->addSpLink (links_[i]); + ninf (src)->addSpLink (links_[i]); + } +} + + + +void +FgBpSolver::createLinks (void) +{ + const FgFacSet& facNodes = factorGraph_->getFactorNodes(); + for (unsigned i = 0; i < facNodes.size(); i++) { + const FgVarSet& neighbors = facNodes[i]->neighbors(); + for (unsigned j = 0; j < neighbors.size(); j++) { + links_.push_back (new SpLink (facNodes[i], neighbors[j])); + } + } +} + + + +bool +FgBpSolver::converged (void) +{ + if (links_.size() == 0) { + return true; + } + if (nIters_ == 0 || nIters_ == 1) { + return false; + } + bool converged = true; + if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) { + Param maxResidual = (*(sortedOrder_.begin()))->getResidual(); + if (maxResidual > BpOptions::accuracy) { + converged = false; + } else { + converged = true; + } + } else { + for (unsigned i = 0; i < links_.size(); i++) { + double residual = links_[i]->getResidual(); + if (DL >= 2) { + cout << links_[i]->toString() + " residual = " << residual << endl; + } + if (residual > BpOptions::accuracy) { + converged = false; + if (DL == 0) break; + } + } + } + return converged; +} + + + +void +FgBpSolver::maxResidualSchedule (void) +{ + if (nIters_ == 1) { + for (unsigned i = 0; i < links_.size(); i++) { + calculateMessage (links_[i]); + SortedOrder::iterator it = sortedOrder_.insert (links_[i]); + linkMap_.insert (make_pair (links_[i], it)); + } + return; + } + + for (unsigned c = 0; c < links_.size(); c++) { + if (DL >= 2) { + cout << "current residuals:" << endl; + for (SortedOrder::iterator it = sortedOrder_.begin(); + it != sortedOrder_.end(); it ++) { + cout << " " << setw (30) << left << (*it)->toString(); + cout << "residual = " << (*it)->getResidual() << endl; + } + } + + SortedOrder::iterator it = sortedOrder_.begin(); + SpLink* link = *it; + if (link->getResidual() < BpOptions::accuracy) { + return; + } + updateMessage (link); + link->clearResidual(); + sortedOrder_.erase (it); + linkMap_.find (link)->second = sortedOrder_.insert (link); + + // update the messages that depend on message source --> destin + const FgFacSet& factorNeighbors = link->getVariable()->neighbors(); + for (unsigned i = 0; i < factorNeighbors.size(); i++) { + if (factorNeighbors[i] != link->getFactor()) { + const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks(); + for (unsigned j = 0; j < links.size(); j++) { + if (links[j]->getVariable() != link->getVariable()) { + calculateMessage (links[j]); + SpLinkMap::iterator iter = linkMap_.find (links[j]); + sortedOrder_.erase (iter->second); + iter->second = sortedOrder_.insert (links[j]); + } + } + } + } + if (DL >= 2) { + cout << "----------------------------------------" ; + cout << "----------------------------------------" << endl; + } + } +} + + + +void +FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const +{ + const FgFacNode* src = link->getFactor(); + const FgVarNode* dst = link->getVariable(); + const SpLinkSet& links = ninf(src)->getLinks(); + // calculate the product of messages that were sent + // to factor `src', except from var `dst' + unsigned msgSize = 1; + for (unsigned i = 0; i < links.size(); i++) { + msgSize *= links[i]->getVariable()->nrStates(); + } + unsigned repetitions = 1; + ParamSet msgProduct (msgSize, Util::multIdenty()); + switch (NSPACE) { + case NumberSpace::NORMAL: + for (int i = links.size() - 1; i >= 0; i--) { + if (links[i]->getVariable() != dst) { + if (DL >= 5) { + cout << " message from " << links[i]->getVariable()->label(); + cout << ": " << endl; + } + Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions); + repetitions *= links[i]->getVariable()->nrStates(); + } else { + unsigned ds = links[i]->getVariable()->nrStates(); + Util::multiply (msgProduct, ParamSet (ds, 1.0), repetitions); + repetitions *= ds; + } + } + break; + case NumberSpace::LOGARITHM: + for (int i = links.size() - 1; i >= 0; i--) { + if (links[i]->getVariable() != dst) { + Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions); + repetitions *= links[i]->getVariable()->nrStates(); + } else { + unsigned ds = links[i]->getVariable()->nrStates(); + Util::add (msgProduct, ParamSet (ds, 1.0), repetitions); + repetitions *= ds; + } + } + } + Factor result (src->factor()->getVarIds(), + src->factor()->getRanges(), + msgProduct); + result.multiplyByFactor (*(src->factor())); + if (DL >= 5) { + cout << " message product: " ; + cout << Util::parametersToString (msgProduct) << endl; + cout << " original factor: " ; + cout << Util::parametersToString (src->getParameters()) << endl; + cout << " factor product: " ; + cout << Util::parametersToString (result.getParameters()) << endl; + } + result.removeAllVariablesExcept (dst->varId()); + if (DL >= 5) { + cout << " marginalized: " ; + cout << Util::parametersToString (result.getParameters()) << endl; + } + const ParamSet& resultParams = result.getParameters(); + ParamSet& message = link->getNextMessage(); + for (unsigned i = 0; i < resultParams.size(); i++) { + message[i] = resultParams[i]; + } + Util::normalize (message); + if (DL >= 5) { + cout << " curr msg: " ; + cout << Util::parametersToString (link->getMessage()) << endl; + cout << " next msg: " ; + cout << Util::parametersToString (message) << endl; + } + result.freeDistribution(); +} + + + +ParamSet +FgBpSolver::getVar2FactorMsg (const SpLink* link) const +{ + const FgVarNode* src = link->getVariable(); + const FgFacNode* dst = link->getFactor(); + ParamSet msg; + if (src->hasEvidence()) { + msg.resize (src->nrStates(), Util::noEvidence()); + msg[src->getEvidence()] = Util::withEvidence(); + if (DL >= 5) { + cout << Util::parametersToString (msg); + } + } else { + msg.resize (src->nrStates(), Util::one()); + } + if (DL >= 5) { + cout << Util::parametersToString (msg); + } + const SpLinkSet& links = ninf (src)->getLinks(); + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned i = 0; i < links.size(); i++) { + if (links[i]->getFactor() != dst) { + Util::multiply (msg, links[i]->getMessage()); + if (DL >= 5) { + cout << " x " << Util::parametersToString (links[i]->getMessage()); + } + } + } + break; + case NumberSpace::LOGARITHM: + for (unsigned i = 0; i < links.size(); i++) { + if (links[i]->getFactor() != dst) { + Util::add (msg, links[i]->getMessage()); + } + } + } + if (DL >= 5) { + cout << " = " << Util::parametersToString (msg); + } + return msg; +} + + + +void +FgBpSolver::printLinkInformation (void) const +{ + for (unsigned i = 0; i < links_.size(); i++) { + SpLink* l = links_[i]; + cout << l->toString() << ":" << endl; + cout << " curr msg = " ; + cout << Util::parametersToString (l->getMessage()) << endl; + cout << " next msg = " ; + cout << Util::parametersToString (l->getNextMessage()) << endl; + cout << " residual = " << l->getResidual() << endl; + } +} + diff --git a/packages/CLPBN/clpbn/bp/FgBpSolver.h b/packages/CLPBN/clpbn/bp/FgBpSolver.h new file mode 100644 index 000000000..67fca0697 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/FgBpSolver.h @@ -0,0 +1,175 @@ +#ifndef HORUS_FGBPSOLVER_H +#define HORUS_FGBPSOLVER_H + +#include +#include +#include + +#include "Solver.h" +#include "Factor.h" +#include "FactorGraph.h" + +using namespace std; + + + +class SpLink +{ + public: + SpLink (FgFacNode* fn, FgVarNode* vn) + { + fac_ = fn; + var_ = vn; + v1_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates())); + v2_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates())); + currMsg_ = &v1_; + nextMsg_ = &v2_; + msgSended_ = false; + residual_ = 0.0; + } + + virtual ~SpLink (void) {}; + + virtual void updateMessage (void) + { + swap (currMsg_, nextMsg_); + msgSended_ = true; + } + + void updateResidual (void) + { + residual_ = Util::getMaxNorm (v1_, v2_); + } + + string toString (void) const + { + stringstream ss; + ss << fac_->getLabel(); + ss << " -- " ; + ss << var_->label(); + return ss.str(); + } + + FgFacNode* getFactor (void) const { return fac_; } + FgVarNode* getVariable (void) const { return var_; } + const ParamSet& getMessage (void) const { return *currMsg_; } + ParamSet& getNextMessage (void) { return *nextMsg_; } + bool messageWasSended (void) const { return msgSended_; } + double getResidual (void) const { return residual_; } + void clearResidual (void) { residual_ = 0.0; } + + protected: + FgFacNode* fac_; + FgVarNode* var_; + ParamSet v1_; + ParamSet v2_; + ParamSet* currMsg_; + ParamSet* nextMsg_; + bool msgSended_; + double residual_; +}; + + +typedef vector SpLinkSet; + + +class SPNodeInfo +{ + public: + void addSpLink (SpLink* link) { links_.push_back (link); } + const SpLinkSet& getLinks (void) { return links_; } + + private: + SpLinkSet links_; +}; + + +class FgBpSolver : public Solver +{ + public: + FgBpSolver (const FactorGraph&); + virtual ~FgBpSolver (void); + + void runSolver (void); + virtual ParamSet getPosterioriOf (VarId); + virtual ParamSet getJointDistributionOf (const VarIdSet&); + + protected: + virtual void initializeSolver (void); + virtual void createLinks (void); + virtual void maxResidualSchedule (void); + virtual void calculateFactor2VariableMsg (SpLink*) const; + virtual ParamSet getVar2FactorMsg (const SpLink*) const; + virtual void printLinkInformation (void) const; + + void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true) + { + if (DL >= 3) { + cout << "calculating & updating " << link->toString() << endl; + } + calculateFactor2VariableMsg (link); + if (calcResidual) { + link->updateResidual(); + } + link->updateMessage(); + } + + void calculateMessage (SpLink* link, bool calcResidual = true) + { + if (DL >= 3) { + cout << "calculating " << link->toString() << endl; + } + calculateFactor2VariableMsg (link); + if (calcResidual) { + link->updateResidual(); + } + } + + void updateMessage (SpLink* link) + { + link->updateMessage(); + if (DL >= 3) { + cout << "updating " << link->toString() << endl; + } + } + + SPNodeInfo* ninf (const FgVarNode* var) const + { + return varsI_[var->getIndex()]; + } + + SPNodeInfo* ninf (const FgFacNode* fac) const + { + return facsI_[fac->getIndex()]; + } + + struct CompareResidual { + inline bool operator() (const SpLink* link1, const SpLink* link2) + { + return link1->getResidual() > link2->getResidual(); + } + }; + + SpLinkSet links_; + unsigned nIters_; + vector varsI_; + vector facsI_; + const FactorGraph* factorGraph_; + + typedef multiset SortedOrder; + SortedOrder sortedOrder_; + + typedef unordered_map SpLinkMap; + SpLinkMap linkMap_; + + private: + void runTreeSolver (void); + bool readyToSendMessage (const SpLink*) const; + void runLoopySolver (void); + bool converged (void); + + +}; + +#endif // HORUS_FGBPSOLVER_H + diff --git a/packages/CLPBN/clpbn/bp/FgVarNode.h b/packages/CLPBN/clpbn/bp/FgVarNode.h deleted file mode 100644 index e46c88f57..000000000 --- a/packages/CLPBN/clpbn/bp/FgVarNode.h +++ /dev/null @@ -1,43 +0,0 @@ -#ifndef BP_FG_VAR_NODE_H -#define BP_FG_VAR_NODE_H - -#include - -#include "Variable.h" -#include "Shared.h" - -using namespace std; - -class Factor; - -class FgVarNode : public Variable -{ - public: - FgVarNode (unsigned vid, unsigned dsize) : Variable (vid, dsize) { } - FgVarNode (const Variable* v) : Variable (v) { } - - void addFactor (Factor* f) { factors_.push_back (f); } - CFactorSet getFactors (void) const { return factors_; } - - void removeFactor (const Factor* f) - { - if (factors_[factors_.size() -1] == f) { - factors_.pop_back(); - } else { - for (unsigned i = 0; i < factors_.size(); i++) { - if (factors_[i] == f) { - factors_.erase (factors_.begin() + i); - return; - } - } - assert (false); - } - } - - private: - DISALLOW_COPY_AND_ASSIGN (FgVarNode); - // members - FactorSet factors_; -}; - -#endif // BP_FG_VAR_NODE_H diff --git a/packages/CLPBN/clpbn/bp/GraphicalModel.h b/packages/CLPBN/clpbn/bp/GraphicalModel.h index e357f2d27..1e4c4d28a 100644 --- a/packages/CLPBN/clpbn/bp/GraphicalModel.h +++ b/packages/CLPBN/clpbn/bp/GraphicalModel.h @@ -1,18 +1,54 @@ -#ifndef BP_GRAPHICAL_MODEL_H -#define BP_GRAPHICAL_MODEL_H +#ifndef HORUS_GRAPHICALMODEL_H +#define HORUS_GRAPHICALMODEL_H -#include "Variable.h" +#include "VarNode.h" #include "Shared.h" -using namespace std; +using namespace std; + +struct VariableInfo +{ + VariableInfo (string l, const States& sts) + { + label = l; + states = sts; + } + string label; + States states; +}; + class GraphicalModel { public: virtual ~GraphicalModel (void) {}; - virtual Variable* getVariable (Vid) const = 0; - virtual VarSet getVariables (void) const = 0; + virtual VarNode* getVariableNode (VarId) const = 0; + virtual VarNodes getVariableNodes (void) const = 0; virtual void printGraphicalModel (void) const = 0; + + static void addVariableInformation (VarId vid, string label, + const States& states) + { + assert (varsInfo_.find (vid) == varsInfo_.end()); + varsInfo_.insert (make_pair (vid, VariableInfo (label, states))); + } + static VariableInfo getVariableInformation (VarId vid) + { + assert (varsInfo_.find (vid) != varsInfo_.end()); + return varsInfo_.find (vid)->second; + } + static bool variablesHaveInformation (void) + { + return varsInfo_.size() != 0; + } + static void clearVariablesInformation (void) + { + varsInfo_.clear(); + } + + private: + static unordered_map varsInfo_; }; -#endif // BP_GRAPHICAL_MODEL_H +#endif // HORUS_GRAPHICALMODEL_H + diff --git a/packages/CLPBN/clpbn/bp/HorusCli.cpp b/packages/CLPBN/clpbn/bp/HorusCli.cpp index 5a0033bcd..0d49abc1a 100644 --- a/packages/CLPBN/clpbn/bp/HorusCli.cpp +++ b/packages/CLPBN/clpbn/bp/HorusCli.cpp @@ -5,15 +5,18 @@ #include "BayesNet.h" #include "FactorGraph.h" -#include "SPSolver.h" -#include "BPSolver.h" -#include "CountingBP.h" +#include "VarElimSolver.h" +#include "BnBpSolver.h" +#include "FgBpSolver.h" +#include "CbpSolver.h" + +#include "StatesIndexer.h" using namespace std; -void BayesianNetwork (int, const char* []); -void markovNetwork (int, const char* []); -void runSolver (Solver*, const VarSet&); +void processArguments (BayesNet&, int, const char* []); +void processArguments (FactorGraph&, int, const char* []); +void runSolver (Solver*, const VarNodes&); const string USAGE = "usage: \ ./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ; @@ -21,33 +24,7 @@ const string USAGE = "usage: \ int main (int argc, const char* argv[]) -{ - /* - FactorGraph fg; - FgVarNode* varNode1 = new FgVarNode (0, 2); - FgVarNode* varNode2 = new FgVarNode (1, 2); - FgVarNode* varNode3 = new FgVarNode (2, 2); - fg.addVariable (varNode1); - fg.addVariable (varNode2); - fg.addVariable (varNode3); - Distribution* dist = new Distribution (ParamSet() = {1.2, 1.4, 2.0, 0.4}); - fg.addFactor (new Factor (FgVarSet() = {varNode1, varNode2}, dist)); - fg.addFactor (new Factor (FgVarSet() = {varNode3, varNode2}, dist)); - //fg.printGraphicalModel(); - //SPSolver sp (fg); - //sp.runSolver(); - //sp.printAllPosterioris(); - //ParamSet p = sp.getJointDistributionOf (VidSet() = {0, 1, 2}); - //cout << Util::parametersToString (p) << endl; - CountingBP cbp (fg); - //cbp.runSolver(); - //cbp.printAllPosterioris(); - ParamSet p2 = cbp.getJointDistributionOf (VidSet() = {0, 1, 2}); - cout << Util::parametersToString (p2) << endl; - fg.freeDistributions(); - Statistics::printCompressingStats ("compressing.stats"); - return 0; - */ +{ if (!argv[1]) { cerr << "error: no graphical model specified" << endl; cerr << USAGE << endl; @@ -56,12 +33,20 @@ main (int argc, const char* argv[]) const string& fileName = argv[1]; const string& extension = fileName.substr (fileName.find_last_of ('.') + 1); if (extension == "xml") { - BayesianNetwork (argc, argv); + BayesNet bn; + bn.readFromBifFormat (argv[1]); + processArguments (bn, argc, argv); } else if (extension == "uai") { - markovNetwork (argc, argv); + FactorGraph fg; + fg.readFromUaiFormat (argv[1]); + processArguments (fg, argc, argv); + } else if (extension == "fg") { + FactorGraph fg; + fg.readFromLibDaiFormat (argv[1]); + processArguments (fg, argc, argv); } else { cerr << "error: the graphical model must be defined either " ; - cerr << "in a xml file or uai file" << endl; + cerr << "in a xml, uai or libDAI file" << endl; exit (0); } return 0; @@ -70,12 +55,9 @@ main (int argc, const char* argv[]) void -BayesianNetwork (int argc, const char* argv[]) +processArguments (BayesNet& bn, int argc, const char* argv[]) { - BayesNet bn (argv[1]); - //bn.printGraphicalModel(); - - VarSet queryVars; + VarNodes queryVars; for (int i = 2; i < argc; i++) { const string& arg = argv[i]; if (arg.find ('=') == std::string::npos) { @@ -86,6 +68,7 @@ BayesianNetwork (int argc, const char* argv[]) cerr << "error: there isn't a variable labeled of " ; cerr << "`" << arg << "'" ; cerr << endl; + bn.freeDistributions(); exit (0); } } else { @@ -95,11 +78,13 @@ BayesianNetwork (int argc, const char* argv[]) if (label.empty()) { cerr << "error: missing left argument" << endl; cerr << USAGE << endl; + bn.freeDistributions(); exit (0); } if (state.empty()) { cerr << "error: missing right argument" << endl; cerr << USAGE << endl; + bn.freeDistributions(); exit (0); } BayesNode* node = bn.getBayesNode (label); @@ -109,42 +94,54 @@ BayesianNetwork (int argc, const char* argv[]) } else { cerr << "error: `" << state << "' " ; cerr << "is not a valid state for " ; - cerr << "`" << node->getLabel() << "'" ; + cerr << "`" << node->label() << "'" ; cerr << endl; + bn.freeDistributions(); exit (0); } } else { cerr << "error: there isn't a variable labeled of " ; cerr << "`" << label << "'" ; cerr << endl; + bn.freeDistributions(); exit (0); } } } - Solver* solver; - if (SolverOptions::convertBn2Fg) { - FactorGraph* fg = new FactorGraph (bn); - fg->printGraphicalModel(); - solver = new SPSolver (*fg); - runSolver (solver, queryVars); - delete fg; - } else { - solver = new BPSolver (bn); - runSolver (solver, queryVars); + Solver* solver = 0; + FactorGraph* fg = 0; + switch (InfAlgorithms::infAlgorithm) { + case InfAlgorithms::VE: + fg = new FactorGraph (bn); + solver = new VarElimSolver (*fg); + break; + case InfAlgorithms::BN_BP: + solver = new BnBpSolver (bn); + break; + case InfAlgorithms::FG_BP: + fg = new FactorGraph (bn); + fg->printGraphicalModel(); + solver = new FgBpSolver (*fg); + break; + case InfAlgorithms::CBP: + fg = new FactorGraph (bn); + solver = new CbpSolver (*fg); + break; + default: + assert (false); } + runSolver (solver, queryVars); + delete fg; bn.freeDistributions(); } void -markovNetwork (int argc, const char* argv[]) +processArguments (FactorGraph& fg, int argc, const char* argv[]) { - FactorGraph fg (argv[1]); - //fg.printGraphicalModel(); - - VarSet queryVars; + VarNodes queryVars; for (int i = 2; i < argc; i++) { const string& arg = argv[i]; if (arg.find ('=') == std::string::npos) { @@ -152,19 +149,21 @@ markovNetwork (int argc, const char* argv[]) cerr << "error: `" << arg << "' " ; cerr << "is not a valid variable id" ; cerr << endl; + fg.freeDistributions(); exit (0); } - Vid vid; + VarId vid; stringstream ss; ss << arg; ss >> vid; - Variable* queryVar = fg.getFgVarNode (vid); + VarNode* queryVar = fg.getFgVarNode (vid); if (queryVar) { queryVars.push_back (queryVar); } else { cerr << "error: there isn't a variable with " ; cerr << "`" << vid << "' as id" ; cerr << endl; + fg.freeDistributions(); exit (0); } } else { @@ -172,53 +171,73 @@ markovNetwork (int argc, const char* argv[]) if (arg.substr (0, pos).empty()) { cerr << "error: missing left argument" << endl; cerr << USAGE << endl; + fg.freeDistributions(); exit (0); } if (arg.substr (pos + 1).empty()) { cerr << "error: missing right argument" << endl; cerr << USAGE << endl; + fg.freeDistributions(); exit (0); } if (!Util::isInteger (arg.substr (0, pos))) { cerr << "error: `" << arg.substr (0, pos) << "' " ; cerr << "is not a variable id" ; cerr << endl; + fg.freeDistributions(); exit (0); } - Vid vid; + VarId vid; stringstream ss; ss << arg.substr (0, pos); ss >> vid; - Variable* var = fg.getFgVarNode (vid); + VarNode* var = fg.getFgVarNode (vid); if (var) { if (!Util::isInteger (arg.substr (pos + 1))) { cerr << "error: `" << arg.substr (pos + 1) << "' " ; cerr << "is not a state index" ; cerr << endl; + fg.freeDistributions(); exit (0); } int stateIndex; stringstream ss; ss << arg.substr (pos + 1); ss >> stateIndex; - if (var->isValidStateIndex (stateIndex)) { + if (var->isValidState (stateIndex)) { var->setEvidence (stateIndex); } else { cerr << "error: `" << stateIndex << "' " ; cerr << "is not a valid state index for variable " ; - cerr << "`" << var->getVarId() << "'" ; + cerr << "`" << var->varId() << "'" ; cerr << endl; + fg.freeDistributions(); exit (0); } } else { cerr << "error: there isn't a variable with " ; cerr << "`" << vid << "' as id" ; cerr << endl; + fg.freeDistributions(); exit (0); } } } - Solver* solver = new SPSolver (fg); + Solver* solver = 0; + switch (InfAlgorithms::infAlgorithm) { + case InfAlgorithms::VE: + solver = new VarElimSolver (fg); + break; + case InfAlgorithms::BN_BP: + case InfAlgorithms::FG_BP: + solver = new FgBpSolver (fg); + break; + case InfAlgorithms::CBP: + solver = new CbpSolver (fg); + break; + default: + assert (false); + } runSolver (solver, queryVars); fg.freeDistributions(); } @@ -226,11 +245,11 @@ markovNetwork (int argc, const char* argv[]) void -runSolver (Solver* solver, const VarSet& queryVars) +runSolver (Solver* solver, const VarNodes& queryVars) { - VidSet vids; + VarIdSet vids; for (unsigned i = 0; i < queryVars.size(); i++) { - vids.push_back (queryVars[i]->getVarId()); + vids.push_back (queryVars[i]->varId()); } if (queryVars.size() == 0) { solver->runSolver(); @@ -239,6 +258,7 @@ runSolver (Solver* solver, const VarSet& queryVars) solver->runSolver(); solver->printPosterioriOf (vids[0]); } else { + solver->runSolver(); solver->printJointDistributionOf (vids); } delete solver; diff --git a/packages/CLPBN/clpbn/bp/HorusYap.cpp b/packages/CLPBN/clpbn/bp/HorusYap.cpp index c8bd18529..cb364e8be 100644 --- a/packages/CLPBN/clpbn/bp/HorusYap.cpp +++ b/packages/CLPBN/clpbn/bp/HorusYap.cpp @@ -8,9 +8,11 @@ #include "BayesNet.h" #include "FactorGraph.h" -#include "BPSolver.h" -#include "SPSolver.h" -#include "CountingBP.h" +#include "VarElimSolver.h" +#include "BnBpSolver.h" +#include "FgBpSolver.h" +#include "CbpSolver.h" +#include "ElimGraph.h" using namespace std; @@ -18,26 +20,27 @@ using namespace std; int createNetwork (void) { - //Statistics::numCreatedNets ++; - //cout << "creating network number " << Statistics::numCreatedNets << endl; - + Statistics::incrementPrimaryNetworksCounting(); + // cout << "creating network number " ; + // cout << Statistics::getPrimaryNetworksCounting() << endl; + // if (Statistics::getPrimaryNetworksCounting() > 98) { + // Statistics::writeStatisticsToFile ("../../compressing.stats"); + // } BayesNet* bn = new BayesNet(); YAP_Term varList = YAP_ARG1; + BnNodeSet nodes; + vector parents; while (varList != YAP_TermNil()) { - YAP_Term var = YAP_HeadOfTerm (varList); - Vid vid = (Vid) YAP_IntOfTerm (YAP_ArgOfTerm (1, var)); - unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var)); - int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var)); - YAP_Term parentL = YAP_ArgOfTerm (4, var); - unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var)); - BnNodeSet parents; + YAP_Term var = YAP_HeadOfTerm (varList); + VarId vid = (VarId) YAP_IntOfTerm (YAP_ArgOfTerm (1, var)); + unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var)); + int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var)); + YAP_Term parentL = YAP_ArgOfTerm (4, var); + unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var)); + parents.push_back (VarIdSet()); while (parentL != YAP_TermNil()) { unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL)); - BayesNode* parent = bn->getBayesNode (parentId); - if (!parent) { - parent = bn->addNode (parentId); - } - parents.push_back (parent); + parents.back().push_back (parentId); parentL = YAP_TailOfTerm (parentL); } Distribution* dist = bn->getDistribution (distId); @@ -45,20 +48,19 @@ createNetwork (void) dist = new Distribution (distId); bn->addDistribution (dist); } - BayesNode* node = bn->getBayesNode (vid); - if (node) { - node->setData (dsize, evidence, parents, dist); - } else { - bn->addNode (vid, dsize, evidence, parents, dist); - } + assert (bn->getBayesNode (vid) == 0); + nodes.push_back (bn->addNode (vid, dsize, evidence, dist)); varList = YAP_TailOfTerm (varList); } + for (unsigned i = 0; i < nodes.size(); i++) { + BnNodeSet ps; + for (unsigned j = 0; j < parents[i].size(); j++) { + assert (bn->getBayesNode (parents[i][j]) != 0); + ps.push_back (bn->getBayesNode (parents[i][j])); + } + nodes[i]->setParents (ps); + } bn->setIndexes(); - - // if (Statistics::numCreatedNets == 1688) { - // Statistics::writeStats(); - // exit (0); - // } YAP_Int p = (YAP_Int) (bn); return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2); } @@ -68,23 +70,22 @@ createNetwork (void) int setExtraVarsInfo (void) { - BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1); + // BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1); + GraphicalModel::clearVariablesInformation(); YAP_Term varsInfoL = YAP_ARG2; while (varsInfoL != YAP_TermNil()) { YAP_Term head = YAP_HeadOfTerm (varsInfoL); - Vid vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head)); + VarId vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head)); YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head)); - YAP_Term domainL = YAP_ArgOfTerm (3, head); - Domain domain; - while (domainL != YAP_TermNil()) { - YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (domainL)); - domain.push_back ((char*) YAP_AtomName (atom)); - domainL = YAP_TailOfTerm (domainL); + YAP_Term statesL = YAP_ArgOfTerm (3, head); + States states; + while (statesL != YAP_TermNil()) { + YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (statesL)); + states.push_back ((char*) YAP_AtomName (atom)); + statesL = YAP_TailOfTerm (statesL); } - BayesNode* node = bn->getBayesNode (vid); - assert (node); - node->setLabel ((char*) YAP_AtomName (label)); - node->setDomain (domain); + GraphicalModel::addVariableInformation (vid, + (char*) YAP_AtomName (label), states); varsInfoL = YAP_TailOfTerm (varsInfoL); } return TRUE; @@ -106,12 +107,10 @@ setParameters (void) params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL))); paramL = YAP_TailOfTerm (paramL); } - bn->getDistribution(distId)->updateParameters(params); - if (Statistics::numCreatedNets == 4) { - cout << "dist " << distId << " parameters:" ; - cout << Util::parametersToString (params); - cout << endl; + if (NSPACE == NumberSpace::LOGARITHM) { + Util::toLog (params); } + bn->getDistribution(distId)->updateParameters (params); distList = YAP_TailOfTerm (distList); } return TRUE; @@ -124,113 +123,73 @@ runSolver (void) { BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1); YAP_Term taskList = YAP_ARG2; - vector tasks; - VidSet marginalVids; - + vector tasks; + std::set vids; while (taskList != YAP_TermNil()) { if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) { - VidSet jointVids; + tasks.push_back (VarIdSet()); YAP_Term jointList = YAP_HeadOfTerm (taskList); while (jointList != YAP_TermNil()) { - Vid vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList)); + VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList)); assert (bn->getBayesNode (vid)); - jointVids.push_back (vid); + tasks.back().push_back (vid); + vids.insert (vid); jointList = YAP_TailOfTerm (jointList); } - tasks.push_back (jointVids); } else { - Vid vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList)); + VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList)); assert (bn->getBayesNode (vid)); - tasks.push_back (VidSet() = {vid}); - marginalVids.push_back (vid); + tasks.push_back (VarIdSet() = {vid}); + vids.insert (vid); } taskList = YAP_TailOfTerm (taskList); } - // cout << "inference tasks:" << endl; - // for (unsigned i = 0; i < tasks.size(); i++) { - // cout << "i" << ": " ; - // if (tasks[i].size() == 1) { - // cout << tasks[i][0] << endl; - // } else { - // for (unsigned j = 0; j < tasks[i].size(); j++) { - // cout << tasks[i][j] << " " ; - // } - // cout << endl; - // } - // } - - Solver* solver = 0; - GraphicalModel* gm = 0; - VidSet vids; - const BnNodeSet& nodes = bn->getBayesNodes(); - for (unsigned i = 0; i < nodes.size(); i++) { - vids.push_back (nodes[i]->getVarId()); - } - if (marginalVids.size() != 0) { - bn->exportToDotFormat ("bn unbayes.dot"); - BayesNet* mrn = bn->getMinimalRequesiteNetwork (marginalVids); - mrn->exportToDotFormat ("bn bayes.dot"); - //BayesNet* mrn = bn->getMinimalRequesiteNetwork (vids); - if (SolverOptions::convertBn2Fg) { - gm = new FactorGraph (*mrn); - if (SolverOptions::compressFactorGraph) { - solver = new CountingBP (*static_cast (gm)); - } else { - solver = new SPSolver (*static_cast (gm)); - } - if (SolverOptions::runBayesBall) { - delete mrn; - } - } else { - gm = mrn; - solver = new BPSolver (*static_cast (gm)); + Solver* bpSolver = 0; + GraphicalModel* graphicalModel = 0; + CFactorGraph::disableCheckForIdenticalFactors(); + if (InfAlgorithms::infAlgorithm != InfAlgorithms::VE) { + BayesNet* mrn = bn->getMinimalRequesiteNetwork ( + VarIdSet (vids.begin(), vids.end())); + if (InfAlgorithms::infAlgorithm == InfAlgorithms::BN_BP) { + graphicalModel = mrn; + bpSolver = new BnBpSolver (*static_cast (graphicalModel)); + } else if (InfAlgorithms::infAlgorithm == InfAlgorithms::FG_BP) { + graphicalModel = new FactorGraph (*mrn); + bpSolver = new FgBpSolver (*static_cast (graphicalModel)); + delete mrn; + } else if (InfAlgorithms::infAlgorithm == InfAlgorithms::CBP) { + graphicalModel = new FactorGraph (*mrn); + bpSolver = new CbpSolver (*static_cast (graphicalModel)); + delete mrn; } - solver->runSolver(); + bpSolver->runSolver(); } vector results; - results.reserve (tasks.size()); + results.reserve (tasks.size()); for (unsigned i = 0; i < tasks.size(); i++) { - if (tasks[i].size() == 1) { - results.push_back (solver->getPosterioriOf (tasks[i][0])); + //if (i == 1) exit (0); + if (InfAlgorithms::infAlgorithm == InfAlgorithms::VE) { + BayesNet* mrn = bn->getMinimalRequesiteNetwork (tasks[i]); + VarElimSolver* veSolver = new VarElimSolver (*mrn); + if (tasks[i].size() == 1) { + results.push_back (veSolver->getPosterioriOf (tasks[i][0])); + } else { + results.push_back (veSolver->getJointDistributionOf (tasks[i])); + } + delete mrn; + delete veSolver; } else { - static int count = 0; - cout << "calculating joint... " << count ++ << endl; - //if (count == 5225) { - // Statistics::printCompressingStats ("compressing.stats"); - //} - Solver* solver2 = 0; - GraphicalModel* gm2 = 0; - bn->exportToDotFormat ("joint.dot"); - BayesNet* mrn2; - if (SolverOptions::runBayesBall) { - mrn2 = bn->getMinimalRequesiteNetwork (tasks[i]); + if (tasks[i].size() == 1) { + results.push_back (bpSolver->getPosterioriOf (tasks[i][0])); } else { - mrn2 = bn; + results.push_back (bpSolver->getJointDistributionOf (tasks[i])); } - if (SolverOptions::convertBn2Fg) { - gm2 = new FactorGraph (*mrn2); - if (SolverOptions::compressFactorGraph) { - solver2 = new CountingBP (*static_cast (gm2)); - } else { - solver2 = new SPSolver (*static_cast (gm2)); - } - if (SolverOptions::runBayesBall) { - delete mrn2; - } - } else { - gm2 = mrn2; - solver2 = new BPSolver (*static_cast (gm2)); - } - results.push_back (solver2->getJointDistributionOf (tasks[i])); - delete solver2; - delete gm2; } } - - delete solver; - delete gm; + delete bpSolver; + delete graphicalModel; YAP_Term list = YAP_TermNil(); for (int i = results.size() - 1; i >= 0; i--) { @@ -251,10 +210,91 @@ runSolver (void) +int +setSolverParameter (void) +{ + string key ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1))); + if (key == "inf_alg") { + string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2))); + if ( value == "ve") { + InfAlgorithms::infAlgorithm = InfAlgorithms::VE; + } else if (value == "bn_bp") { + InfAlgorithms::infAlgorithm = InfAlgorithms::BN_BP; + } else if (value == "fg_bp") { + InfAlgorithms::infAlgorithm = InfAlgorithms::FG_BP; + } else if (value == "cbp") { + InfAlgorithms::infAlgorithm = InfAlgorithms::CBP; + } else { + cerr << "warning: invalid value `" << value << "' " ; + cerr << "for `" << key << "'" << endl; + return FALSE; + } + } else if (key == "elim_heuristic") { + string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2))); + if ( value == "min_neighbors") { + ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_NEIGHBORS); + } else if (value == "min_weight") { + ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_WEIGHT); + } else if (value == "min_fill") { + ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_FILL); + } else if (value == "weighted_min_fill") { + ElimGraph::setEliminationHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL); + } else { + cerr << "warning: invalid value `" << value << "' " ; + cerr << "for `" << key << "'" << endl; + return FALSE; + } + } else if (key == "schedule") { + string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2))); + if ( value == "seq_fixed") { + BpOptions::schedule = BpOptions::Schedule::SEQ_FIXED; + } else if (value == "seq_random") { + BpOptions::schedule = BpOptions::Schedule::SEQ_RANDOM; + } else if (value == "parallel") { + BpOptions::schedule = BpOptions::Schedule::PARALLEL; + } else if (value == "max_residual") { + BpOptions::schedule = BpOptions::Schedule::MAX_RESIDUAL; + } else { + cerr << "warning: invalid value `" << value << "' " ; + cerr << "for `" << key << "'" << endl; + return FALSE; + } + } else if (key == "accuracy") { + BpOptions::accuracy = (double) YAP_FloatOfTerm (YAP_ARG2); + } else if (key == "max_iter") { + BpOptions::maxIter = (int) YAP_IntOfTerm (YAP_ARG2); + } else if (key == "always_loopy_solver") { + string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2))); + if (value == "true") { + BpOptions::useAlwaysLoopySolver = true; + } else if (value == "false") { + BpOptions::useAlwaysLoopySolver = false; + } else { + cerr << "warning: invalid value `" << value << "' " ; + cerr << "for `" << key << "'" << endl; + return FALSE; + } + } else { + cerr << "warning: invalid key `" << key << "'" << endl; + return FALSE; + } + return TRUE; +} + + + +int useLogSpace (void) +{ + NSPACE = NumberSpace::LOGARITHM; + return TRUE; +} + + + int freeBayesNetwork (void) { - //Statistics::printCompressingStats ("../../compressing.stats"); + //Statistics::writeStatisticsToFile ("stats.txt"); BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1); bn->freeDistributions(); delete bn; @@ -266,10 +306,12 @@ freeBayesNetwork (void) extern "C" void init_predicates (void) { - YAP_UserCPredicate ("create_network", createNetwork, 2); - YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2); - YAP_UserCPredicate ("set_parameters", setParameters, 2); - YAP_UserCPredicate ("run_solver", runSolver, 3); - YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1); + YAP_UserCPredicate ("create_network", createNetwork, 2); + YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2); + YAP_UserCPredicate ("set_parameters", setParameters, 2); + YAP_UserCPredicate ("run_solver", runSolver, 3); + YAP_UserCPredicate ("set_solver_parameter", setSolverParameter, 2); + YAP_UserCPredicate ("use_log_space", useLogSpace, 0); + YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1); } diff --git a/packages/CLPBN/clpbn/bp/LiftedFG.cpp b/packages/CLPBN/clpbn/bp/LiftedFG.cpp deleted file mode 100644 index add8610cc..000000000 --- a/packages/CLPBN/clpbn/bp/LiftedFG.cpp +++ /dev/null @@ -1,278 +0,0 @@ - -#include "LiftedFG.h" -#include "FgVarNode.h" -#include "Factor.h" -#include "Distribution.h" - -LiftedFG::LiftedFG (const FactorGraph& fg) -{ - groundFg_ = &fg; - freeColor_ = 0; - - const FgVarSet& varNodes = fg.getFgVarNodes(); - const FactorSet& factors = fg.getFactors(); - varColors_.resize (varNodes.size()); - factorColors_.resize (factors.size()); - for (unsigned i = 0; i < factors.size(); i++) { - factors[i]->setIndex (i); - } - - // create the initial variable colors - VarColorMap colorMap; - for (unsigned i = 0; i < varNodes.size(); i++) { - unsigned dsize = varNodes[i]->getDomainSize(); - VarColorMap::iterator it = colorMap.find (dsize); - if (it == colorMap.end()) { - it = colorMap.insert (make_pair ( - dsize, vector (dsize + 1,-1))).first; - } - unsigned idx; - if (varNodes[i]->hasEvidence()) { - idx = varNodes[i]->getEvidence(); - } else { - idx = dsize; - } - vector& stateColors = it->second; - if (stateColors[idx] == -1) { - stateColors[idx] = getFreeColor(); - } - setColor (varNodes[i], stateColors[idx]); - } - - // create the initial factor colors - DistColorMap distColors; - for (unsigned i = 0; i < factors.size(); i++) { - Distribution* dist = factors[i]->getDistribution(); - DistColorMap::iterator it = distColors.find (dist); - if (it == distColors.end()) { - it = distColors.insert (make_pair (dist, getFreeColor())).first; - } - setColor (factors[i], it->second); - } - - VarSignMap varGroups; - FactorSignMap factorGroups; - bool groupsHaveChanged = true; - unsigned nIter = 0; - while (groupsHaveChanged || nIter == 1) { - nIter ++; - if (Statistics::numCreatedNets == 4) { - cout << "--------------------------------------------" << endl; - cout << "Iteration " << nIter << endl; - cout << "--------------------------------------------" << endl; - } - - unsigned prevFactorGroupsSize = factorGroups.size(); - factorGroups.clear(); - // set a new color to the factors with the same signature - for (unsigned i = 0; i < factors.size(); i++) { - const string& signatureId = getSignatureId (factors[i]); - // cout << factors[i]->getLabel() << " signature: " ; - // cout<< signatureId << endl; - FactorSignMap::iterator it = factorGroups.find (signatureId); - if (it == factorGroups.end()) { - it = factorGroups.insert (make_pair (signatureId, FactorSet())).first; - } - it->second.push_back (factors[i]); - } - if (nIter > 0) - for (FactorSignMap::iterator it = factorGroups.begin(); - it != factorGroups.end(); it++) { - Color newColor = getFreeColor(); - FactorSet& groupMembers = it->second; - for (unsigned i = 0; i < groupMembers.size(); i++) { - setColor (groupMembers[i], newColor); - } - } - - // set a new color to the variables with the same signature - unsigned prevVarGroupsSize = varGroups.size(); - varGroups.clear(); - for (unsigned i = 0; i < varNodes.size(); i++) { - const string& signatureId = getSignatureId (varNodes[i]); - VarSignMap::iterator it = varGroups.find (signatureId); - // cout << varNodes[i]->getLabel() << " signature: " ; - // cout << signatureId << endl; - if (it == varGroups.end()) { - it = varGroups.insert (make_pair (signatureId, FgVarSet())).first; - } - it->second.push_back (varNodes[i]); - } - if (nIter > 0) - for (VarSignMap::iterator it = varGroups.begin(); - it != varGroups.end(); it++) { - Color newColor = getFreeColor(); - FgVarSet& groupMembers = it->second; - for (unsigned i = 0; i < groupMembers.size(); i++) { - setColor (groupMembers[i], newColor); - } - } - - //if (nIter >= 3) cout << "bigger than three: " << nIter << endl; - groupsHaveChanged = prevVarGroupsSize != varGroups.size() - || prevFactorGroupsSize != factorGroups.size(); - } - - printGroups (varGroups, factorGroups); - for (VarSignMap::iterator it = varGroups.begin(); - it != varGroups.end(); it++) { - CFgVarSet vars = it->second; - VarCluster* vc = new VarCluster (vars); - for (unsigned i = 0; i < vars.size(); i++) { - vid2VarCluster_.insert (make_pair (vars[i]->getVarId(), vc)); - } - varClusters_.push_back (vc); - } - - for (FactorSignMap::iterator it = factorGroups.begin(); - it != factorGroups.end(); it++) { - VarClusterSet varClusters; - Factor* groundFactor = it->second[0]; - FgVarSet groundVars = groundFactor->getFgVarNodes(); - for (unsigned i = 0; i < groundVars.size(); i++) { - Vid vid = groundVars[i]->getVarId(); - varClusters.push_back (vid2VarCluster_.find (vid)->second); - } - factorClusters_.push_back (new FactorCluster (it->second, varClusters)); - } -} - - - -LiftedFG::~LiftedFG (void) -{ - for (unsigned i = 0; i < varClusters_.size(); i++) { - delete varClusters_[i]; - } - for (unsigned i = 0; i < factorClusters_.size(); i++) { - delete factorClusters_[i]; - } -} - - - -string -LiftedFG::getSignatureId (FgVarNode* var) const -{ - stringstream ss; - CFactorSet myFactors = var->getFactors(); - ss << myFactors.size(); - for (unsigned i = 0; i < myFactors.size(); i++) { - ss << "." << getColor (myFactors[i]); - ss << "." << myFactors[i]->getIndexOf(var); - } - ss << "." << getColor (var); - return ss.str(); -} - - - -string -LiftedFG::getSignatureId (Factor* factor) const -{ - stringstream ss; - CFgVarSet myVars = factor->getFgVarNodes(); - ss << myVars.size(); - for (unsigned i = 0; i < myVars.size(); i++) { - ss << "." << getColor (myVars[i]); - } - ss << "." << getColor (factor); - return ss.str(); -} - - - -FactorGraph* -LiftedFG::getCompressedFactorGraph (void) -{ - FactorGraph* fg = new FactorGraph(); - for (unsigned i = 0; i < varClusters_.size(); i++) { - FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0]; - FgVarNode* newVar = new FgVarNode (var); - newVar->setIndex (i); - varClusters_[i]->setRepresentativeVariable (newVar); - fg->addVariable (newVar); - } - - for (unsigned i = 0; i < factorClusters_.size(); i++) { - FgVarSet myGroundVars; - const VarClusterSet& myVarClusters = factorClusters_[i]->getVarClusters(); - for (unsigned j = 0; j < myVarClusters.size(); j++) { - myGroundVars.push_back (myVarClusters[j]->getRepresentativeVariable()); - } - Factor* newFactor = new Factor (myGroundVars, - factorClusters_[i]->getGroundFactors()[0]->getDistribution()); - factorClusters_[i]->setRepresentativeFactor (newFactor); - fg->addFactor (newFactor); - } - return fg; -} - - - -unsigned -LiftedFG::getGroundEdgeCount (FactorCluster* fc, VarCluster* vc) const -{ - CFactorSet clusterGroundFactors = fc->getGroundFactors(); - FgVarNode* var = vc->getGroundFgVarNodes()[0]; - unsigned count = 0; - for (unsigned i = 0; i < clusterGroundFactors.size(); i++) { - if (clusterGroundFactors[i]->getIndexOf (var) != -1) { - count ++; - } - } - /* - CFgVarSet vars = vc->getGroundFgVarNodes(); - for (unsigned i = 1; i < vars.size(); i++) { - FgVarNode* var = vc->getGroundFgVarNodes()[i]; - unsigned count2 = 0; - for (unsigned i = 0; i < clusterGroundFactors.size(); i++) { - if (clusterGroundFactors[i]->getIndexOf (var) != -1) { - count2 ++; - } - } - if (count != count2) { cout << "oops!" << endl; abort(); } - } - */ - return count; -} - - - -void -LiftedFG::printGroups (const VarSignMap& varGroups, - const FactorSignMap& factorGroups) const -{ - cout << "variable groups:" << endl; - unsigned count = 0; - for (VarSignMap::const_iterator it = varGroups.begin(); - it != varGroups.end(); it++) { - const FgVarSet& groupMembers = it->second; - if (groupMembers.size() > 0) { - cout << ++count << ": " ; - //if (groupMembers.size() > 1) { - for (unsigned i = 0; i < groupMembers.size(); i++) { - cout << groupMembers[i]->getLabel() << " " ; - } - //} - cout << endl; - } - } - cout << endl; - cout << "factor groups:" << endl; - count = 0; - for (FactorSignMap::const_iterator it = factorGroups.begin(); - it != factorGroups.end(); it++) { - const FactorSet& groupMembers = it->second; - if (groupMembers.size() > 0) { - cout << ++count << ": " ; - //if (groupMembers.size() > 1) { - for (unsigned i = 0; i < groupMembers.size(); i++) { - cout << groupMembers[i]->getLabel() << " " ; - } - //} - cout << endl; - } - } -} - diff --git a/packages/CLPBN/clpbn/bp/LiftedFG.h b/packages/CLPBN/clpbn/bp/LiftedFG.h deleted file mode 100644 index 4a2518fc1..000000000 --- a/packages/CLPBN/clpbn/bp/LiftedFG.h +++ /dev/null @@ -1,152 +0,0 @@ -#ifndef BP_LIFTED_FG_H -#define BP_LIFTED_FG_H - -#include - -#include "FactorGraph.h" -#include "FgVarNode.h" -#include "Factor.h" -#include "Shared.h" - -class VarCluster; -class FactorCluster; -class Distribution; - -typedef long Color; -typedef vector Signature; -typedef vector VarClusterSet; -typedef vector FactorClusterSet; - -typedef map VarSignMap; -typedef map FactorSignMap; - -typedef map > VarColorMap; -typedef map DistColorMap; - -typedef map Vid2VarCluster; - - -class VarCluster -{ - public: - VarCluster (CFgVarSet vs) - { - for (unsigned i = 0; i < vs.size(); i++) { - groundVars_.push_back (vs[i]); - } - } - - void addFactorCluster (FactorCluster* fc) - { - factorClusters_.push_back (fc); - } - - const FactorClusterSet& getFactorClusters (void) const - { - return factorClusters_; - } - - FgVarNode* getRepresentativeVariable (void) const { return representVar_; } - void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; } - CFgVarSet getGroundFgVarNodes (void) const { return groundVars_; } - - private: - FgVarSet groundVars_; - FactorClusterSet factorClusters_; - FgVarNode* representVar_; -}; - - -class FactorCluster -{ - public: - FactorCluster (CFactorSet groundFactors, const VarClusterSet& vcs) - { - groundFactors_ = groundFactors; - varClusters_ = vcs; - for (unsigned i = 0; i < varClusters_.size(); i++) { - varClusters_[i]->addFactorCluster (this); - } - } - - const VarClusterSet& getVarClusters (void) const - { - return varClusters_; - } - - bool containsGround (const Factor* f) - { - for (unsigned i = 0; i < groundFactors_.size(); i++) { - if (groundFactors_[i] == f) { - return true; - } - } - return false; - } - - Factor* getRepresentativeFactor (void) const { return representFactor_; } - void setRepresentativeFactor (Factor* f) { representFactor_ = f; } - CFactorSet getGroundFactors (void) const { return groundFactors_; } - - - private: - FactorSet groundFactors_; - VarClusterSet varClusters_; - Factor* representFactor_; -}; - - -class LiftedFG -{ - public: - LiftedFG (const FactorGraph&); - ~LiftedFG (void); - - FactorGraph* getCompressedFactorGraph (void); - unsigned getGroundEdgeCount (FactorCluster*, VarCluster*) const; - void printGroups (const VarSignMap& varGroups, - const FactorSignMap& factorGroups) const; - - FgVarNode* getEquivalentVariable (Vid vid) - { - VarCluster* vc = vid2VarCluster_.find (vid)->second; - return vc->getRepresentativeVariable(); - } - - const VarClusterSet& getVariableClusters (void) { return varClusters_; } - const FactorClusterSet& getFactorClusters (void) { return factorClusters_; } - - private: - string getSignatureId (FgVarNode*) const; - string getSignatureId (Factor*) const; - - Color getFreeColor (void) { return ++freeColor_ -1; } - Color getColor (FgVarNode* v) const { return varColors_[v->getIndex()]; } - Color getColor (Factor* f) const { return factorColors_[f->getIndex()]; } - - void setColor (FgVarNode* v, Color c) - { - varColors_[v->getIndex()] = c; - } - - void setColor (Factor* f, Color c) - { - factorColors_[f->getIndex()] = c; - } - - VarCluster* getVariableCluster (Vid vid) const - { - return vid2VarCluster_.find (vid)->second; - } - - Color freeColor_; - vector varColors_; - vector factorColors_; - VarClusterSet varClusters_; - FactorClusterSet factorClusters_; - Vid2VarCluster vid2VarCluster_; - const FactorGraph* groundFg_; -}; - -#endif // BP_LIFTED_FG_H - diff --git a/packages/CLPBN/clpbn/bp/Makefile.in b/packages/CLPBN/clpbn/bp/Makefile.in index e21e6ff79..7b90e2642 100644 --- a/packages/CLPBN/clpbn/bp/Makefile.in +++ b/packages/CLPBN/clpbn/bp/Makefile.in @@ -26,10 +26,7 @@ CXX=@CXX@ CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../../.. -I$(srcdir)/../../../../include @CPPFLAGS@ -DNDEBUG # debug -#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../../.. -I$(srcdir)/../../../../include @CPPFLAGS@ -g -O0 - -# profiling (callgrind) -#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../../.. -I$(srcdir)/../../../../include @CPPFLAGS@ -g -DNDEBUG +#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../../.. -I$(srcdir)/../../../../include @CPPFLAGS@ -g -O0 -Wextra # @@ -49,33 +46,37 @@ CWD=$(PWD) HEADERS = \ $(srcdir)/GraphicalModel.h \ - $(srcdir)/Variable.h \ + $(srcdir)/VarNode.h \ $(srcdir)/Distribution.h \ $(srcdir)/BayesNet.h \ $(srcdir)/BayesNode.h \ - $(srcdir)/LiftedFG.h \ + $(srcdir)/ElimGraph.h \ + $(srcdir)/CFactorGraph.h \ $(srcdir)/CptEntry.h \ $(srcdir)/FactorGraph.h \ - $(srcdir)/FgVarNode.h \ $(srcdir)/Factor.h \ $(srcdir)/Solver.h \ - $(srcdir)/BPSolver.h \ - $(srcdir)/BPNodeInfo.h \ - $(srcdir)/SPSolver.h \ - $(srcdir)/CountingBP.h \ + $(srcdir)/VarElimSolver.h \ + $(srcdir)/BnBpSolver.h \ + $(srcdir)/FgBpSolver.h \ + $(srcdir)/CbpSolver.h \ $(srcdir)/Shared.h \ + $(srcdir)/StatesIndexer.h \ $(srcdir)/xmlParser/xmlParser.h - + CPP_SOURCES = \ $(srcdir)/BayesNet.cpp \ $(srcdir)/BayesNode.cpp \ + $(srcdir)/ElimGraph.cpp \ $(srcdir)/FactorGraph.cpp \ $(srcdir)/Factor.cpp \ - $(srcdir)/LiftedFG.cpp \ - $(srcdir)/BPSolver.cpp \ - $(srcdir)/BPNodeInfo.cpp \ - $(srcdir)/SPSolver.cpp \ - $(srcdir)/CountingBP.cpp \ + $(srcdir)/CFactorGraph.cpp \ + $(srcdir)/VarNode.cpp \ + $(srcdir)/Solver.cpp \ + $(srcdir)/VarElimSolver.cpp \ + $(srcdir)/BnBpSolver.cpp \ + $(srcdir)/FgBpSolver.cpp \ + $(srcdir)/CbpSolver.cpp \ $(srcdir)/Util.cpp \ $(srcdir)/HorusYap.cpp \ $(srcdir)/HorusCli.cpp \ @@ -84,29 +85,35 @@ CPP_SOURCES = \ OBJS = \ BayesNet.o \ BayesNode.o \ + ElimGraph.o \ FactorGraph.o \ Factor.o \ - BPSolver.o \ - BPNodeInfo.o \ - SPSolver.o \ + CFactorGraph.o \ + VarNode.o \ + Solver.o \ + VarElimSolver.o \ + BnBpSolver.o \ + FgBpSolver.o \ + CbpSolver.o \ Util.o \ - LiftedFG.o \ - CountingBP.o \ HorusYap.o HCLI_OBJS = \ BayesNet.o \ BayesNode.o \ + ElimGraph.o \ FactorGraph.o \ Factor.o \ - BPSolver.o \ - BPNodeInfo.o \ - SPSolver.o \ + CFactorGraph.o \ + VarNode.o \ + Solver.o \ + VarElimSolver.o \ + BnBpSolver.o \ + FgBpSolver.o \ + CbpSolver.o \ Util.o \ - LiftedFG.o \ - CountingBP.o \ - HorusCli.o \ - xmlParser/xmlParser.o + xmlParser/xmlParser.o \ + HorusCli.o SOBJS=horus.@SO@ diff --git a/packages/CLPBN/clpbn/bp/SPSolver.cpp b/packages/CLPBN/clpbn/bp/SPSolver.cpp deleted file mode 100644 index f2ec74fba..000000000 --- a/packages/CLPBN/clpbn/bp/SPSolver.cpp +++ /dev/null @@ -1,470 +0,0 @@ -#include -#include - -#include - -#include "SPSolver.h" -#include "FactorGraph.h" -#include "FgVarNode.h" -#include "Factor.h" -#include "Shared.h" - - -SPSolver::SPSolver (FactorGraph& fg) : Solver (&fg) -{ - fg_ = &fg; -} - - - -SPSolver::~SPSolver (void) -{ - for (unsigned i = 0; i < varsI_.size(); i++) { - delete varsI_[i]; - } - for (unsigned i = 0; i < factorsI_.size(); i++) { - delete factorsI_[i]; - } - for (unsigned i = 0; i < links_.size(); i++) { - delete links_[i]; - } -} - - - -void -SPSolver::runTreeSolver (void) -{ - CFactorSet factors = fg_->getFactors(); - bool finish = false; - while (!finish) { - finish = true; - for (unsigned i = 0; i < factors.size(); i++) { - CLinkSet links = factorsI_[factors[i]->getIndex()]->getLinks(); - for (unsigned j = 0; j < links.size(); j++) { - if (!links[j]->messageWasSended()) { - if (readyToSendMessage(links[j])) { - links[j]->setNextMessage (getFactor2VarMsg (links[j])); - links[j]->updateMessage(); - } - finish = false; - } - } - } - } -} - - - -bool -SPSolver::readyToSendMessage (const Link* link) const -{ - CFgVarSet factorVars = link->getFactor()->getFgVarNodes(); - for (unsigned i = 0; i < factorVars.size(); i++) { - if (factorVars[i] != link->getVariable()) { - CLinkSet links = varsI_[factorVars[i]->getIndex()]->getLinks(); - for (unsigned j = 0; j < links.size(); j++) { - if (links[j]->getFactor() != link->getFactor() && - !links[j]->messageWasSended()) { - return false; - } - } - } - } - return true; -} - - - -void -SPSolver::runSolver (void) -{ - initializeSolver(); - runTreeSolver(); - return; - nIter_ = 0; - while (!converged() && nIter_ < SolverOptions::maxIter) { - - nIter_ ++; - if (DL >= 2) { - cout << endl; - cout << "****************************************" ; - cout << "****************************************" ; - cout << endl; - cout << " Iteration " << nIter_ << endl; - cout << "****************************************" ; - cout << "****************************************" ; - cout << endl; - } - - switch (SolverOptions::schedule) { - case SolverOptions::S_SEQ_RANDOM: - random_shuffle (links_.begin(), links_.end()); - // no break - - case SolverOptions::S_SEQ_FIXED: - for (unsigned i = 0; i < links_.size(); i++) { - links_[i]->setNextMessage (getFactor2VarMsg (links_[i])); - links_[i]->updateMessage(); - } - break; - - case SolverOptions::S_PARALLEL: - for (unsigned i = 0; i < links_.size(); i++) { - links_[i]->setNextMessage (getFactor2VarMsg (links_[i])); - } - for (unsigned i = 0; i < links_.size(); i++) { - links_[i]->updateMessage(); - } - break; - - case SolverOptions::S_MAX_RESIDUAL: - maxResidualSchedule(); - break; - } - } - - if (DL >= 2) { - cout << endl; - if (nIter_ < SolverOptions::maxIter) { - cout << "Loopy Sum-Product converged in " ; - cout << nIter_ << " iterations" << endl; - } else { - cout << "The maximum number of iterations was hit, terminating..." ; - cout << endl; - } - } -} - - - -ParamSet -SPSolver::getPosterioriOf (Vid vid) const -{ - assert (fg_->getFgVarNode (vid)); - FgVarNode* var = fg_->getFgVarNode (vid); - ParamSet probs; - - if (var->hasEvidence()) { - probs.resize (var->getDomainSize(), 0.0); - probs[var->getEvidence()] = 1.0; - } else { - probs.resize (var->getDomainSize(), 1.0); - CLinkSet links = varsI_[var->getIndex()]->getLinks(); - for (unsigned i = 0; i < links.size(); i++) { - CParamSet msg = links[i]->getMessage(); - for (unsigned j = 0; j < msg.size(); j++) { - probs[j] *= msg[j]; - } - } - Util::normalize (probs); - } - return probs; -} - - - -ParamSet -SPSolver::getJointDistributionOf (const VidSet& jointVids) -{ - FgVarSet jointVars; - unsigned dsize = 1; - for (unsigned i = 0; i < jointVids.size(); i++) { - FgVarNode* varNode = fg_->getFgVarNode (jointVids[i]); - dsize *= varNode->getDomainSize(); - jointVars.push_back (varNode); - } - - unsigned maxVid = std::numeric_limits::max(); - FgVarNode* junctionVar = new FgVarNode (maxVid, dsize); - FgVarSet factorVars = { junctionVar }; - for (unsigned i = 0; i < jointVars.size(); i++) { - factorVars.push_back (jointVars[i]); - } - - unsigned nParams = dsize * dsize; - ParamSet params (nParams); - for (unsigned i = 0; i < nParams; i++) { - unsigned row = i / dsize; - unsigned col = i % dsize; - if (row == col) { - params[i] = 1; - } else { - params[i] = 0; - } - } - - Distribution* dist = new Distribution (params, maxVid); - Factor* newFactor = new Factor (factorVars, dist); - fg_->addVariable (junctionVar); - fg_->addFactor (newFactor); - - runSolver(); - ParamSet results = getPosterioriOf (maxVid); - deleteJunction (newFactor, junctionVar); - - return results; -} - - - -void -SPSolver::initializeSolver (void) -{ - fg_->setIndexes(); - - CFgVarSet vars = fg_->getFgVarNodes(); - for (unsigned i = 0; i < varsI_.size(); i++) { - delete varsI_[i]; - } - varsI_.reserve (vars.size()); - for (unsigned i = 0; i < vars.size(); i++) { - varsI_.push_back (new SPNodeInfo()); - } - - CFactorSet factors = fg_->getFactors(); - for (unsigned i = 0; i < factorsI_.size(); i++) { - delete factorsI_[i]; - } - factorsI_.reserve (factors.size()); - for (unsigned i = 0; i < factors.size(); i++) { - factorsI_.push_back (new SPNodeInfo()); - } - - for (unsigned i = 0; i < links_.size(); i++) { - delete links_[i]; - } - createLinks(); - - for (unsigned i = 0; i < links_.size(); i++) { - Factor* source = links_[i]->getFactor(); - FgVarNode* dest = links_[i]->getVariable(); - varsI_[dest->getIndex()]->addLink (links_[i]); - factorsI_[source->getIndex()]->addLink (links_[i]); - } -} - - - -void -SPSolver::createLinks (void) -{ - CFactorSet factors = fg_->getFactors(); - for (unsigned i = 0; i < factors.size(); i++) { - CFgVarSet neighbors = factors[i]->getFgVarNodes(); - for (unsigned j = 0; j < neighbors.size(); j++) { - links_.push_back (new Link (factors[i], neighbors[j])); - } - } -} - - - -void -SPSolver::deleteJunction (Factor* f, FgVarNode* v) -{ - fg_->removeFactor (f); - f->freeDistribution(); - delete f; - fg_->removeVariable (v); - delete v; -} - - - -bool -SPSolver::converged (void) -{ - // this can happen if the graph is fully disconnected - if (links_.size() == 0) { - return true; - } - if (nIter_ == 0 || nIter_ == 1) { - return false; - } - bool converged = true; - if (SolverOptions::schedule == SolverOptions::S_MAX_RESIDUAL) { - Param maxResidual = (*(sortedOrder_.begin()))->getResidual(); - if (maxResidual < SolverOptions::accuracy) { - converged = true; - } else { - converged = false; - } - } else { - for (unsigned i = 0; i < links_.size(); i++) { - double residual = links_[i]->getResidual(); - if (DL >= 2) { - cout << links_[i]->toString() + " residual = " << residual << endl; - } - if (residual > SolverOptions::accuracy) { - converged = false; - if (DL == 0) break; - } - } - } - return converged; -} - - - -void -SPSolver::maxResidualSchedule (void) -{ - if (nIter_ == 1) { - for (unsigned i = 0; i < links_.size(); i++) { - links_[i]->setNextMessage (getFactor2VarMsg (links_[i])); - SortedOrder::iterator it = sortedOrder_.insert (links_[i]); - linkMap_.insert (make_pair (links_[i], it)); - if (DL >= 2 && DL < 5) { - cout << "calculating " << links_[i]->toString() << endl; - } - } - return; - } - - for (unsigned c = 0; c < links_.size(); c++) { - if (DL >= 2) { - cout << endl << "current residuals:" << endl; - for (SortedOrder::iterator it = sortedOrder_.begin(); - it != sortedOrder_.end(); it ++) { - cout << " " << setw (30) << left << (*it)->toString(); - cout << "residual = " << (*it)->getResidual() << endl; - } - } - - SortedOrder::iterator it = sortedOrder_.begin(); - Link* link = *it; - if (DL >= 2) { - cout << "updating " << (*sortedOrder_.begin())->toString() << endl; - } - if (link->getResidual() < SolverOptions::accuracy) { - return; - } - link->updateMessage(); - link->clearResidual(); - sortedOrder_.erase (it); - linkMap_.find (link)->second = sortedOrder_.insert (link); - - // update the messages that depend on message source --> destin - CFactorSet factorNeighbors = link->getVariable()->getFactors(); - for (unsigned i = 0; i < factorNeighbors.size(); i++) { - if (factorNeighbors[i] != link->getFactor()) { - CLinkSet links = factorsI_[factorNeighbors[i]->getIndex()]->getLinks(); - for (unsigned j = 0; j < links.size(); j++) { - if (links[j]->getVariable() != link->getVariable()) { - if (DL >= 2 && DL < 5) { - cout << " calculating " << links[j]->toString() << endl; - } - links[j]->setNextMessage (getFactor2VarMsg (links[j])); - LinkMap::iterator iter = linkMap_.find (links[j]); - sortedOrder_.erase (iter->second); - iter->second = sortedOrder_.insert (links[j]); - } - } - } - } - } -} - - - -ParamSet -SPSolver::getFactor2VarMsg (const Link* link) const -{ - const Factor* src = link->getFactor(); - const FgVarNode* dest = link->getVariable(); - CFgVarSet neighbors = src->getFgVarNodes(); - CLinkSet links = factorsI_[src->getIndex()]->getLinks(); - // calculate the product of messages that were sent - // to factor `src', except from var `dest' - Factor result (*src); - Factor temp; - if (DL >= 5) { - cout << "calculating " ; - cout << src->getLabel() << " --> " << dest->getLabel(); - cout << endl; - } - for (unsigned i = 0; i < neighbors.size(); i++) { - if (links[i]->getVariable() != dest) { - if (DL >= 5) { - cout << " message from " << links[i]->getVariable()->getLabel(); - cout << ": " ; - ParamSet p = getVar2FactorMsg (links[i]); - cout << endl; - Factor temp2 (links[i]->getVariable(), p); - temp.multiplyByFactor (temp2); - temp2.freeDistribution(); - } else { - Factor temp2 (links[i]->getVariable(), getVar2FactorMsg (links[i])); - temp.multiplyByFactor (temp2); - temp2.freeDistribution(); - } - } - } - if (links.size() >= 2) { - result.multiplyByFactor (temp, &(src->getCptEntries())); - if (DL >= 5) { - cout << " message product: " ; - cout << Util::parametersToString (temp.getParameters()) << endl; - cout << " factor product: " ; - cout << Util::parametersToString (src->getParameters()); - cout << " x " ; - cout << Util::parametersToString (temp.getParameters()); - cout << " = " ; - cout << Util::parametersToString (result.getParameters()) << endl; - } - temp.freeDistribution(); - } - - for (unsigned i = 0; i < links.size(); i++) { - if (links[i]->getVariable() != dest) { - result.removeVariable (links[i]->getVariable()); - } - } - if (DL >= 5) { - cout << " final message: " ; - cout << Util::parametersToString (result.getParameters()) << endl << endl; - } - ParamSet msg = result.getParameters(); - result.freeDistribution(); - return msg; -} - - - -ParamSet -SPSolver::getVar2FactorMsg (const Link* link) const -{ - const FgVarNode* src = link->getVariable(); - const Factor* dest = link->getFactor(); - ParamSet msg; - if (src->hasEvidence()) { - msg.resize (src->getDomainSize(), 0.0); - msg[src->getEvidence()] = 1.0; - if (DL >= 5) { - cout << Util::parametersToString (msg); - } - } else { - msg.resize (src->getDomainSize(), 1.0); - } - if (DL >= 5) { - cout << Util::parametersToString (msg); - } - CLinkSet links = varsI_[src->getIndex()]->getLinks(); - for (unsigned i = 0; i < links.size(); i++) { - if (links[i]->getFactor() != dest) { - CParamSet msgFromFactor = links[i]->getMessage(); - for (unsigned j = 0; j < msgFromFactor.size(); j++) { - msg[j] *= msgFromFactor[j]; - } - if (DL >= 5) { - cout << " x " << Util::parametersToString (msgFromFactor); - } - } - } - if (DL >= 5) { - cout << " = " << Util::parametersToString (msg); - } - return msg; -} - diff --git a/packages/CLPBN/clpbn/bp/SPSolver.h b/packages/CLPBN/clpbn/bp/SPSolver.h deleted file mode 100644 index bc4d64ade..000000000 --- a/packages/CLPBN/clpbn/bp/SPSolver.h +++ /dev/null @@ -1,130 +0,0 @@ -#ifndef BP_SP_SOLVER_H -#define BP_SP_SOLVER_H - -#include -#include - -#include "Solver.h" -#include "FgVarNode.h" -#include "Factor.h" - -using namespace std; - -class FactorGraph; -class SPSolver; - - -class Link -{ - public: - Link (Factor* f, FgVarNode* v) - { - factor_ = f; - var_ = v; - currMsg_.resize (v->getDomainSize(), 1); - nextMsg_.resize (v->getDomainSize(), 1); - msgSended_ = false; - residual_ = 0.0; - } - - void setMessage (ParamSet msg) - { - Util::normalize (msg); - residual_ = Util::getMaxNorm (currMsg_, msg); - currMsg_ = msg; - } - - void setNextMessage (CParamSet msg) - { - nextMsg_ = msg; - Util::normalize (nextMsg_); - residual_ = Util::getMaxNorm (currMsg_, nextMsg_); - } - - void updateMessage (void) - { - currMsg_ = nextMsg_; - msgSended_ = true; - } - - string toString (void) const - { - stringstream ss; - ss << factor_->getLabel(); - ss << " -- " ; - ss << var_->getLabel(); - return ss.str(); - } - - Factor* getFactor (void) const { return factor_; } - FgVarNode* getVariable (void) const { return var_; } - CParamSet getMessage (void) const { return currMsg_; } - bool messageWasSended (void) const { return msgSended_; } - double getResidual (void) const { return residual_; } - void clearResidual (void) { residual_ = 0.0; } - - private: - Factor* factor_; - FgVarNode* var_; - ParamSet currMsg_; - ParamSet nextMsg_; - bool msgSended_; - double residual_; -}; - - -class SPNodeInfo -{ - public: - void addLink (Link* link) { links_.push_back (link); } - CLinkSet getLinks (void) { return links_; } - - private: - LinkSet links_; -}; - - -class SPSolver : public Solver -{ - public: - SPSolver (FactorGraph&); - virtual ~SPSolver (void); - - void runSolver (void); - virtual ParamSet getPosterioriOf (Vid) const; - ParamSet getJointDistributionOf (CVidSet); - - protected: - virtual void initializeSolver (void); - void runTreeSolver (void); - bool readyToSendMessage (const Link*) const; - virtual void createLinks (void); - virtual void deleteJunction (Factor*, FgVarNode*); - bool converged (void); - virtual void maxResidualSchedule (void); - virtual ParamSet getFactor2VarMsg (const Link*) const; - virtual ParamSet getVar2FactorMsg (const Link*) const; - - struct CompareResidual { - inline bool operator() (const Link* link1, const Link* link2) - { - return link1->getResidual() > link2->getResidual(); - } - }; - - FactorGraph* fg_; - LinkSet links_; - vector varsI_; - vector factorsI_; - unsigned nIter_; - - typedef multiset SortedOrder; - SortedOrder sortedOrder_; - - typedef map LinkMap; - LinkMap linkMap_; - -}; - -#endif // BP_SP_SOLVER_H - diff --git a/packages/CLPBN/clpbn/bp/Shared.h b/packages/CLPBN/clpbn/bp/Shared.h index a30d3edc7..1b87607a8 100644 --- a/packages/CLPBN/clpbn/bp/Shared.h +++ b/packages/CLPBN/clpbn/bp/Shared.h @@ -1,15 +1,15 @@ -#ifndef BP_SHARED_H -#define BP_SHARED_H +#ifndef HORUS_SHARED_H +#define HORUS_SHARED_H #include #include +#include + #include -#include #include #include #include -#include #define DISALLOW_COPY_AND_ASSIGN(TypeName) \ TypeName(const TypeName&); \ @@ -17,34 +17,29 @@ using namespace std; -class Variable; +class VarNode; +class BayesNet; class BayesNode; -class FgVarNode; class Factor; -class Link; -class Edge; +class FgVarNode; +class FgFacNode; +class SpLink; +class BpLink; + +typedef double Param; +typedef vector ParamSet; +typedef unsigned VarId; +typedef vector VarIdSet; +typedef vector VarNodes; +typedef vector BnNodeSet; +typedef vector FgVarSet; +typedef vector FgFacSet; +typedef vector FactorSet; +typedef vector States; +typedef vector Ranges; +typedef vector DConf; +typedef pair DConstraint; -typedef double Param; -typedef vector ParamSet; -typedef const ParamSet& CParamSet; -typedef unsigned Vid; -typedef vector VidSet; -typedef const VidSet& CVidSet; -typedef vector VarSet; -typedef vector BnNodeSet; -typedef const BnNodeSet& CBnNodeSet; -typedef vector FgVarSet; -typedef const FgVarSet& CFgVarSet; -typedef vector FactorSet; -typedef const FactorSet& CFactorSet; -typedef vector LinkSet; -typedef const LinkSet& CLinkSet; -typedef vector EdgeSet; -typedef const EdgeSet& CEdgeSet; -typedef vector Domain; -typedef vector DConf; -typedef pair DConstraint; -typedef map IndexMap; // level of debug information static const unsigned DL = 0; @@ -54,197 +49,260 @@ static const int NO_EVIDENCE = -1; // number of digits to show when printing a parameter static const unsigned PRECISION = 5; -static const bool EXPORT_TO_DOT = false; -static const unsigned EXPORT_MIN_SIZE = 30; +static const bool COLLECT_STATISTICS = false; + +static const bool EXPORT_TO_GRAPHVIZ = false; +static const unsigned EXPORT_MINIMAL_SIZE = 100; + +static const double INF = -numeric_limits::infinity(); -namespace SolverOptions -{ - enum Schedule - { - S_SEQ_FIXED, - S_SEQ_RANDOM, - S_PARALLEL, - S_MAX_RESIDUAL +namespace NumberSpace { + enum ns { + NORMAL, + LOGARITHM + }; +}; + + + +extern NumberSpace::ns NSPACE; + + +namespace InfAlgorithms { + enum InfAlgs + { + VE, // variable elimination + BN_BP, // bayesian network belief propagation + FG_BP, // factor graph belief propagation + CBP // counting bp solver + }; + extern InfAlgs infAlgorithm; +}; + + +namespace BpOptions +{ + enum Schedule { + SEQ_FIXED, + SEQ_RANDOM, + PARALLEL, + MAX_RESIDUAL }; - extern bool runBayesBall; - extern bool convertBn2Fg; - extern bool compressFactorGraph; extern Schedule schedule; extern double accuracy; extern unsigned maxIter; + extern bool useAlwaysLoopySolver; } namespace Util { - void normalize (ParamSet&); - void pow (ParamSet&, unsigned); - double getL1dist (CParamSet, CParamSet); - double getMaxNorm (CParamSet, CParamSet); - bool isInteger (const string&); - string parametersToString (CParamSet); - vector getDomainConfigurations (const VarSet&); - vector getInstantiations (const VarSet&); + void toLog (ParamSet&); + void fromLog (ParamSet&); + void normalize (ParamSet&); + void logSum (Param&, Param); + void multiply (ParamSet&, const ParamSet&); + void multiply (ParamSet&, const ParamSet&, unsigned); + void add (ParamSet&, const ParamSet&); + void add (ParamSet&, const ParamSet&, unsigned); + void pow (ParamSet&, unsigned); + Param pow (Param, unsigned); + double getL1Distance (const ParamSet&, const ParamSet&); + double getMaxNorm (const ParamSet&, const ParamSet&); + unsigned getNumberOfDigits (int); + bool isInteger (const string&); + string parametersToString (const ParamSet&, unsigned = PRECISION); + BayesNet* generateBayesianNetworkTreeWithLevel (unsigned); + vector getDomainConfigurations (const VarNodes&); + vector getJointStateStrings (const VarNodes&); + double tl (Param v); + double fl (Param v); + double multIdenty(); + double addIdenty(); + double withEvidence(); + double noEvidence(); + double one(); + double zero(); }; + +inline void +Util::logSum (Param& x, Param y) +{ + // x = log (exp (x) + exp (y)); return; + assert (isfinite (x) && finite (y)); + // If one value is much smaller than the other, keep the larger value. + if (x < (y - log (1e200))) { + x = y; + return; + } + if (y < (x - log (1e200))) { + return; + } + double diff = x - y; + assert (isfinite (diff) && finite (x) && finite (y)); + if (!isfinite (exp (diff))) { // difference is too large + x = x > y ? x : y; + } else { // otherwise return the sum. + x = y + log (static_cast(1.0) + exp (diff)); + } +} + + + +inline void +Util::multiply (ParamSet& v1, const ParamSet& v2) +{ + assert (v1.size() == v2.size()); + for (unsigned i = 0; i < v1.size(); i++) { + v1[i] *= v2[i]; + } +} + + + +inline void +Util::multiply (ParamSet& v1, const ParamSet& v2, unsigned repetitions) +{ + for (unsigned count = 0; count < v1.size(); ) { + for (unsigned i = 0; i < v2.size(); i++) { + for (unsigned r = 0; r < repetitions; r++) { + v1[count] *= v2[i]; + count ++; + } + } + } +} + + + +inline void +Util::add (ParamSet& v1, const ParamSet& v2) +{ + assert (v1.size() == v2.size()); + for (unsigned i = 0; i < v1.size(); i++) { + v1[i] += v2[i]; + } +} + + + +inline void +Util::add (ParamSet& v1, const ParamSet& v2, unsigned repetitions) +{ + for (unsigned count = 0; count < v1.size(); ) { + for (unsigned i = 0; i < v2.size(); i++) { + for (unsigned r = 0; r < repetitions; r++) { + v1[count] += v2[i]; + count ++; + } + } + } +} + + + +inline double +Util::tl (Param v) +{ + return NSPACE == NumberSpace::NORMAL ? v : log(v); +} + +inline double +Util::fl (Param v) +{ + return NSPACE == NumberSpace::NORMAL ? v : exp(v); +} + +inline double +Util::multIdenty() { + return NSPACE == NumberSpace::NORMAL ? 1.0 : 0.0; +} + +inline double +Util::addIdenty() +{ + return NSPACE == NumberSpace::NORMAL ? 0.0 : INF; +} + +inline double +Util::withEvidence() +{ + return NSPACE == NumberSpace::NORMAL ? 1.0 : 0.0; +} + +inline double +Util::noEvidence() { + return NSPACE == NumberSpace::NORMAL ? 0.0 : INF; +} + +inline double +Util::one() +{ + return NSPACE == NumberSpace::NORMAL ? 1.0 : 0.0; +} + +inline double +Util::zero() { + return NSPACE == NumberSpace::NORMAL ? 0.0 : INF; +} + + struct NetInfo { - NetInfo (void) + NetInfo (unsigned size, bool loopy, unsigned nIters, double time) { - counting = 0; - nIters = 0; - solvingTime = 0.0; + this->size = size; + this->loopy = loopy; + this->nIters = nIters; + this->time = time; } - unsigned counting; - double solvingTime; + unsigned size; + bool loopy; unsigned nIters; + double time; }; struct CompressInfo { - CompressInfo (unsigned a, unsigned b, unsigned c, - unsigned d, unsigned e) { - nUncVars = a; - nUncFactors = b; - nCompVars = c; - nCompFactors = d; - nNeighborlessVars = e; + CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e) + { + nGroundVars = a; + nGroundFactors = b; + nClusterVars = c; + nClusterFactors = d; + nWithoutNeighs = e; } - unsigned nUncVars; - unsigned nUncFactors; - unsigned nCompVars; - unsigned nCompFactors; - unsigned nNeighborlessVars; + unsigned nGroundVars; + unsigned nGroundFactors; + unsigned nClusterVars; + unsigned nClusterFactors; + unsigned nWithoutNeighs; }; -typedef map StatisticMap; class Statistics { public: - - static void updateStats (unsigned size, unsigned nIters, double time) - { - StatisticMap::iterator it = stats_.find (size); - if (it == stats_.end()) { - it = (stats_.insert (make_pair (size, NetInfo()))).first; - } else { - it->second.counting ++; - it->second.nIters += nIters; - it->second.solvingTime += time; - totalOfIterations += nIters; - if (nIters > maxIterations) { - maxIterations = nIters; - } - } - } - - static void updateCompressingStats (unsigned nUncVars, - unsigned nUncFactors, - unsigned nCompVars, - unsigned nCompFactors, - unsigned nNeighborlessVars) { - compressInfo_.push_back (CompressInfo ( - nUncVars, nUncFactors, nCompVars, nCompFactors, nNeighborlessVars)); - } - - static void printCompressingStats (const char* fileName) - { - ofstream out (fileName); - if (!out.is_open()) { - cerr << "error: cannot open file to write at " ; - cerr << "BayesNet::printCompressingStats()" << endl; - abort(); - } - out << "--------------------------------------" ; - out << "--------------------------------------" << endl; - out << " Compression Stats" << endl; - out << "--------------------------------------" ; - out << "--------------------------------------" << endl; - out << left; - out << "Uncompress Compressed Uncompress Compressed Neighborless"; - out << endl; - out << "Vars Vars Factors Factors Vars" ; - out << endl; - for (unsigned i = 0; i < compressInfo_.size(); i++) { - out << setw (13) << compressInfo_[i].nUncVars; - out << setw (13) << compressInfo_[i].nCompVars; - out << setw (13) << compressInfo_[i].nUncFactors; - out << setw (13) << compressInfo_[i].nCompFactors; - out << setw (13) << compressInfo_[i].nNeighborlessVars; - out << endl; - } - } - - static unsigned getCounting (unsigned size) - { - StatisticMap::iterator it = stats_.find(size); - assert (it != stats_.end()); - return it->second.counting; - } - - static void writeStats (void) - { - ofstream out ("../../stats.txt"); - if (!out.is_open()) { - cerr << "error: cannot open file to write at " ; - cerr << "Statistics::updateStats()" << endl; - abort(); - } - unsigned avgIterations = 0; - if (numSolvedLoopyNets > 0) { - avgIterations = totalOfIterations / numSolvedLoopyNets; - } - double totalSolvingTime = 0.0; - for (StatisticMap::iterator it = stats_.begin(); - it != stats_.end(); it++) { - totalSolvingTime += it->second.solvingTime; - } - out << "created networks: " << numCreatedNets << endl; - out << "solver runs on polytrees: " << numSolvedPolyTrees << endl; - out << "solver runs on loopy networks: " << numSolvedLoopyNets << endl; - out << " unconverged: " << numUnconvergedRuns << endl; - out << " max iterations: " << maxIterations << endl; - out << " average iterations: " << avgIterations << endl; - out << "total solving time " << totalSolvingTime << endl; - out << endl; - out << left << endl; - out << setw (15) << "Network Size" ; - out << setw (15) << "Counting" ; - out << setw (15) << "Solving Time" ; - out << setw (15) << "Average Time" ; - out << setw (15) << "#Iterations" ; - out << endl; - for (StatisticMap::iterator it = stats_.begin(); - it != stats_.end(); it++) { - out << setw (15) << it->first; - out << setw (15) << it->second.counting; - out << setw (15) << it->second.solvingTime; - if (it->second.counting > 0) { - out << setw (15) << it->second.solvingTime / it->second.counting; - } else { - out << setw (15) << "0.0" ; - } - out << setw (15) << it->second.nIters; - out << endl; - } - out.close(); - } - - static unsigned numCreatedNets; - static unsigned numSolvedPolyTrees; - static unsigned numSolvedLoopyNets; - static unsigned numUnconvergedRuns; + static unsigned getSolvedNetworksCounting (void); + static void incrementPrimaryNetworksCounting (void); + static unsigned getPrimaryNetworksCounting (void); + static void updateStatistics (unsigned, bool, unsigned, double); + static void printStatistics (void); + static void writeStatisticsToFile (const char*); + static void updateCompressingStatistics ( + unsigned, unsigned, unsigned, unsigned, unsigned); private: - static StatisticMap stats_; - static unsigned maxIterations; - static unsigned totalOfIterations; + static string getStatisticString (void); + + static vector netInfo_; static vector compressInfo_; + static unsigned primaryNetCount_; }; -#endif //BP_SHARED_H +#endif // HORUS_SHARED_H diff --git a/packages/CLPBN/clpbn/bp/Solver.cpp b/packages/CLPBN/clpbn/bp/Solver.cpp new file mode 100644 index 000000000..f84f98d7b --- /dev/null +++ b/packages/CLPBN/clpbn/bp/Solver.cpp @@ -0,0 +1,53 @@ +#include "Solver.h" + + +void +Solver::printAllPosterioris (void) +{ + const VarNodes& vars = gm_->getVariableNodes(); + for (unsigned i = 0; i < vars.size(); i++) { + printPosterioriOf (vars[i]->varId()); + cout << endl; + } +} + + + +void +Solver::printPosterioriOf (VarId vid) +{ + VarNode* var = gm_->getVariableNode (vid); + const ParamSet& posterioriDist = getPosterioriOf (vid); + const States& states = var->states(); + for (unsigned i = 0; i < states.size(); i++) { + cout << "P(" << var->label() << "=" << states[i] << ") = " ; + cout << setprecision (PRECISION) << posterioriDist[i]; + cout << endl; + } + cout << endl; +} + + + +void +Solver::printJointDistributionOf (const VarIdSet& vids) +{ + VarNodes vars; + VarIdSet vidsWithoutEvidence; + for (unsigned i = 0; i < vids.size(); i++) { + VarNode* var = gm_->getVariableNode (vids[i]); + if (var->hasEvidence() == false) { + vars.push_back (var); + vidsWithoutEvidence.push_back (vids[i]); + } + } + const ParamSet& jointDist = getJointDistributionOf (vidsWithoutEvidence); + vector jointStrings = Util::getJointStateStrings (vars); + for (unsigned i = 0; i < jointDist.size(); i++) { + cout << "P(" << jointStrings[i] << ") = " ; + cout << setprecision (PRECISION) << jointDist[i]; + cout << endl; + } + cout << endl; +} + diff --git a/packages/CLPBN/clpbn/bp/Solver.h b/packages/CLPBN/clpbn/bp/Solver.h index 170f3cb40..a8ce5dc3b 100644 --- a/packages/CLPBN/clpbn/bp/Solver.h +++ b/packages/CLPBN/clpbn/bp/Solver.h @@ -1,10 +1,10 @@ -#ifndef BP_SOLVER_H -#define BP_SOLVER_H +#ifndef HORUS_SOLVER_H +#define HORUS_SOLVER_H #include #include "GraphicalModel.h" -#include "Variable.h" +#include "VarNode.h" using namespace std; @@ -15,66 +15,18 @@ class Solver { gm_ = gm; } - virtual ~Solver() {} // to call subclass destructor - virtual void runSolver (void) = 0; - virtual ParamSet getPosterioriOf (Vid) const = 0; - virtual ParamSet getJointDistributionOf (const VidSet&) = 0; - - void printAllPosterioris (void) const - { - VarSet vars = gm_->getVariables(); - for (unsigned i = 0; i < vars.size(); i++) { - printPosterioriOf (vars[i]->getVarId()); - } - } - - void printPosterioriOf (Vid vid) const - { - Variable* var = gm_->getVariable (vid); - cout << endl; - cout << setw (20) << left << var->getLabel() << "posteriori" ; - cout << endl; - cout << "------------------------------" ; - cout << endl; - const Domain& domain = var->getDomain(); - ParamSet results = getPosterioriOf (vid); - for (unsigned xi = 0; xi < var->getDomainSize(); xi++) { - cout << setw (20) << domain[xi]; - cout << setprecision (PRECISION) << results[xi]; - cout << endl; - } - cout << endl; - } - - void printJointDistributionOf (const VidSet& vids) - { - const ParamSet& jointDist = getJointDistributionOf (vids); - cout << endl; - cout << "joint distribution of " ; - VarSet vars; - for (unsigned i = 0; i < vids.size() - 1; i++) { - Variable* var = gm_->getVariable (vids[i]); - cout << var->getLabel() << ", " ; - vars.push_back (var); - } - Variable* var = gm_->getVariable (vids[vids.size() - 1]); - cout << var->getLabel() ; - vars.push_back (var); - cout << endl; - cout << "------------------------------" ; - cout << endl; - const vector& domainConfs = Util::getInstantiations (vars); - for (unsigned i = 0; i < jointDist.size(); i++) { - cout << left << setw (20) << domainConfs[i]; - cout << setprecision (PRECISION) << jointDist[i]; - cout << endl; - } - cout << endl; - } + virtual ~Solver() {} // to ensure that subclass destructor is called + virtual void runSolver (void) = 0; + virtual ParamSet getPosterioriOf (VarId) = 0; + virtual ParamSet getJointDistributionOf (const VarIdSet&) = 0; + void printAllPosterioris (void); + void printPosterioriOf (VarId vid); + void printJointDistributionOf (const VarIdSet& vids); + private: const GraphicalModel* gm_; }; -#endif //BP_SOLVER_H +#endif // HORUS_SOLVER_H diff --git a/packages/CLPBN/clpbn/bp/StatesIndexer.h b/packages/CLPBN/clpbn/bp/StatesIndexer.h new file mode 100644 index 000000000..37afaabf7 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/StatesIndexer.h @@ -0,0 +1,246 @@ +#ifndef HORUS_STATESINDEXER_H +#define HORUS_STATESINDEXER_H + +#include + +class StatesIndexer { + public: + + StatesIndexer (const Ranges& ranges) + { + maxIndex_ = 1; + states_.resize (ranges.size(), 0); + ranges_ = ranges; + for (unsigned i = 0; i < ranges.size(); i++) { + maxIndex_ *= ranges[i]; + } + linearIndex_ = 0; + } + + + StatesIndexer (const VarNodes& vars) + { + maxIndex_ = 1; + states_.resize (vars.size(), 0); + ranges_.reserve (vars.size()); + for (unsigned i = 0; i < vars.size(); i++) { + ranges_.push_back (vars[i]->nrStates()); + maxIndex_ *= vars[i]->nrStates(); + } + linearIndex_ = 0; + } + + StatesIndexer& operator++ (void) { + for (int i = ranges_.size() - 1; i >= 0; i--) { + states_[i] ++; + if (states_[i] == (int)ranges_[i]) { + states_[i] = 0; + } else { + break; + } + } + linearIndex_ ++; + return *this; + } + + StatesIndexer& operator-- (void) { + for (int i = ranges_.size() - 1; i >= 0; i--) { + states_[i] --; + if (states_[i] == -1) { + states_[i] = ranges_[i] - 1; + } else { + break; + } + } + linearIndex_ --; + return *this; + } + + void incrementState (unsigned whichVar) + { + for (int i = whichVar; i >= 0; i--) { + states_[i] ++; + if (states_[i] == (int)ranges_[i] && i != 0) { + if (i == 0) { + linearIndex_ = maxIndex_; + } else { + states_[i] = 0; + } + } else { + linearIndex_ = getLinearIndexFromStates(); + return; + } + } + } + + void decrementState (unsigned whichVar) + { + for (int i = whichVar; i >= 0; i--) { + states_[i] --; + if (states_[i] == -1) { + if (i == 0) { + linearIndex_ = -1; + } else { + states_[i] = ranges_[i] - 1; + } + } else { + linearIndex_ = getLinearIndexFromStates(); + return; + } + } + } + + void nextSameState (unsigned whichVar) + { + for (int i = ranges_.size() - 1; i >= 0; i--) { + if (i != (int)whichVar) { + states_[i] ++; + if (states_[i] == (int)ranges_[i]) { + if (i == 0 || (i-1 == (int)whichVar && whichVar == 0)) { + linearIndex_ = maxIndex_; + } else { + states_[i] = 0; + } + } else { + linearIndex_ = getLinearIndexFromStates(); + return; + } + } + } + } + + void previousSameState (unsigned whichVar) + { + for (int i = ranges_.size() - 1; i >= 0; i--) { + if (i != (int)whichVar) { + states_[i] --; + if (states_[i] == - 1) { + if (i == 0 || (i-1 == (int)whichVar && whichVar == 0)) { + linearIndex_ = -1; + } else { + states_[i] = ranges_[i] - 1; + } + } else { + linearIndex_ = getLinearIndexFromStates(); + return; + } + } + } + } + + void moveToBegin (void) + { + std::fill (states_.begin(), states_.end(), 0); + linearIndex_ = 0; + } + + void moveToEnd (void) + { + for (unsigned i = 0; i < states_.size(); i++) { + states_[i] = ranges_[i] - 1; + } + linearIndex_ = maxIndex_ - 1; + } + + bool valid (void) const + { + return linearIndex_ >= 0 && linearIndex_ < (int)maxIndex_; + } + + unsigned getLinearIndex (void) const + { + return linearIndex_; + } + + const vector& getStates (void) const + { + return states_; + } + + unsigned operator[] (unsigned whichVar) const + { + assert (valid()); + assert (whichVar < states_.size()); + return states_[whichVar]; + } + + string toString (void) const + { + stringstream ss; + ss << "linear index=" << setw (3) << linearIndex_ << " " ; + ss << "states= [" << states_[0] ; + for (unsigned i = 1; i < states_.size(); i++) { + ss << ", " << states_[i]; + } + ss << "]" ; + return ss.str(); + } + + private: + unsigned getLinearIndexFromStates (void) + { + unsigned prod = 1; + unsigned linearIndex = 0; + for (int i = states_.size() - 1; i >= 0; i--) { + linearIndex += states_[i] * prod; + prod *= ranges_[i]; + } + return linearIndex; + } + + int linearIndex_; + int maxIndex_; + vector states_; + vector ranges_; +}; + + +/* + FgVarNode* v1 = new FgVarNode (0, 4); + FgVarNode* v2 = new FgVarNode (1, 3); + FgVarNode* v3 = new FgVarNode (2, 2); + FgVarSet vars = {v1,v2,v3}; + ParamSet params = { + 0.2, 0.44, 0.1, 0.88, 0.22,0.62,0.32, 0.42, 0.11, 0.88, 0.8,0.5, + 0.22, 0.4, 0.11, 0.8, 0.224,0.6,0.21, 0.44, 0.14, 0.68, 0.41,0.6 + }; + Factor f (vars,params); + StatesIndexer idx (vars); + while (idx.valid()) + { + cout << idx.toString() << " p=" << params[idx.getLinearIndex()] << endl; + idx.incrementVariableState (0); + idx.nextSameState (1); + ++idx; + } + cout << endl; + idx.moveToEnd(); + while (idx.valid()) + { + cout << idx.toString() << " p=" << params[idx.getLinearIndex()] << endl; + idx.decrementVariableState (0); + idx.previousSameState (1); + --idx; + } +*/ + + +/* + FgVarNode* x0 = new FgVarNode (0, 2); + FgVarNode* x1 = new FgVarNode (1, 2); + FgVarNode* x2 = new FgVarNode (2, 2); + FgVarNode* x3 = new FgVarNode (2, 2); + FgVarNode* x4 = new FgVarNode (2, 2); + FgVarSet vars_ = {x0,x1,x2,x3,x4}; + ParamSet params_ = { + 0.2, 0.44, 0.1, 0.88, 0.11, 0.88, 0.8, 0.5, + 0.2, 0.44, 0.1, 0.88, 0.11, 0.88, 0.8, 0.5, + 0.2, 0.44, 0.1, 0.88, 0.11, 0.88, 0.8, 0.5, + 0.2, 0.44, 0.1, 0.88, 0.11, 0.88, 0.8, 0.5 + }; + Factor ff (vars_,params_); + ff.printFactor(); +*/ + +#endif // HORUS_STATESINDEXER_H + diff --git a/packages/CLPBN/clpbn/bp/Util.cpp b/packages/CLPBN/clpbn/bp/Util.cpp index 2b1a86adc..3994e29ae 100644 --- a/packages/CLPBN/clpbn/bp/Util.cpp +++ b/packages/CLPBN/clpbn/bp/Util.cpp @@ -1,91 +1,191 @@ #include -#include "Variable.h" +#include "BayesNet.h" +#include "VarNode.h" #include "Shared.h" +#include "StatesIndexer.h" -namespace SolverOptions { - -bool runBayesBall = false; -bool convertBn2Fg = true; -bool compressFactorGraph = true; -Schedule schedule = S_SEQ_FIXED; -//Schedule schedule = S_SEQ_RANDOM; -//Schedule schedule = S_PARALLEL; -//Schedule schedule = S_MAX_RESIDUAL; -double accuracy = 0.0001; -unsigned maxIter = 1000; //FIXME - +namespace InfAlgorithms { +InfAlgs infAlgorithm = InfAlgorithms::VE; +//InfAlgs infAlgorithm = InfAlgorithms::BN_BP; +//InfAlgs infAlgorithm = InfAlgorithms::FG_BP; +//InfAlgs infAlgorithm = InfAlgorithms::CBP; } +namespace BpOptions { +Schedule schedule = BpOptions::Schedule::SEQ_FIXED; +//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM; +//Schedule schedule = BpOptions::Schedule::PARALLEL; +//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL; +double accuracy = 0.0001; +unsigned maxIter = 1000; +bool useAlwaysLoopySolver = true; +} -unsigned Statistics::numCreatedNets = 0; -unsigned Statistics::numSolvedPolyTrees = 0; -unsigned Statistics::numSolvedLoopyNets = 0; -unsigned Statistics::numUnconvergedRuns = 0; -unsigned Statistics::maxIterations = 0; -unsigned Statistics::totalOfIterations = 0; +NumberSpace::ns NSPACE = NumberSpace::NORMAL; + +unordered_map GraphicalModel::varsInfo_; + +vector Statistics::netInfo_; vector Statistics::compressInfo_; -StatisticMap Statistics::stats_; - +unsigned Statistics::primaryNetCount_; namespace Util { void -normalize (ParamSet& v) +toLog (ParamSet& v) { - double sum = 0.0; for (unsigned i = 0; i < v.size(); i++) { - sum += v[i]; - } - assert (sum != 0.0); - for (unsigned i = 0; i < v.size(); i++) { - v[i] /= sum; + v[i] = log (v[i]); } } + +void +fromLog (ParamSet& v) +{ + for (unsigned i = 0; i < v.size(); i++) { + v[i] = exp (v[i]); + } +} + + + +void +normalize (ParamSet& v) +{ + double sum; + switch (NSPACE) { + case NumberSpace::NORMAL: + sum = 0.0; + for (unsigned i = 0; i < v.size(); i++) { + sum += v[i]; + } + assert (sum != 0.0); + for (unsigned i = 0; i < v.size(); i++) { + v[i] /= sum; + } + break; + case NumberSpace::LOGARITHM: + sum = addIdenty(); + for (unsigned i = 0; i < v.size(); i++) { + logSum (sum, v[i]); + } + assert (sum != -numeric_limits::infinity()); + for (unsigned i = 0; i < v.size(); i++) { + v[i] -= sum; + } + } +} + + + void pow (ParamSet& v, unsigned expoent) { - for (unsigned i = 0; i < v.size(); i++) { - double value = 1; - for (unsigned j = 0; j < expoent; j++) { - value *= v[i]; - } - v[i] = value; + if (expoent == 1) { + return; // optimization + } + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned i = 0; i < v.size(); i++) { + double value = 1.0; + for (unsigned j = 0; j < expoent; j++) { + value *= v[i]; + } + v[i] = value; + } + break; + case NumberSpace::LOGARITHM: + for (unsigned i = 0; i < v.size(); i++) { + v[i] *= expoent; + } } } + +Param +pow (Param p, unsigned expoent) +{ + double value = 1.0; + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned i = 0; i < expoent; i++) { + value *= p; + } + break; + case NumberSpace::LOGARITHM: + value = p * expoent; + } + return value; +} + + + double -getL1dist (const ParamSet& v1, const ParamSet& v2) +getL1Distance (const ParamSet& v1, const ParamSet& v2) { assert (v1.size() == v2.size()); double dist = 0.0; - for (unsigned i = 0; i < v1.size(); i++) { - dist += abs (v1[i] - v2[i]); + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned i = 0; i < v1.size(); i++) { + dist += abs (v1[i] - v2[i]); + } + break; + case NumberSpace::LOGARITHM: + for (unsigned i = 0; i < v1.size(); i++) { + dist += abs (exp(v1[i]) - exp(v2[i])); + } } return dist; } + double getMaxNorm (const ParamSet& v1, const ParamSet& v2) { assert (v1.size() == v2.size()); double max = 0.0; - for (unsigned i = 0; i < v1.size(); i++) { - double diff = abs (v1[i] - v2[i]); - if (diff > max) { - max = diff; - } + switch (NSPACE) { + case NumberSpace::NORMAL: + for (unsigned i = 0; i < v1.size(); i++) { + double diff = abs (v1[i] - v2[i]); + if (diff > max) { + max = diff; + } + } + break; + case NumberSpace::LOGARITHM: + for (unsigned i = 0; i < v1.size(); i++) { + double diff = abs (exp(v1[i]) - exp(v2[i])); + if (diff > max) { + max = diff; + } + } } return max; } + +unsigned +getNumberOfDigits (int number) { + unsigned count = 1; + while (number >= 10) { + number /= 10; + count ++; + } + return count; +} + + + bool isInteger (const string& s) { @@ -100,9 +200,10 @@ isInteger (const string& s) string -parametersToString (CParamSet v) +parametersToString (const ParamSet& v, unsigned precision) { - stringstream ss; + stringstream ss; + ss.precision (precision); ss << "[" ; for (unsigned i = 0; i < v.size() - 1; i++) { ss << v[i] << ", " ; @@ -116,12 +217,44 @@ parametersToString (CParamSet v) -vector -getDomainConfigurations (const VarSet& vars) +BayesNet* +generateBayesianNetworkTreeWithLevel (unsigned level) { + BayesNet* bn = new BayesNet(); + Distribution* dist = new Distribution (ParamSet() = {0.1, 0.5, 0.2, 0.7}); + BayesNode* root = bn->addNode (0, 2, -1, BnNodeSet() = {}, + new Distribution (ParamSet() = {0.1, 0.5})); + BnNodeSet prevLevel = { root }; + BnNodeSet currLevel; + VarId vidCount = 1; + for (unsigned l = 1; l < level; l++) { + currLevel.clear(); + for (unsigned i = 0; i < prevLevel.size(); i++) { + currLevel.push_back ( + bn->addNode (vidCount, 2, -1, BnNodeSet() = {prevLevel[i]}, dist)); + vidCount ++; + currLevel.push_back ( + bn->addNode (vidCount, 2, -1, BnNodeSet() = {prevLevel[i]}, dist)); + vidCount ++; + } + prevLevel = currLevel; + } + for (unsigned i = 0; i < prevLevel.size(); i++) { + prevLevel[i]->setEvidence (0); + } + bn->setIndexes(); + return bn; +} + + + +vector +getDomainConfigurations (const VarNodes& vars) +{ + // TODO this method must die unsigned nConfs = 1; for (unsigned i = 0; i < vars.size(); i++) { - nConfs *= vars[i]->getDomainSize(); + nConfs *= vars[i]->nrStates(); } vector confs (nConfs); @@ -133,59 +266,213 @@ getDomainConfigurations (const VarSet& vars) for (int i = vars.size() - 1; i >= 0; i--) { unsigned index = 0; while (index < nConfs) { - for (unsigned j = 0; j < vars[i]->getDomainSize(); j++) { + for (unsigned j = 0; j < vars[i]->nrStates(); j++) { for (unsigned r = 0; r < nReps; r++) { confs[index][i] = j; index++; } } } - nReps *= vars[i]->getDomainSize(); + nReps *= vars[i]->nrStates(); } return confs; } + vector -getInstantiations (const VarSet& vars) +getJointStateStrings (const VarNodes& vars) { - //FIXME handle variables without domain - /* - char c = 'a' ; - const DConf& conf = entries[i].getDomainConfiguration(); - for (unsigned j = 0; j < conf.size(); j++) { - if (j != 0) ss << "," ; - ss << c << conf[j] + 1; - c ++; + StatesIndexer idx (vars); + vector jointStrings; + while (idx.valid()) { + stringstream ss; + for (unsigned i = 0; i < vars.size(); i++) { + if (i != 0) ss << ", " ; + ss << vars[i]->label() << "=" << vars[i]->states()[(idx[i])]; + } + jointStrings.push_back (ss.str()); + ++ idx; } - */ - unsigned rowSize = 1; - for (unsigned i = 0; i < vars.size(); i++) { - rowSize *= vars[i]->getDomainSize(); + return jointStrings; +} + + +} + + + +unsigned +Statistics::getSolvedNetworksCounting (void) +{ + return netInfo_.size(); +} + + + +void +Statistics::incrementPrimaryNetworksCounting (void) +{ + primaryNetCount_ ++; +} + + + +unsigned +Statistics::getPrimaryNetworksCounting (void) +{ + return primaryNetCount_; +} + + + +void +Statistics::updateStatistics (unsigned size, bool loopy, + unsigned nIters, double time) +{ + netInfo_.push_back (NetInfo (size, loopy, nIters, time)); +} + + + +void +Statistics::printStatistics (void) +{ + cout << getStatisticString(); +} + + + +void +Statistics::writeStatisticsToFile (const char* fileName) +{ + ofstream out (fileName); + if (!out.is_open()) { + cerr << "error: cannot open file to write at " ; + cerr << "Statistics::writeStatisticsToFile()" << endl; + abort(); } + out << getStatisticString(); + out.close(); +} - vector headers (rowSize); - unsigned nReps = 1; - for (int i = vars.size() - 1; i >= 0; i--) { - Domain domain = vars[i]->getDomain(); - unsigned index = 0; - while (index < rowSize) { - for (unsigned j = 0; j < vars[i]->getDomainSize(); j++) { - for (unsigned r = 0; r < nReps; r++) { - if (headers[index] != "") { - headers[index] = domain[j] + ", " + headers[index]; - } else { - headers[index] = domain[j]; - } - index++; - } + +void +Statistics::updateCompressingStatistics (unsigned nGroundVars, + unsigned nGroundFactors, + unsigned nClusterVars, + unsigned nClusterFactors, + unsigned nWithoutNeighs) { + compressInfo_.push_back (CompressInfo (nGroundVars, nGroundFactors, + nClusterVars, nClusterFactors, nWithoutNeighs)); +} + + + +string +Statistics::getStatisticString (void) +{ + stringstream ss2, ss3, ss4, ss1; + ss1 << "running mode: " ; + switch (InfAlgorithms::infAlgorithm) { + case InfAlgorithms::VE: ss1 << "ve" << endl; break; + case InfAlgorithms::BN_BP: ss1 << "bn_bp" << endl; break; + case InfAlgorithms::FG_BP: ss1 << "fg_bp" << endl; break; + case InfAlgorithms::CBP: ss1 << "cbp" << endl; break; + } + ss1 << "message schedule: " ; + switch (BpOptions::schedule) { + case BpOptions::Schedule::SEQ_FIXED: ss1 << "sequential fixed" << endl; break; + case BpOptions::Schedule::SEQ_RANDOM: ss1 << "sequential random" << endl; break; + case BpOptions::Schedule::PARALLEL: ss1 << "parallel" << endl; break; + case BpOptions::Schedule::MAX_RESIDUAL: ss1 << "max residual" << endl; break; + } + ss1 << "max iterations: " << BpOptions::maxIter << endl; + ss1 << "accuracy " << BpOptions::accuracy << endl; + if (BpOptions::useAlwaysLoopySolver) { + ss1 << "always loopy solver: yes" << endl; + } else { + ss1 << "always loopy solver: no" << endl; + } + ss1 << endl << endl; + + ss2 << "---------------------------------------------------" << endl; + ss2 << " Network information" << endl; + ss2 << "---------------------------------------------------" << endl; + ss2 << left; + ss2 << setw (15) << "Network Size" ; + ss2 << setw (9) << "Loopy" ; + ss2 << setw (15) << "Iterations" ; + ss2 << setw (15) << "Solving Time" ; + ss2 << endl; + unsigned nLoopyNets = 0; + unsigned nUnconvergedRuns = 0; + double totalSolvingTime = 0.0; + for (unsigned i = 0; i < netInfo_.size(); i++) { + ss2 << setw (15) << netInfo_[i].size; + if (netInfo_[i].loopy) { + ss2 << setw (9) << "yes"; + nLoopyNets ++; + } else { + ss2 << setw (9) << "no"; + } + if (netInfo_[i].nIters == 0) { + ss2 << setw (15) << "n/a" ; + } else { + ss2 << setw (15) << netInfo_[i].nIters; + if (netInfo_[i].nIters > BpOptions::maxIter) { + nUnconvergedRuns ++; } } - nReps *= vars[i]->getDomainSize(); + ss2 << setw (15) << netInfo_[i].time; + totalSolvingTime += netInfo_[i].time; + ss2 << endl; } - return headers; -} - + ss2 << endl << endl; + + unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0; + if (compressInfo_.size() > 0) { + ss3 << "---------------------------------------------------" << endl; + ss3 << " Compression information" << endl; + ss3 << "---------------------------------------------------" << endl; + ss3 << left; + ss3 << "Ground Cluster Ground Cluster Neighborless" << endl; + ss3 << "Vars Vars Factors Factors Vars" << endl; + for (unsigned i = 0; i < compressInfo_.size(); i++) { + ss3 << setw (9) << compressInfo_[i].nGroundVars; + ss3 << setw (10) << compressInfo_[i].nClusterVars; + ss3 << setw (10) << compressInfo_[i].nGroundFactors; + ss3 << setw (10) << compressInfo_[i].nClusterFactors; + ss3 << setw (10) << compressInfo_[i].nWithoutNeighs; + ss3 << endl; + c1 += compressInfo_[i].nGroundVars - compressInfo_[i].nWithoutNeighs; + c2 += compressInfo_[i].nClusterVars; + c3 += compressInfo_[i].nGroundFactors - compressInfo_[i].nWithoutNeighs; + c4 += compressInfo_[i].nClusterFactors; + if (compressInfo_[i].nWithoutNeighs != 0) { + c2 --; + c4 --; + } + } + ss3 << endl << endl; + } + + ss4 << "primary networks: " << primaryNetCount_ << endl; + ss4 << "solved networks: " << netInfo_.size() << endl; + ss4 << "loopy networks: " << nLoopyNets << endl; + ss4 << "unconverged runs: " << nUnconvergedRuns << endl; + ss4 << "total solving time: " << totalSolvingTime << endl; + if (compressInfo_.size() > 0) { + double pc1 = (1.0 - (c2 / (double)c1)) * 100.0; + double pc2 = (1.0 - (c4 / (double)c3)) * 100.0; + ss4 << setprecision (5); + ss4 << "variable compression: " << pc1 << "%" << endl; + ss4 << "factor compression: " << pc2 << "%" << endl; + } + ss4 << endl << endl; + + ss1 << ss4.str() << ss2.str() << ss3.str(); + return ss1.str(); } diff --git a/packages/CLPBN/clpbn/bp/VarElimSolver.cpp b/packages/CLPBN/clpbn/bp/VarElimSolver.cpp new file mode 100644 index 000000000..09e1227e3 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/VarElimSolver.cpp @@ -0,0 +1,211 @@ +#include + +#include "VarElimSolver.h" +#include "ElimGraph.h" +#include "Factor.h" + + +VarElimSolver::VarElimSolver (const BayesNet& bn) : Solver (&bn) +{ + bayesNet_ = &bn; + factorGraph_ = new FactorGraph (bn); +} + + + +VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (&fg) +{ + bayesNet_ = 0; + factorGraph_ = &fg; +} + + + +VarElimSolver::~VarElimSolver (void) +{ + if (bayesNet_) { + delete factorGraph_; + } +} + + + +ParamSet +VarElimSolver::getPosterioriOf (VarId vid) +{ + FgVarNode* vn = factorGraph_->getFgVarNode (vid); + assert (vn); + if (vn->hasEvidence()) { + ParamSet params (vn->nrStates(), 0.0); + params[vn->getEvidence()] = 1.0; + return params; + } + return getJointDistributionOf (VarIdSet() = {vid}); +} + + + +ParamSet +VarElimSolver::getJointDistributionOf (const VarIdSet& vids) +{ + factorList_.clear(); + varFactors_.clear(); + elimOrder_.clear(); + createFactorList(); + introduceEvidence(); + chooseEliminationOrder (vids); + processFactorList (vids); + ParamSet params = factorList_.back()->getParameters(); + factorList_.back()->freeDistribution(); + delete factorList_.back(); + Util::normalize (params); + return params; +} + + + +void +VarElimSolver::createFactorList (void) +{ + const FgFacSet& factorNodes = factorGraph_->getFactorNodes(); + factorList_.reserve (factorNodes.size() * 2); + for (unsigned i = 0; i < factorNodes.size(); i++) { + factorList_.push_back (new Factor (*factorNodes[i]->factor())); + const FgVarSet& neighs = factorNodes[i]->neighbors(); + for (unsigned j = 0; j < neighs.size(); j++) { + unordered_map >::iterator it + = varFactors_.find (neighs[j]->varId()); + if (it == varFactors_.end()) { + it = varFactors_.insert (make_pair ( + neighs[j]->varId(), vector())).first; + } + it->second.push_back (i); + } + } +} + + + +void +VarElimSolver::introduceEvidence (void) +{ + const FgVarSet& varNodes = factorGraph_->getVarNodes(); + for (unsigned i = 0; i < varNodes.size(); i++) { + if (varNodes[i]->hasEvidence()) { + const vector& idxs = + varFactors_.find (varNodes[i]->varId())->second; + for (unsigned j = 0; j < idxs.size(); j++) { + Factor* factor = factorList_[idxs[j]]; + if (factor->nrVariables() == 1) { + factorList_[idxs[j]] = 0; + } else { + factorList_[idxs[j]]->removeInconsistentEntries ( + varNodes[i]->varId(), varNodes[i]->getEvidence()); + } + } + } + } +} + + + +void +VarElimSolver::chooseEliminationOrder (const VarIdSet& vids) +{ + if (bayesNet_) { + ElimGraph graph = ElimGraph (*bayesNet_); + elimOrder_ = graph.getEliminatingOrder (vids); + } else { + const FgVarSet& varNodes = factorGraph_->getVarNodes(); + for (unsigned i = 0; i < varNodes.size(); i++) { + VarId vid = varNodes[i]->varId(); + if (std::find (vids.begin(), vids.end(), vid) == vids.end() + && !varNodes[i]->hasEvidence()) { + elimOrder_.push_back (vid); + } + } + } +} + + + +void +VarElimSolver::processFactorList (const VarIdSet& vids) +{ + for (unsigned i = 0; i < elimOrder_.size(); i++) { + // cout << "-----------------------------------------" << endl; + // cout << "Eliminating " << elimOrder_[i]; + // cout << " in the following factors:" << endl; + // printActiveFactors(); + eliminate (elimOrder_[i]); + } + Factor* thisIsTheEnd = new Factor(); + + for (unsigned i = 0; i < factorList_.size(); i++) { + if (factorList_[i]) { + thisIsTheEnd->multiplyByFactor (*factorList_[i]); + factorList_[i]->freeDistribution(); + delete factorList_[i]; + factorList_[i] = 0; + } + } + VarIdSet vidsWithoutEvidence; + for (unsigned i = 0; i < vids.size(); i++) { + if (factorGraph_->getFgVarNode (vids[i])->hasEvidence() == false) { + vidsWithoutEvidence.push_back (vids[i]); + } + } + thisIsTheEnd->orderVariables (vidsWithoutEvidence); + factorList_.push_back (thisIsTheEnd); +} + + + +void +VarElimSolver::eliminate (VarId elimVar) +{ + FgVarNode* vn = factorGraph_->getFgVarNode (elimVar); + Factor* result = 0; + vector& idxs = varFactors_.find (elimVar)->second; + //cout << "eliminating " << setw (5) << elimVar << ":" ; + for (unsigned i = 0; i < idxs.size(); i++) { + unsigned idx = idxs[i]; + if (factorList_[idx]) { + if (result == 0) { + result = new Factor(*factorList_[idx]); + //cout << " " << factorList_[idx]->label(); + } else { + result->multiplyByFactor (*factorList_[idx]); + //cout << " x " << factorList_[idx]->label(); + } + factorList_[idx]->freeDistribution(); + delete factorList_[idx]; + factorList_[idx] = 0; + } + } + if (result != 0 && result->nrVariables() != 1) { + result->removeVariable (vn->varId()); + factorList_.push_back (result); + // cout << endl <<" factor size=" << result->size() << endl; + const VarIdSet& resultVarIds = result->getVarIds(); + for (unsigned i = 0; i < resultVarIds.size(); i++) { + vector& idxs = + varFactors_.find (resultVarIds[i])->second; + idxs.push_back (factorList_.size() - 1); + } + } +} + + + +void +VarElimSolver::printActiveFactors (void) +{ + for (unsigned i = 0; i < factorList_.size(); i++) { + if (factorList_[i] != 0) { + factorList_[i]->printFactor(); + cout << endl; + } + } +} + diff --git a/packages/CLPBN/clpbn/bp/VarElimSolver.h b/packages/CLPBN/clpbn/bp/VarElimSolver.h new file mode 100644 index 000000000..5cf6d0660 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/VarElimSolver.h @@ -0,0 +1,41 @@ +#ifndef HORUS_VARELIMSOLVER_H +#define HORUS_VARELIMSOLVER_H + +#include "unordered_map" + +#include "Solver.h" +#include "FactorGraph.h" +#include "BayesNet.h" +#include "Shared.h" + + +using namespace std; + + +class VarElimSolver : public Solver +{ + public: + VarElimSolver (const BayesNet&); + VarElimSolver (const FactorGraph&); + ~VarElimSolver (void); + void runSolver (void) { } + ParamSet getPosterioriOf (VarId); + ParamSet getJointDistributionOf (const VarIdSet&); + + private: + void createFactorList (void); + void introduceEvidence (void); + void chooseEliminationOrder (const VarIdSet&); + void processFactorList (const VarIdSet&); + void eliminate (VarId); + void printActiveFactors (void); + + const BayesNet* bayesNet_; + const FactorGraph* factorGraph_; + vector factorList_; + VarIdSet elimOrder_; + unordered_map> varFactors_; +}; + +#endif // HORUS_VARELIMSOLVER_H + diff --git a/packages/CLPBN/clpbn/bp/VarNode.cpp b/packages/CLPBN/clpbn/bp/VarNode.cpp new file mode 100644 index 000000000..452befb4a --- /dev/null +++ b/packages/CLPBN/clpbn/bp/VarNode.cpp @@ -0,0 +1,100 @@ +#include +#include + +#include "VarNode.h" +#include "GraphicalModel.h" + +using namespace std; + + +VarNode::VarNode (const VarNode* v) +{ + varId_ = v->varId(); + nrStates_ = v->nrStates(); + evidence_ = v->getEvidence(); + index_ = std::numeric_limits::max(); +} + + + +VarNode::VarNode (VarId varId, unsigned nrStates, int evidence) +{ + assert (nrStates != 0); + assert (evidence < (int) nrStates); + varId_ = varId; + nrStates_ = nrStates; + evidence_ = evidence; + index_ = std::numeric_limits::max(); +} + + + +bool +VarNode::isValidState (int stateIndex) +{ + return stateIndex >= 0 && stateIndex < (int) nrStates_; +} + + + +bool +VarNode::isValidState (const string& stateName) +{ + States states = GraphicalModel::getVariableInformation (varId_).states; + return find (states.begin(), states.end(), stateName) != states.end(); +} + + + +void +VarNode::setEvidence (int ev) +{ + assert (ev < (int) nrStates_); + evidence_ = ev; +} + + + +void +VarNode::setEvidence (const string& ev) +{ + States states = GraphicalModel::getVariableInformation (varId_).states; + for (unsigned i = 0; i < states.size(); i++) { + if (states[i] == ev) { + evidence_ = i; + return; + } + } + assert (false); +} + + + +string +VarNode::label (void) const +{ + if (GraphicalModel::variablesHaveInformation()) { + return GraphicalModel::getVariableInformation (varId_).label; + } + stringstream ss; + ss << "x" << varId_; + return ss.str(); +} + + + +States +VarNode::states (void) const +{ + if (GraphicalModel::variablesHaveInformation()) { + return GraphicalModel::getVariableInformation (varId_).states; + } + States states; + for (unsigned i = 0; i < nrStates_; i++) { + stringstream ss; + ss << i ; + states.push_back (ss.str()); + } + return states; +} + diff --git a/packages/CLPBN/clpbn/bp/VarNode.h b/packages/CLPBN/clpbn/bp/VarNode.h new file mode 100644 index 000000000..ac7253324 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/VarNode.h @@ -0,0 +1,52 @@ +#ifndef HORUS_VARNODE_H +#define HORUS_VARNODE_H + +#include "Shared.h" + +using namespace std; + +class VarNode +{ + public: + VarNode (const VarNode*); + VarNode (VarId, unsigned, int = NO_EVIDENCE); + virtual ~VarNode (void) {}; + + bool isValidState (int); + bool isValidState (const string&); + void setEvidence (int); + void setEvidence (const string&); + string label (void) const; + States states (void) const; + + unsigned varId (void) const { return varId_; } + unsigned nrStates (void) const { return nrStates_; } + bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; } + int getEvidence (void) const { return evidence_; } + unsigned getIndex (void) const { return index_; } + void setIndex (unsigned idx) { index_ = idx; } + + operator unsigned () const { return index_; } + + bool operator== (const VarNode& var) const + { + assert (!(varId_ == var.varId() && nrStates_ != var.nrStates())); + return varId_ == var.varId(); + } + + bool operator!= (const VarNode& var) const + { + assert (!(varId_ == var.varId() && nrStates_ != var.nrStates())); + return varId_ != var.varId(); + } + + private: + VarId varId_; + unsigned nrStates_; + int evidence_; + unsigned index_; + +}; + +#endif // BP_VARNODE_H + diff --git a/packages/CLPBN/clpbn/bp/Variable.h b/packages/CLPBN/clpbn/bp/Variable.h deleted file mode 100644 index e3e0a95db..000000000 --- a/packages/CLPBN/clpbn/bp/Variable.h +++ /dev/null @@ -1,172 +0,0 @@ -#ifndef BP_VARIABLE_H -#define BP_VARIABLE_H - -#include - -#include - -#include "Shared.h" - -using namespace std; - -class Variable -{ - public: - - Variable (const Variable* v) - { - vid_ = v->getVarId(); - dsize_ = v->getDomainSize(); - if (v->hasDomain()) { - domain_ = v->getDomain(); - dsize_ = domain_.size(); - } else { - dsize_ = v->getDomainSize(); - } - evidence_ = v->getEvidence(); - if (v->hasLabel()) { - label_ = new string (v->getLabel()); - } else { - label_ = 0; - } - } - - Variable (Vid vid) - { - this->vid_ = vid; - this->dsize_ = 0; - this->evidence_ = NO_EVIDENCE; - this->label_ = 0; - } - - Variable (Vid vid, unsigned dsize, int evidence = NO_EVIDENCE, - const string& lbl = string()) - { - assert (dsize != 0); - assert (evidence < (int)dsize); - this->vid_ = vid; - this->dsize_ = dsize; - this->evidence_ = evidence; - if (!lbl.empty()) { - this->label_ = new string (lbl); - } else { - this->label_ = 0; - } - } - - Variable (Vid vid, const Domain& domain, int evidence = NO_EVIDENCE, - const string& lbl = string()) - { - assert (!domain.empty()); - assert (evidence < (int)domain.size()); - this->vid_ = vid; - this->dsize_ = domain.size(); - this->domain_ = domain; - this->evidence_ = evidence; - if (!lbl.empty()) { - this->label_ = new string (lbl); - } else { - this->label_ = 0; - } - } - - ~Variable (void) - { - delete label_; - } - - unsigned getVarId (void) const { return vid_; } - unsigned getIndex (void) const { return index_; } - void setIndex (unsigned idx) { index_ = idx; } - unsigned getDomainSize (void) const { return dsize_; } - bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; } - int getEvidence (void) const { return evidence_; } - bool hasDomain (void) const { return !domain_.empty(); } - bool hasLabel (void) const { return label_ != 0; } - - bool isValidStateIndex (int index) - { - return index >= 0 && index < (int)dsize_; - } - - bool isValidState (const string& state) - { - return find (domain_.begin(), domain_.end(), state) != domain_.end(); - } - - Domain getDomain (void) const - { - assert (dsize_ != 0); - if (domain_.size() == 0) { - Domain d; - for (unsigned i = 0; i < dsize_; i++) { - stringstream ss; - ss << "x" << i ; - d.push_back (ss.str()); - } - return d; - } else { - return domain_; - } - } - - void setDomainSize (unsigned dsize) - { - assert (dsize != 0); - dsize_ = dsize; - } - - void setDomain (const Domain& domain) - { - assert (!domain.empty()); - domain_ = domain; - dsize_ = domain.size(); - } - - void setEvidence (int ev) - { - assert (ev < dsize_); - evidence_ = ev; - } - - void setEvidence (const string& ev) - { - assert (isValidState (ev)); - for (unsigned i = 0; i < domain_.size(); i++) { - if (domain_[i] == ev) { - evidence_ = i; - } - } - } - - void setLabel (const string& label) - { - label_ = new string (label); - } - - string getLabel (void) const - { - if (label_ == 0) { - stringstream ss; - ss << "v" << vid_; - return ss.str(); - } else { - return *label_; - } - } - - - private: - DISALLOW_COPY_AND_ASSIGN (Variable); - - Vid vid_; - unsigned dsize_; - int evidence_; - Domain domain_; - string* label_; - unsigned index_; - -}; - -#endif // BP_VARIABLE_H - diff --git a/packages/CLPBN/clpbn/bp/callgrind.h b/packages/CLPBN/clpbn/bp/callgrind.h deleted file mode 100644 index d36b6f4eb..000000000 --- a/packages/CLPBN/clpbn/bp/callgrind.h +++ /dev/null @@ -1,147 +0,0 @@ - -/* - ---------------------------------------------------------------- - - Notice that the following BSD-style license applies to this one - file (callgrind.h) only. The rest of Valgrind is licensed under the - terms of the GNU General Public License, version 2, unless - otherwise indicated. See the COPYING file in the source - distribution for details. - - ---------------------------------------------------------------- - - This file is part of callgrind, a valgrind tool for cache simulation - and call tree tracing. - - Copyright (C) 2003-2010 Josef Weidendorfer. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. The origin of this software must not be misrepresented; you must - not claim that you wrote the original software. If you use this - software in a product, an acknowledgment in the product - documentation would be appreciated but is not required. - - 3. Altered source versions must be plainly marked as such, and must - not be misrepresented as being the original software. - - 4. The name of the author may not be used to endorse or promote - products derived from this software without specific prior written - permission. - - THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS - OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY - DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE - GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - ---------------------------------------------------------------- - - Notice that the above BSD-style license applies to this one file - (callgrind.h) only. The entire rest of Valgrind is licensed under - the terms of the GNU General Public License, version 2. See the - COPYING file in the source distribution for details. - - ---------------------------------------------------------------- -*/ - -#ifndef __CALLGRIND_H -#define __CALLGRIND_H - -#include "valgrind.h" - -/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !! - This enum comprises an ABI exported by Valgrind to programs - which use client requests. DO NOT CHANGE THE ORDER OF THESE - ENTRIES, NOR DELETE ANY -- add new ones at the end. - - The identification ('C','T') for Callgrind has historical - reasons: it was called "Calltree" before. Besides, ('C','G') would - clash with cachegrind. - */ - -typedef - enum { - VG_USERREQ__DUMP_STATS = VG_USERREQ_TOOL_BASE('C','T'), - VG_USERREQ__ZERO_STATS, - VG_USERREQ__TOGGLE_COLLECT, - VG_USERREQ__DUMP_STATS_AT, - VG_USERREQ__START_INSTRUMENTATION, - VG_USERREQ__STOP_INSTRUMENTATION - } Vg_CallgrindClientRequest; - -/* Dump current state of cost centers, and zero them afterwards */ -#define CALLGRIND_DUMP_STATS \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__DUMP_STATS, \ - 0, 0, 0, 0, 0); \ - } - -/* Dump current state of cost centers, and zero them afterwards. - The argument is appended to a string stating the reason which triggered - the dump. This string is written as a description field into the - profile data dump. */ -#define CALLGRIND_DUMP_STATS_AT(pos_str) \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__DUMP_STATS_AT, \ - pos_str, 0, 0, 0, 0); \ - } - -/* Zero cost centers */ -#define CALLGRIND_ZERO_STATS \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__ZERO_STATS, \ - 0, 0, 0, 0, 0); \ - } - -/* Toggles collection state. - The collection state specifies whether the happening of events - should be noted or if they are to be ignored. Events are noted - by increment of counters in a cost center */ -#define CALLGRIND_TOGGLE_COLLECT \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__TOGGLE_COLLECT, \ - 0, 0, 0, 0, 0); \ - } - -/* Start full callgrind instrumentation if not already switched on. - When cache simulation is done, it will flush the simulated cache; - this will lead to an artifical cache warmup phase afterwards with - cache misses which would not have happened in reality. */ -#define CALLGRIND_START_INSTRUMENTATION \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__START_INSTRUMENTATION, \ - 0, 0, 0, 0, 0); \ - } - -/* Stop full callgrind instrumentation if not already switched off. - This flushes Valgrinds translation cache, and does no additional - instrumentation afterwards, which effectivly will run at the same - speed as the "none" tool (ie. at minimal slowdown). - Use this to bypass Callgrind aggregation for uninteresting code parts. - To start Callgrind in this mode to ignore the setup phase, use - the option "--instr-atstart=no". */ -#define CALLGRIND_STOP_INSTRUMENTATION \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__STOP_INSTRUMENTATION, \ - 0, 0, 0, 0, 0); \ - } - -#endif /* __CALLGRIND_H */ diff --git a/packages/CLPBN/clpbn/bp/examples/cbp_example.uai b/packages/CLPBN/clpbn/bp/examples/cbp_example.uai new file mode 100644 index 000000000..239ec4ab1 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/examples/cbp_example.uai @@ -0,0 +1,14 @@ +MARKOV +3 +2 2 2 +2 +2 0 1 +2 2 1 + + +4 + 1.2 1.4 2.0 0.4 + +4 + 1.2 1.4 2.0 0.4 + diff --git a/packages/CLPBN/clpbn/bp/examples/loop.yap b/packages/CLPBN/clpbn/bp/examples/loop.yap deleted file mode 100644 index c18784975..000000000 --- a/packages/CLPBN/clpbn/bp/examples/loop.yap +++ /dev/null @@ -1,53 +0,0 @@ - -:- use_module(library(clpbn)). - -:- set_clpbn_flag(solver, bp). - -% -% A E -% / \ / -% / \ / -% B C -% \ / -% \ / -% D -% - -a(A) :- - a_table(ADist), - { A = a with p([a1, a2], ADist) }. - -b(B) :- - a(A), - b_table(BDist), - { B = b with p([b1, b2], BDist, [A]) }. - -c(C) :- - a(A), - c_table(CDist), - { C = c with p([c1, c2], CDist, [A]) }. - -d(D) :- - b(B), - c(C), - d_table(DDist), - { D = d with p([d1, d2], DDist, [B, C]) }. - -e(E) :- - e_table(EDist), - { E = e with p([e1, e2], EDist) }. - - -a_table([0.005, 0.995]). - -b_table([0.02, 0.97, - 0.88, 0.03]). - -c_table([0.55, 0.94, - 0.45, 0.06]). - -d_table([0.192, 0.98, 0.33, 0.013, - 0.908, 0.02, 0.77, 0.987]). - -e_table([0.055, 0.945]). - diff --git a/packages/CLPBN/clpbn/bp/examples/town/run_town_bp_compress.sh b/packages/CLPBN/clpbn/bp/examples/town/run_town_bp_compress.sh new file mode 100644 index 000000000..0261b7a35 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/examples/town/run_town_bp_compress.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +cp ~/bin/yap ~/bin/town_comp +YAP=~/bin/town_comp + +#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log +OUT_FILE_NAME=bp_compress.log +rm -f $OUT_FILE_NAME +rm -f ignore.$OUT_FILE_NAME + + +function run_solver +{ +if [ $2 = bp ] +then + extra_flag1=clpbn_bp:set_solver_parameter\(run_mode,$4\) + extra_flag2=clpbn_bp:set_solver_parameter\(schedule,$5\) + extra_flag3=clpbn_bp:set_solver_parameter\(always_loopy_solver,$6\) +else + extra_flag1=true + extra_flag2=true + extra_flag3=true +fi +/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME +[$1]. +clpbn:set_clpbn_flag(solver,$2), + clpbn_bp:use_log_space, + $extra_flag1, $extra_flag2, $extra_flag3, + run_query(_R), + open("$OUT_FILE_NAME", 'append',S), + format(S, '$3: ~15+ ',[]), + close(S). +EOF +} + + +function run_all_graphs +{ + echo "*******************************************************************" >> "$OUT_FILE_NAME" + echo "results for solver $2" >> $OUT_FILE_NAME + echo "*******************************************************************" >> "$OUT_FILE_NAME" + run_solver town_1000 $1 town_1000 $3 $4 $5 + run_solver town_5000 $1 town_5000 $3 $4 $5 + run_solver town_10000 $1 town_10000 $3 $4 $5 + run_solver town_50000 $1 town_50000 $3 $4 $5 + run_solver town_100000 $1 town_100000 $3 $4 $5 + run_solver town_500000 $1 town_500000 $3 $4 $5 + run_solver town_1000000 $1 town_1000000 $3 $4 $5 + run_solver town_2500000 $1 town_2500000 $3 $4 $5 + run_solver town_5000000 $1 town_5000000 $3 $4 $5 + run_solver town_7500000 $1 town_7500000 $3 $4 $5 + run_solver town_10000000 $1 town_10000000 $3 $4 $5 +} + +run_solver town_10000 "bp(compress,seq_fixed)" town_10000 compress seq_fixed true +exit + +########## +run_all_graphs bp "bp(compress,seq_fixed) " compress seq_fixed true + diff --git a/packages/CLPBN/clpbn/bp/examples/town/run_town_bp_convert.sh b/packages/CLPBN/clpbn/bp/examples/town/run_town_bp_convert.sh new file mode 100644 index 000000000..133cf71fb --- /dev/null +++ b/packages/CLPBN/clpbn/bp/examples/town/run_town_bp_convert.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +YAP=~/bin/town_conv + +#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log +OUT_FILE_NAME=bp_convert.log +rm -f $OUT_FILE_NAME +rm -f ignore.$OUT_FILE_NAME + + +function run_solver +{ +if [ $2 = bp ] +then + extra_flag1=clpbn_bp:set_solver_parameter\(run_mode,$4\) + extra_flag2=clpbn_bp:set_solver_parameter\(schedule,$5\) + extra_flag3=clpbn_bp:set_solver_parameter\(always_loopy_solver,$6\) +else + extra_flag1=true + extra_flag2=true + extra_flag3=true +fi +/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME +[$1]. +clpbn:set_clpbn_flag(solver,$2), + clpbn_bp:use_log_space, + $extra_flag1, $extra_flag2, $extra_flag3, + run_query(_R), + open("$OUT_FILE_NAME", 'append',S), + format(S, '$3: ~15+ ',[]), + close(S). +EOF +} + + +function run_all_graphs +{ + echo "*******************************************************************" >> "$OUT_FILE_NAME" + echo "results for solver $2" >> $OUT_FILE_NAME + echo "*******************************************************************" >> "$OUT_FILE_NAME" + run_solver town_1000 $1 town_1000 $3 $4 $5 + run_solver town_5000 $1 town_5000 $3 $4 $5 + run_solver town_10000 $1 town_10000 $3 $4 $5 + run_solver town_50000 $1 town_50000 $3 $4 $5 + run_solver town_100000 $1 town_100000 $3 $4 $5 + run_solver town_500000 $1 town_500000 $3 $4 $5 + run_solver town_1000000 $1 town_1000000 $3 $4 $5 +} + +run_all_graphs bp "bp(convert,seq_fixed) " convert seq_fixed false + diff --git a/packages/CLPBN/clpbn/bp/examples/town/run_town_bp_normal.sh b/packages/CLPBN/clpbn/bp/examples/town/run_town_bp_normal.sh new file mode 100644 index 000000000..6284282cb --- /dev/null +++ b/packages/CLPBN/clpbn/bp/examples/town/run_town_bp_normal.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +YAP=~/bin/town_norm + +#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log +OUT_FILE_NAME=bp_normal.log +rm -f $OUT_FILE_NAME +rm -f ignore.$OUT_FILE_NAME + + +function run_solver +{ +if [ $2 = bp ] +then + extra_flag1=clpbn_bp:set_solver_parameter\(run_mode,$4\) + extra_flag2=clpbn_bp:set_solver_parameter\(schedule,$5\) + extra_flag3=clpbn_bp:set_solver_parameter\(always_loopy_solver,$6\) +else + extra_flag1=true + extra_flag2=true + extra_flag3=true +fi +/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME +[$1]. +clpbn:set_clpbn_flag(solver,$2), + clpbn_bp:use_log_space, + $extra_flag1, $extra_flag2, $extra_flag3, + run_query(_R), + open("$OUT_FILE_NAME", 'append',S), + format(S, '$3: ~15+ ',[]), + close(S). +EOF +} + + +function run_all_graphs +{ + echo "*******************************************************************" >> "$OUT_FILE_NAME" + echo "results for solver $2" >> $OUT_FILE_NAME + echo "*******************************************************************" >> "$OUT_FILE_NAME" + run_solver town_1000 $1 town_1000 $3 $4 $5 + run_solver town_5000 $1 town_5000 $3 $4 $5 + run_solver town_10000 $1 town_10000 $3 $4 $5 + run_solver town_50000 $1 town_50000 $3 $4 $5 + run_solver town_100000 $1 town_100000 $3 $4 $5 + run_solver town_500000 $1 town_500000 $3 $4 $5 + run_solver town_1000000 $1 town_1000000 $3 $4 $5 +} + +run_all_graphs bp "bp(normal,seq_fixed) " normal seq_fixed false diff --git a/packages/CLPBN/clpbn/bp/examples/town/run_town_gibbs_tests.sh b/packages/CLPBN/clpbn/bp/examples/town/run_town_gibbs_tests.sh new file mode 100644 index 000000000..ac1f10781 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/examples/town/run_town_gibbs_tests.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +YAP=~/bin/town_gibbs + +#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log +OUT_FILE_NAME=gibbs.log +rm -f $OUT_FILE_NAME +rm -f ignore.$OUT_FILE_NAME + + +function run_solver +{ +if [ $2 = bp ] +then + extra_flag1=clpbn_bp:set_solver_parameter\(run_mode,$4\) + extra_flag2=clpbn_bp:set_solver_parameter\(schedule,$5\) + extra_flag3=clpbn_bp:set_solver_parameter\(always_loopy_solver,$6\) +else + extra_flag1=true + extra_flag2=true + extra_flag3=true +fi +/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME +[$1]. +clpbn:set_clpbn_flag(solver,$2), + clpbn_bp:use_log_space, + $extra_flag1, $extra_flag2, $extra_flag3, + run_query(_R), + open("$OUT_FILE_NAME", 'append',S), + format(S, '$3: ~15+ ',[]), + close(S). +EOF +} + + +function run_all_graphs +{ + echo "*******************************************************************" >> "$OUT_FILE_NAME" + echo "results for solver $2" >> $OUT_FILE_NAME + echo "*******************************************************************" >> "$OUT_FILE_NAME" + run_solver town_1000 $1 town_1000 $3 $4 $5 + run_solver town_5000 $1 town_5000 $3 $4 $5 + run_solver town_10000 $1 town_10000 $3 $4 $5 + run_solver town_50000 $1 town_50000 $3 $4 $5 + run_solver town_100000 $1 town_100000 $3 $4 $5 + run_solver town_500000 $1 town_500000 $3 $4 $5 + run_solver town_1000000 $1 town_1000000 $3 $4 $5 +} + +run_all_graphs gibbs "gibbs " + diff --git a/packages/CLPBN/clpbn/bp/examples/town/run_town_jt_tests.sh b/packages/CLPBN/clpbn/bp/examples/town/run_town_jt_tests.sh new file mode 100644 index 000000000..a75483517 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/examples/town/run_town_jt_tests.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +YAP=~/bin/town_jt + +#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log +OUT_FILE_NAME=jt.log +rm -f $OUT_FILE_NAME +rm -f ignore.$OUT_FILE_NAME + + +function run_solver +{ +if [ $2 = bp ] +then + extra_flag1=clpbn_bp:set_solver_parameter\(run_mode,$4\) + extra_flag2=clpbn_bp:set_solver_parameter\(schedule,$5\) + extra_flag3=clpbn_bp:set_solver_parameter\(always_loopy_solver,$6\) +else + extra_flag1=true + extra_flag2=true + extra_flag3=true +fi +/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME +[$1]. +clpbn:set_clpbn_flag(solver,$2), + clpbn_bp:use_log_space, + $extra_flag1, $extra_flag2, $extra_flag3, + run_query(_R), + open("$OUT_FILE_NAME", 'append',S), + format(S, '$3: ~15+ ',[]), + close(S). +EOF +} + + +function run_all_graphs +{ + echo "*******************************************************************" >> "$OUT_FILE_NAME" + echo "results for solver $2" >> $OUT_FILE_NAME + echo "*******************************************************************" >> "$OUT_FILE_NAME" + run_solver town_1000 $1 town_1000 $3 $4 $5 + run_solver town_5000 $1 town_5000 $3 $4 $5 + run_solver town_10000 $1 town_10000 $3 $4 $5 + run_solver town_50000 $1 town_50000 $3 $4 $5 + run_solver town_100000 $1 town_100000 $3 $4 $5 + run_solver town_500000 $1 town_500000 $3 $4 $5 + run_solver town_1000000 $1 town_1000000 $3 $4 $5 +} + +run_all_graphs jt "jt " + diff --git a/packages/CLPBN/clpbn/bp/examples/town/run_town_ve_tests.sh b/packages/CLPBN/clpbn/bp/examples/town/run_town_ve_tests.sh new file mode 100644 index 000000000..db5762a82 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/examples/town/run_town_ve_tests.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +YAP=~/bin/town_ve + +#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log +OUT_FILE_NAME=ve.log +rm -f $OUT_FILE_NAME +rm -f ignore.$OUT_FILE_NAME + + +function run_solver +{ +if [ $2 = bp ] +then + extra_flag1=clpbn_bp:set_solver_parameter\(run_mode,$4\) + extra_flag2=clpbn_bp:set_solver_parameter\(schedule,$5\) + extra_flag3=clpbn_bp:set_solver_parameter\(always_loopy_solver,$6\) +else + extra_flag1=true + extra_flag2=true + extra_flag3=true +fi +/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME +[$1]. +clpbn:set_clpbn_flag(solver,$2), + clpbn_bp:use_log_space, + $extra_flag1, $extra_flag2, $extra_flag3, + run_query(_R), + open("$OUT_FILE_NAME", 'append',S), + format(S, '$3: ~15+ ',[]), + close(S). +EOF +} + + +function run_all_graphs +{ + echo "*******************************************************************" >> "$OUT_FILE_NAME" + echo "results for solver $2" >> $OUT_FILE_NAME + echo "*******************************************************************" >> "$OUT_FILE_NAME" + run_solver town_1000 $1 town_1000 $3 $4 $5 + run_solver town_5000 $1 town_5000 $3 $4 $5 + run_solver town_10000 $1 town_10000 $3 $4 $5 + run_solver town_50000 $1 town_50000 $3 $4 $5 + run_solver town_100000 $1 town_100000 $3 $4 $5 + run_solver town_500000 $1 town_500000 $3 $4 $5 + run_solver town_1000000 $1 town_1000000 $3 $4 $5 +} + +run_all_graphs ve "ve " + diff --git a/packages/CLPBN/clpbn/bp/examples/town/schema.yap b/packages/CLPBN/clpbn/bp/examples/town/schema.yap new file mode 100644 index 000000000..f33db80f7 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/examples/town/schema.yap @@ -0,0 +1,65 @@ + +conservative_city(City, Cons) :- + cons_table(City, ConsDist), + { Cons = conservative_city(City) with p([y,n], ConsDist) }. + + +gender(X, Gender) :- + gender_table(X, GenderDist), + { Gender = gender(X) with p([m,f], GenderDist) }. + + +hair_color(X, Color) :- + lives(X, City), + conservative_city(City, Cons), + hair_color_table(X,ColorTable), + { Color = hair_color(X) with + p([t,f], ColorTable,[Cons]) }. + + +car_color(X, Color) :- + hair_color(X, HColor), + car_color_table(X,CColorTable), + { Color = car_color(X) with + p([t,f], CColorTable,[HColor]) }. + + +height(X, Height) :- + gender(X, Gender), + height_table(X,HeightTable), + { Height = height(X) with + p([t,f], HeightTable,[Gender]) }. + + +shoe_size(X, Shoesize) :- + height(X, Height), + shoe_size_table(X,ShoesizeTable), + { Shoesize = shoe_size(X) with + p([t,f], ShoesizeTable,[Height]) }. + + +guilty(X, Guilt) :- + guilty_table(X, GuiltDist), + { Guilt = guilty(X) with p([y,n], GuiltDist) }. + + +descn(X, Descn) :- + car_color(X, Car), + hair_color(X, Hair), + height(X, Height), + guilty(X, Guilt), + descn_table(X, DescTable), + { Descn = descn(X) with + p([t,f], DescTable,[Car,Hair,Height,Guilt]) }. + + +witness(City, Witness) :- + descn(joe, DescnJ), + descn(p2, Descn2), + wit_table(WitTable), + { Witness = witness(City) with + p([t,f], WitTable,[DescnJ, Descn2]) }. + + +:- ensure_loaded(tables). + diff --git a/packages/CLPBN/clpbn/bp/examples/town/tables.yap b/packages/CLPBN/clpbn/bp/examples/town/tables.yap new file mode 100644 index 000000000..e4503cbef --- /dev/null +++ b/packages/CLPBN/clpbn/bp/examples/town/tables.yap @@ -0,0 +1,46 @@ + +cons_table(amsterdam, [0.2, 0.8]) :- !. +cons_table(_, [0.8, 0.2]). + + +gender_table(_, [0.55, 0.44]). + + +hair_color_table(_, + /* conservative_city */ + /* y n */ + [ 0.05, 0.1, + 0.95, 0.9 ]). + + +car_color_table(_, + /* t f */ + [ 0.9, 0.2, + 0.1, 0.8 ]). + + +height_table(_, + /* m f */ + [ 0.6, 0.4, + 0.4, 0.6 ]). + + +shoe_size_table(_, + /* t f */ + [ 0.9, 0.1, + 0.1, 0.9 ]). + + +guilty_table(_, [0.23, 0.77]). + + +descn_table(_, + /* color, hair, height, guilt */ + /* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */ + [ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.33, 0.91, 0.95, 0.92, + 0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.61, 0.28, 0.77, 0.09, 0.05, 0.08]). + + +wit_table([0.2, 0.45, 0.24, 0.34, + 0.8, 0.55, 0.76, 0.66]). + diff --git a/packages/CLPBN/clpbn/bp/examples/town/town_generator.sh b/packages/CLPBN/clpbn/bp/examples/town/town_generator.sh new file mode 100644 index 000000000..2feb9a397 --- /dev/null +++ b/packages/CLPBN/clpbn/bp/examples/town/town_generator.sh @@ -0,0 +1,59 @@ +#!/home/tgomes/bin/yap -L -- + +/* +Steps: +1. generate N facts lives(I, nyc), 0 <= I < N. +2. generate evidence on descn for N people, *** except for 1 *** +3. Run query ?- guilty(joe, Guilty), witness(joe, t), descn(2,t), descn(3, f), descn(4, f) ... +*/ + +:- initialization(main). + + +main :- + unix(argv([H])), + generate_town(H). + + +generate_town(N) :- + atomic_concat(['town_', N, '.yap'], FileName), + open(FileName, 'write', S), + write(S, ':- source.\n'), + write(S, ':- style_check(all).\n'), + write(S, ':- yap_flag(unknown,error).\n'), + write(S, ':- yap_flag(write_strings,on).\n'), + write(S, ':- use_module(library(clpbn)).\n'), + write(S, ':- set_clpbn_flag(solver, bp).\n'), + write(S, ':- [-schema].\n\n'), + write(S, 'lives(_joe, nyc).\n'), + atom_number(N, N2), + generate_people(S, N2, 2), + write(S, '\nrun_query(Guilty) :- \n'), + write(S, '\tguilty(joe, Guilty),\n'), + write(S, '\twitness(nyc, t),\n'), + write(S, '\trunall(X, ev(X)).\n\n\n'), + write(S, 'runall(G, Wrapper) :-\n'), + write(S, '\tfindall(G, Wrapper, L),\n'), + write(S, '\texecute_all(L).\n\n\n'), + write(S, 'execute_all([]).\n'), + write(S, 'execute_all(G.L) :-\n'), + write(S, '\tcall(G),\n'), + write(S, '\texecute_all(L).\n\n\n'), + generate_query(S, N2, 2), + close(S). + + +generate_people(_, N, Counting1) :- !. +generate_people(S, N, Counting) :- + format(S, 'lives(p~w, nyc).~n', [Counting]), + Counting1 is Counting + 1, + generate_people(S, N, Counting1). + + +generate_query(S, N, Counting) :- + Counting > N, !. +generate_query(S, N, Counting) :- !, + format(S, 'ev(descn(p~w, t)).~n', [Counting]), + Counting1 is Counting + 1, + generate_query(S, N, Counting1). + diff --git a/packages/CLPBN/clpbn/bp/valgrind.h b/packages/CLPBN/clpbn/bp/valgrind.h deleted file mode 100644 index 0f5b37662..000000000 --- a/packages/CLPBN/clpbn/bp/valgrind.h +++ /dev/null @@ -1,4536 +0,0 @@ -/* -*- c -*- - ---------------------------------------------------------------- - - Notice that the following BSD-style license applies to this one - file (valgrind.h) only. The rest of Valgrind is licensed under the - terms of the GNU General Public License, version 2, unless - otherwise indicated. See the COPYING file in the source - distribution for details. - - ---------------------------------------------------------------- - - This file is part of Valgrind, a dynamic binary instrumentation - framework. - - Copyright (C) 2000-2010 Julian Seward. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. The origin of this software must not be misrepresented; you must - not claim that you wrote the original software. If you use this - software in a product, an acknowledgment in the product - documentation would be appreciated but is not required. - - 3. Altered source versions must be plainly marked as such, and must - not be misrepresented as being the original software. - - 4. The name of the author may not be used to endorse or promote - products derived from this software without specific prior written - permission. - - THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS - OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY - DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE - GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - ---------------------------------------------------------------- - - Notice that the above BSD-style license applies to this one file - (valgrind.h) only. The entire rest of Valgrind is licensed under - the terms of the GNU General Public License, version 2. See the - COPYING file in the source distribution for details. - - ---------------------------------------------------------------- -*/ - - -/* This file is for inclusion into client (your!) code. - - You can use these macros to manipulate and query Valgrind's - execution inside your own programs. - - The resulting executables will still run without Valgrind, just a - little bit more slowly than they otherwise would, but otherwise - unchanged. When not running on valgrind, each client request - consumes very few (eg. 7) instructions, so the resulting performance - loss is negligible unless you plan to execute client requests - millions of times per second. Nevertheless, if that is still a - problem, you can compile with the NVALGRIND symbol defined (gcc - -DNVALGRIND) so that client requests are not even compiled in. */ - -#ifndef __VALGRIND_H -#define __VALGRIND_H - - -/* ------------------------------------------------------------------ */ -/* VERSION NUMBER OF VALGRIND */ -/* ------------------------------------------------------------------ */ - -/* Specify Valgrind's version number, so that user code can - conditionally compile based on our version number. Note that these - were introduced at version 3.6 and so do not exist in version 3.5 - or earlier. The recommended way to use them to check for "version - X.Y or later" is (eg) - -#if defined(__VALGRIND_MAJOR__) && defined(__VALGRIND_MINOR__) \ - && (__VALGRIND_MAJOR__ > 3 \ - || (__VALGRIND_MAJOR__ == 3 && __VALGRIND_MINOR__ >= 6)) -*/ -#define __VALGRIND_MAJOR__ 3 -#define __VALGRIND_MINOR__ 6 - - -#include - -/* Nb: this file might be included in a file compiled with -ansi. So - we can't use C++ style "//" comments nor the "asm" keyword (instead - use "__asm__"). */ - -/* Derive some tags indicating what the target platform is. Note - that in this file we're using the compiler's CPP symbols for - identifying architectures, which are different to the ones we use - within the rest of Valgrind. Note, __powerpc__ is active for both - 32 and 64-bit PPC, whereas __powerpc64__ is only active for the - latter (on Linux, that is). - - Misc note: how to find out what's predefined in gcc by default: - gcc -Wp,-dM somefile.c -*/ -#undef PLAT_ppc64_aix5 -#undef PLAT_ppc32_aix5 -#undef PLAT_x86_darwin -#undef PLAT_amd64_darwin -#undef PLAT_x86_linux -#undef PLAT_amd64_linux -#undef PLAT_ppc32_linux -#undef PLAT_ppc64_linux -#undef PLAT_arm_linux - -#if defined(_AIX) && defined(__64BIT__) -# define PLAT_ppc64_aix5 1 -#elif defined(_AIX) && !defined(__64BIT__) -# define PLAT_ppc32_aix5 1 -#elif defined(__APPLE__) && defined(__i386__) -# define PLAT_x86_darwin 1 -#elif defined(__APPLE__) && defined(__x86_64__) -# define PLAT_amd64_darwin 1 -#elif defined(__linux__) && defined(__i386__) -# define PLAT_x86_linux 1 -#elif defined(__linux__) && defined(__x86_64__) -# define PLAT_amd64_linux 1 -#elif defined(__linux__) && defined(__powerpc__) && !defined(__powerpc64__) -# define PLAT_ppc32_linux 1 -#elif defined(__linux__) && defined(__powerpc__) && defined(__powerpc64__) -# define PLAT_ppc64_linux 1 -#elif defined(__linux__) && defined(__arm__) -# define PLAT_arm_linux 1 -#else -/* If we're not compiling for our target platform, don't generate - any inline asms. */ -# if !defined(NVALGRIND) -# define NVALGRIND 1 -# endif -#endif - - -/* ------------------------------------------------------------------ */ -/* ARCHITECTURE SPECIFICS for SPECIAL INSTRUCTIONS. There is nothing */ -/* in here of use to end-users -- skip to the next section. */ -/* ------------------------------------------------------------------ */ - -#if defined(NVALGRIND) - -/* Define NVALGRIND to completely remove the Valgrind magic sequence - from the compiled code (analogous to NDEBUG's effects on - assert()) */ -#define VALGRIND_DO_CLIENT_REQUEST( \ - _zzq_rlval, _zzq_default, _zzq_request, \ - _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ - { \ - (_zzq_rlval) = (_zzq_default); \ - } - -#else /* ! NVALGRIND */ - -/* The following defines the magic code sequences which the JITter - spots and handles magically. Don't look too closely at them as - they will rot your brain. - - The assembly code sequences for all architectures is in this one - file. This is because this file must be stand-alone, and we don't - want to have multiple files. - - For VALGRIND_DO_CLIENT_REQUEST, we must ensure that the default - value gets put in the return slot, so that everything works when - this is executed not under Valgrind. Args are passed in a memory - block, and so there's no intrinsic limit to the number that could - be passed, but it's currently five. - - The macro args are: - _zzq_rlval result lvalue - _zzq_default default value (result returned when running on real CPU) - _zzq_request request code - _zzq_arg1..5 request params - - The other two macros are used to support function wrapping, and are - a lot simpler. VALGRIND_GET_NR_CONTEXT returns the value of the - guest's NRADDR pseudo-register and whatever other information is - needed to safely run the call original from the wrapper: on - ppc64-linux, the R2 value at the divert point is also needed. This - information is abstracted into a user-visible type, OrigFn. - - VALGRIND_CALL_NOREDIR_* behaves the same as the following on the - guest, but guarantees that the branch instruction will not be - redirected: x86: call *%eax, amd64: call *%rax, ppc32/ppc64: - branch-and-link-to-r11. VALGRIND_CALL_NOREDIR is just text, not a - complete inline asm, since it needs to be combined with more magic - inline asm stuff to be useful. -*/ - -/* ------------------------- x86-{linux,darwin} ---------------- */ - -#if defined(PLAT_x86_linux) || defined(PLAT_x86_darwin) - -typedef - struct { - unsigned int nraddr; /* where's the code? */ - } - OrigFn; - -#define __SPECIAL_INSTRUCTION_PREAMBLE \ - "roll $3, %%edi ; roll $13, %%edi\n\t" \ - "roll $29, %%edi ; roll $19, %%edi\n\t" - -#define VALGRIND_DO_CLIENT_REQUEST( \ - _zzq_rlval, _zzq_default, _zzq_request, \ - _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ - { volatile unsigned int _zzq_args[6]; \ - volatile unsigned int _zzq_result; \ - _zzq_args[0] = (unsigned int)(_zzq_request); \ - _zzq_args[1] = (unsigned int)(_zzq_arg1); \ - _zzq_args[2] = (unsigned int)(_zzq_arg2); \ - _zzq_args[3] = (unsigned int)(_zzq_arg3); \ - _zzq_args[4] = (unsigned int)(_zzq_arg4); \ - _zzq_args[5] = (unsigned int)(_zzq_arg5); \ - __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ - /* %EDX = client_request ( %EAX ) */ \ - "xchgl %%ebx,%%ebx" \ - : "=d" (_zzq_result) \ - : "a" (&_zzq_args[0]), "0" (_zzq_default) \ - : "cc", "memory" \ - ); \ - _zzq_rlval = _zzq_result; \ - } - -#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ - { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ - volatile unsigned int __addr; \ - __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ - /* %EAX = guest_NRADDR */ \ - "xchgl %%ecx,%%ecx" \ - : "=a" (__addr) \ - : \ - : "cc", "memory" \ - ); \ - _zzq_orig->nraddr = __addr; \ - } - -#define VALGRIND_CALL_NOREDIR_EAX \ - __SPECIAL_INSTRUCTION_PREAMBLE \ - /* call-noredir *%EAX */ \ - "xchgl %%edx,%%edx\n\t" -#endif /* PLAT_x86_linux || PLAT_x86_darwin */ - -/* ------------------------ amd64-{linux,darwin} --------------- */ - -#if defined(PLAT_amd64_linux) || defined(PLAT_amd64_darwin) - -typedef - struct { - unsigned long long int nraddr; /* where's the code? */ - } - OrigFn; - -#define __SPECIAL_INSTRUCTION_PREAMBLE \ - "rolq $3, %%rdi ; rolq $13, %%rdi\n\t" \ - "rolq $61, %%rdi ; rolq $51, %%rdi\n\t" - -#define VALGRIND_DO_CLIENT_REQUEST( \ - _zzq_rlval, _zzq_default, _zzq_request, \ - _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ - { volatile unsigned long long int _zzq_args[6]; \ - volatile unsigned long long int _zzq_result; \ - _zzq_args[0] = (unsigned long long int)(_zzq_request); \ - _zzq_args[1] = (unsigned long long int)(_zzq_arg1); \ - _zzq_args[2] = (unsigned long long int)(_zzq_arg2); \ - _zzq_args[3] = (unsigned long long int)(_zzq_arg3); \ - _zzq_args[4] = (unsigned long long int)(_zzq_arg4); \ - _zzq_args[5] = (unsigned long long int)(_zzq_arg5); \ - __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ - /* %RDX = client_request ( %RAX ) */ \ - "xchgq %%rbx,%%rbx" \ - : "=d" (_zzq_result) \ - : "a" (&_zzq_args[0]), "0" (_zzq_default) \ - : "cc", "memory" \ - ); \ - _zzq_rlval = _zzq_result; \ - } - -#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ - { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ - volatile unsigned long long int __addr; \ - __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ - /* %RAX = guest_NRADDR */ \ - "xchgq %%rcx,%%rcx" \ - : "=a" (__addr) \ - : \ - : "cc", "memory" \ - ); \ - _zzq_orig->nraddr = __addr; \ - } - -#define VALGRIND_CALL_NOREDIR_RAX \ - __SPECIAL_INSTRUCTION_PREAMBLE \ - /* call-noredir *%RAX */ \ - "xchgq %%rdx,%%rdx\n\t" -#endif /* PLAT_amd64_linux || PLAT_amd64_darwin */ - -/* ------------------------ ppc32-linux ------------------------ */ - -#if defined(PLAT_ppc32_linux) - -typedef - struct { - unsigned int nraddr; /* where's the code? */ - } - OrigFn; - -#define __SPECIAL_INSTRUCTION_PREAMBLE \ - "rlwinm 0,0,3,0,0 ; rlwinm 0,0,13,0,0\n\t" \ - "rlwinm 0,0,29,0,0 ; rlwinm 0,0,19,0,0\n\t" - -#define VALGRIND_DO_CLIENT_REQUEST( \ - _zzq_rlval, _zzq_default, _zzq_request, \ - _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ - \ - { unsigned int _zzq_args[6]; \ - unsigned int _zzq_result; \ - unsigned int* _zzq_ptr; \ - _zzq_args[0] = (unsigned int)(_zzq_request); \ - _zzq_args[1] = (unsigned int)(_zzq_arg1); \ - _zzq_args[2] = (unsigned int)(_zzq_arg2); \ - _zzq_args[3] = (unsigned int)(_zzq_arg3); \ - _zzq_args[4] = (unsigned int)(_zzq_arg4); \ - _zzq_args[5] = (unsigned int)(_zzq_arg5); \ - _zzq_ptr = _zzq_args; \ - __asm__ volatile("mr 3,%1\n\t" /*default*/ \ - "mr 4,%2\n\t" /*ptr*/ \ - __SPECIAL_INSTRUCTION_PREAMBLE \ - /* %R3 = client_request ( %R4 ) */ \ - "or 1,1,1\n\t" \ - "mr %0,3" /*result*/ \ - : "=b" (_zzq_result) \ - : "b" (_zzq_default), "b" (_zzq_ptr) \ - : "cc", "memory", "r3", "r4"); \ - _zzq_rlval = _zzq_result; \ - } - -#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ - { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ - unsigned int __addr; \ - __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ - /* %R3 = guest_NRADDR */ \ - "or 2,2,2\n\t" \ - "mr %0,3" \ - : "=b" (__addr) \ - : \ - : "cc", "memory", "r3" \ - ); \ - _zzq_orig->nraddr = __addr; \ - } - -#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - __SPECIAL_INSTRUCTION_PREAMBLE \ - /* branch-and-link-to-noredir *%R11 */ \ - "or 3,3,3\n\t" -#endif /* PLAT_ppc32_linux */ - -/* ------------------------ ppc64-linux ------------------------ */ - -#if defined(PLAT_ppc64_linux) - -typedef - struct { - unsigned long long int nraddr; /* where's the code? */ - unsigned long long int r2; /* what tocptr do we need? */ - } - OrigFn; - -#define __SPECIAL_INSTRUCTION_PREAMBLE \ - "rotldi 0,0,3 ; rotldi 0,0,13\n\t" \ - "rotldi 0,0,61 ; rotldi 0,0,51\n\t" - -#define VALGRIND_DO_CLIENT_REQUEST( \ - _zzq_rlval, _zzq_default, _zzq_request, \ - _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ - \ - { unsigned long long int _zzq_args[6]; \ - register unsigned long long int _zzq_result __asm__("r3"); \ - register unsigned long long int* _zzq_ptr __asm__("r4"); \ - _zzq_args[0] = (unsigned long long int)(_zzq_request); \ - _zzq_args[1] = (unsigned long long int)(_zzq_arg1); \ - _zzq_args[2] = (unsigned long long int)(_zzq_arg2); \ - _zzq_args[3] = (unsigned long long int)(_zzq_arg3); \ - _zzq_args[4] = (unsigned long long int)(_zzq_arg4); \ - _zzq_args[5] = (unsigned long long int)(_zzq_arg5); \ - _zzq_ptr = _zzq_args; \ - __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ - /* %R3 = client_request ( %R4 ) */ \ - "or 1,1,1" \ - : "=r" (_zzq_result) \ - : "0" (_zzq_default), "r" (_zzq_ptr) \ - : "cc", "memory"); \ - _zzq_rlval = _zzq_result; \ - } - -#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ - { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ - register unsigned long long int __addr __asm__("r3"); \ - __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ - /* %R3 = guest_NRADDR */ \ - "or 2,2,2" \ - : "=r" (__addr) \ - : \ - : "cc", "memory" \ - ); \ - _zzq_orig->nraddr = __addr; \ - __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ - /* %R3 = guest_NRADDR_GPR2 */ \ - "or 4,4,4" \ - : "=r" (__addr) \ - : \ - : "cc", "memory" \ - ); \ - _zzq_orig->r2 = __addr; \ - } - -#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - __SPECIAL_INSTRUCTION_PREAMBLE \ - /* branch-and-link-to-noredir *%R11 */ \ - "or 3,3,3\n\t" - -#endif /* PLAT_ppc64_linux */ - -/* ------------------------- arm-linux ------------------------- */ - -#if defined(PLAT_arm_linux) - -typedef - struct { - unsigned int nraddr; /* where's the code? */ - } - OrigFn; - -#define __SPECIAL_INSTRUCTION_PREAMBLE \ - "mov r12, r12, ror #3 ; mov r12, r12, ror #13 \n\t" \ - "mov r12, r12, ror #29 ; mov r12, r12, ror #19 \n\t" - -#define VALGRIND_DO_CLIENT_REQUEST( \ - _zzq_rlval, _zzq_default, _zzq_request, \ - _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ - \ - { volatile unsigned int _zzq_args[6]; \ - volatile unsigned int _zzq_result; \ - _zzq_args[0] = (unsigned int)(_zzq_request); \ - _zzq_args[1] = (unsigned int)(_zzq_arg1); \ - _zzq_args[2] = (unsigned int)(_zzq_arg2); \ - _zzq_args[3] = (unsigned int)(_zzq_arg3); \ - _zzq_args[4] = (unsigned int)(_zzq_arg4); \ - _zzq_args[5] = (unsigned int)(_zzq_arg5); \ - __asm__ volatile("mov r3, %1\n\t" /*default*/ \ - "mov r4, %2\n\t" /*ptr*/ \ - __SPECIAL_INSTRUCTION_PREAMBLE \ - /* R3 = client_request ( R4 ) */ \ - "orr r10, r10, r10\n\t" \ - "mov %0, r3" /*result*/ \ - : "=r" (_zzq_result) \ - : "r" (_zzq_default), "r" (&_zzq_args[0]) \ - : "cc","memory", "r3", "r4"); \ - _zzq_rlval = _zzq_result; \ - } - -#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ - { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ - unsigned int __addr; \ - __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ - /* R3 = guest_NRADDR */ \ - "orr r11, r11, r11\n\t" \ - "mov %0, r3" \ - : "=r" (__addr) \ - : \ - : "cc", "memory", "r3" \ - ); \ - _zzq_orig->nraddr = __addr; \ - } - -#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ - __SPECIAL_INSTRUCTION_PREAMBLE \ - /* branch-and-link-to-noredir *%R4 */ \ - "orr r12, r12, r12\n\t" - -#endif /* PLAT_arm_linux */ - -/* ------------------------ ppc32-aix5 ------------------------- */ - -#if defined(PLAT_ppc32_aix5) - -typedef - struct { - unsigned int nraddr; /* where's the code? */ - unsigned int r2; /* what tocptr do we need? */ - } - OrigFn; - -#define __SPECIAL_INSTRUCTION_PREAMBLE \ - "rlwinm 0,0,3,0,0 ; rlwinm 0,0,13,0,0\n\t" \ - "rlwinm 0,0,29,0,0 ; rlwinm 0,0,19,0,0\n\t" - -#define VALGRIND_DO_CLIENT_REQUEST( \ - _zzq_rlval, _zzq_default, _zzq_request, \ - _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ - \ - { unsigned int _zzq_args[7]; \ - register unsigned int _zzq_result; \ - register unsigned int* _zzq_ptr; \ - _zzq_args[0] = (unsigned int)(_zzq_request); \ - _zzq_args[1] = (unsigned int)(_zzq_arg1); \ - _zzq_args[2] = (unsigned int)(_zzq_arg2); \ - _zzq_args[3] = (unsigned int)(_zzq_arg3); \ - _zzq_args[4] = (unsigned int)(_zzq_arg4); \ - _zzq_args[5] = (unsigned int)(_zzq_arg5); \ - _zzq_args[6] = (unsigned int)(_zzq_default); \ - _zzq_ptr = _zzq_args; \ - __asm__ volatile("mr 4,%1\n\t" \ - "lwz 3, 24(4)\n\t" \ - __SPECIAL_INSTRUCTION_PREAMBLE \ - /* %R3 = client_request ( %R4 ) */ \ - "or 1,1,1\n\t" \ - "mr %0,3" \ - : "=b" (_zzq_result) \ - : "b" (_zzq_ptr) \ - : "r3", "r4", "cc", "memory"); \ - _zzq_rlval = _zzq_result; \ - } - -#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ - { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ - register unsigned int __addr; \ - __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ - /* %R3 = guest_NRADDR */ \ - "or 2,2,2\n\t" \ - "mr %0,3" \ - : "=b" (__addr) \ - : \ - : "r3", "cc", "memory" \ - ); \ - _zzq_orig->nraddr = __addr; \ - __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ - /* %R3 = guest_NRADDR_GPR2 */ \ - "or 4,4,4\n\t" \ - "mr %0,3" \ - : "=b" (__addr) \ - : \ - : "r3", "cc", "memory" \ - ); \ - _zzq_orig->r2 = __addr; \ - } - -#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - __SPECIAL_INSTRUCTION_PREAMBLE \ - /* branch-and-link-to-noredir *%R11 */ \ - "or 3,3,3\n\t" - -#endif /* PLAT_ppc32_aix5 */ - -/* ------------------------ ppc64-aix5 ------------------------- */ - -#if defined(PLAT_ppc64_aix5) - -typedef - struct { - unsigned long long int nraddr; /* where's the code? */ - unsigned long long int r2; /* what tocptr do we need? */ - } - OrigFn; - -#define __SPECIAL_INSTRUCTION_PREAMBLE \ - "rotldi 0,0,3 ; rotldi 0,0,13\n\t" \ - "rotldi 0,0,61 ; rotldi 0,0,51\n\t" - -#define VALGRIND_DO_CLIENT_REQUEST( \ - _zzq_rlval, _zzq_default, _zzq_request, \ - _zzq_arg1, _zzq_arg2, _zzq_arg3, _zzq_arg4, _zzq_arg5) \ - \ - { unsigned long long int _zzq_args[7]; \ - register unsigned long long int _zzq_result; \ - register unsigned long long int* _zzq_ptr; \ - _zzq_args[0] = (unsigned int long long)(_zzq_request); \ - _zzq_args[1] = (unsigned int long long)(_zzq_arg1); \ - _zzq_args[2] = (unsigned int long long)(_zzq_arg2); \ - _zzq_args[3] = (unsigned int long long)(_zzq_arg3); \ - _zzq_args[4] = (unsigned int long long)(_zzq_arg4); \ - _zzq_args[5] = (unsigned int long long)(_zzq_arg5); \ - _zzq_args[6] = (unsigned int long long)(_zzq_default); \ - _zzq_ptr = _zzq_args; \ - __asm__ volatile("mr 4,%1\n\t" \ - "ld 3, 48(4)\n\t" \ - __SPECIAL_INSTRUCTION_PREAMBLE \ - /* %R3 = client_request ( %R4 ) */ \ - "or 1,1,1\n\t" \ - "mr %0,3" \ - : "=b" (_zzq_result) \ - : "b" (_zzq_ptr) \ - : "r3", "r4", "cc", "memory"); \ - _zzq_rlval = _zzq_result; \ - } - -#define VALGRIND_GET_NR_CONTEXT(_zzq_rlval) \ - { volatile OrigFn* _zzq_orig = &(_zzq_rlval); \ - register unsigned long long int __addr; \ - __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ - /* %R3 = guest_NRADDR */ \ - "or 2,2,2\n\t" \ - "mr %0,3" \ - : "=b" (__addr) \ - : \ - : "r3", "cc", "memory" \ - ); \ - _zzq_orig->nraddr = __addr; \ - __asm__ volatile(__SPECIAL_INSTRUCTION_PREAMBLE \ - /* %R3 = guest_NRADDR_GPR2 */ \ - "or 4,4,4\n\t" \ - "mr %0,3" \ - : "=b" (__addr) \ - : \ - : "r3", "cc", "memory" \ - ); \ - _zzq_orig->r2 = __addr; \ - } - -#define VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - __SPECIAL_INSTRUCTION_PREAMBLE \ - /* branch-and-link-to-noredir *%R11 */ \ - "or 3,3,3\n\t" - -#endif /* PLAT_ppc64_aix5 */ - -/* Insert assembly code for other platforms here... */ - -#endif /* NVALGRIND */ - - -/* ------------------------------------------------------------------ */ -/* PLATFORM SPECIFICS for FUNCTION WRAPPING. This is all very */ -/* ugly. It's the least-worst tradeoff I can think of. */ -/* ------------------------------------------------------------------ */ - -/* This section defines magic (a.k.a appalling-hack) macros for doing - guaranteed-no-redirection macros, so as to get from function - wrappers to the functions they are wrapping. The whole point is to - construct standard call sequences, but to do the call itself with a - special no-redirect call pseudo-instruction that the JIT - understands and handles specially. This section is long and - repetitious, and I can't see a way to make it shorter. - - The naming scheme is as follows: - - CALL_FN_{W,v}_{v,W,WW,WWW,WWWW,5W,6W,7W,etc} - - 'W' stands for "word" and 'v' for "void". Hence there are - different macros for calling arity 0, 1, 2, 3, 4, etc, functions, - and for each, the possibility of returning a word-typed result, or - no result. -*/ - -/* Use these to write the name of your wrapper. NOTE: duplicates - VG_WRAP_FUNCTION_Z{U,Z} in pub_tool_redir.h. */ - -/* Use an extra level of macroisation so as to ensure the soname/fnname - args are fully macro-expanded before pasting them together. */ -#define VG_CONCAT4(_aa,_bb,_cc,_dd) _aa##_bb##_cc##_dd - -#define I_WRAP_SONAME_FNNAME_ZU(soname,fnname) \ - VG_CONCAT4(_vgwZU_,soname,_,fnname) - -#define I_WRAP_SONAME_FNNAME_ZZ(soname,fnname) \ - VG_CONCAT4(_vgwZZ_,soname,_,fnname) - -/* Use this macro from within a wrapper function to collect the - context (address and possibly other info) of the original function. - Once you have that you can then use it in one of the CALL_FN_ - macros. The type of the argument _lval is OrigFn. */ -#define VALGRIND_GET_ORIG_FN(_lval) VALGRIND_GET_NR_CONTEXT(_lval) - -/* Derivatives of the main macros below, for calling functions - returning void. */ - -#define CALL_FN_v_v(fnptr) \ - do { volatile unsigned long _junk; \ - CALL_FN_W_v(_junk,fnptr); } while (0) - -#define CALL_FN_v_W(fnptr, arg1) \ - do { volatile unsigned long _junk; \ - CALL_FN_W_W(_junk,fnptr,arg1); } while (0) - -#define CALL_FN_v_WW(fnptr, arg1,arg2) \ - do { volatile unsigned long _junk; \ - CALL_FN_W_WW(_junk,fnptr,arg1,arg2); } while (0) - -#define CALL_FN_v_WWW(fnptr, arg1,arg2,arg3) \ - do { volatile unsigned long _junk; \ - CALL_FN_W_WWW(_junk,fnptr,arg1,arg2,arg3); } while (0) - -#define CALL_FN_v_WWWW(fnptr, arg1,arg2,arg3,arg4) \ - do { volatile unsigned long _junk; \ - CALL_FN_W_WWWW(_junk,fnptr,arg1,arg2,arg3,arg4); } while (0) - -#define CALL_FN_v_5W(fnptr, arg1,arg2,arg3,arg4,arg5) \ - do { volatile unsigned long _junk; \ - CALL_FN_W_5W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5); } while (0) - -#define CALL_FN_v_6W(fnptr, arg1,arg2,arg3,arg4,arg5,arg6) \ - do { volatile unsigned long _junk; \ - CALL_FN_W_6W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5,arg6); } while (0) - -#define CALL_FN_v_7W(fnptr, arg1,arg2,arg3,arg4,arg5,arg6,arg7) \ - do { volatile unsigned long _junk; \ - CALL_FN_W_7W(_junk,fnptr,arg1,arg2,arg3,arg4,arg5,arg6,arg7); } while (0) - -/* ------------------------- x86-{linux,darwin} ---------------- */ - -#if defined(PLAT_x86_linux) || defined(PLAT_x86_darwin) - -/* These regs are trashed by the hidden call. No need to mention eax - as gcc can already see that, plus causes gcc to bomb. */ -#define __CALLER_SAVED_REGS /*"eax"*/ "ecx", "edx" - -/* These CALL_FN_ macros assume that on x86-linux, sizeof(unsigned - long) == 4. */ - -#define CALL_FN_W_v(lval, orig) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[1]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - __asm__ volatile( \ - "movl (%%eax), %%eax\n\t" /* target->%eax */ \ - VALGRIND_CALL_NOREDIR_EAX \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_W(lval, orig, arg1) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[2]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - __asm__ volatile( \ - "pushl 4(%%eax)\n\t" \ - "movl (%%eax), %%eax\n\t" /* target->%eax */ \ - VALGRIND_CALL_NOREDIR_EAX \ - "addl $4, %%esp\n" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - __asm__ volatile( \ - "pushl 8(%%eax)\n\t" \ - "pushl 4(%%eax)\n\t" \ - "movl (%%eax), %%eax\n\t" /* target->%eax */ \ - VALGRIND_CALL_NOREDIR_EAX \ - "addl $8, %%esp\n" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[4]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - __asm__ volatile( \ - "pushl 12(%%eax)\n\t" \ - "pushl 8(%%eax)\n\t" \ - "pushl 4(%%eax)\n\t" \ - "movl (%%eax), %%eax\n\t" /* target->%eax */ \ - VALGRIND_CALL_NOREDIR_EAX \ - "addl $12, %%esp\n" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[5]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - __asm__ volatile( \ - "pushl 16(%%eax)\n\t" \ - "pushl 12(%%eax)\n\t" \ - "pushl 8(%%eax)\n\t" \ - "pushl 4(%%eax)\n\t" \ - "movl (%%eax), %%eax\n\t" /* target->%eax */ \ - VALGRIND_CALL_NOREDIR_EAX \ - "addl $16, %%esp\n" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[6]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - __asm__ volatile( \ - "pushl 20(%%eax)\n\t" \ - "pushl 16(%%eax)\n\t" \ - "pushl 12(%%eax)\n\t" \ - "pushl 8(%%eax)\n\t" \ - "pushl 4(%%eax)\n\t" \ - "movl (%%eax), %%eax\n\t" /* target->%eax */ \ - VALGRIND_CALL_NOREDIR_EAX \ - "addl $20, %%esp\n" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[7]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - __asm__ volatile( \ - "pushl 24(%%eax)\n\t" \ - "pushl 20(%%eax)\n\t" \ - "pushl 16(%%eax)\n\t" \ - "pushl 12(%%eax)\n\t" \ - "pushl 8(%%eax)\n\t" \ - "pushl 4(%%eax)\n\t" \ - "movl (%%eax), %%eax\n\t" /* target->%eax */ \ - VALGRIND_CALL_NOREDIR_EAX \ - "addl $24, %%esp\n" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[8]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - __asm__ volatile( \ - "pushl 28(%%eax)\n\t" \ - "pushl 24(%%eax)\n\t" \ - "pushl 20(%%eax)\n\t" \ - "pushl 16(%%eax)\n\t" \ - "pushl 12(%%eax)\n\t" \ - "pushl 8(%%eax)\n\t" \ - "pushl 4(%%eax)\n\t" \ - "movl (%%eax), %%eax\n\t" /* target->%eax */ \ - VALGRIND_CALL_NOREDIR_EAX \ - "addl $28, %%esp\n" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[9]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - __asm__ volatile( \ - "pushl 32(%%eax)\n\t" \ - "pushl 28(%%eax)\n\t" \ - "pushl 24(%%eax)\n\t" \ - "pushl 20(%%eax)\n\t" \ - "pushl 16(%%eax)\n\t" \ - "pushl 12(%%eax)\n\t" \ - "pushl 8(%%eax)\n\t" \ - "pushl 4(%%eax)\n\t" \ - "movl (%%eax), %%eax\n\t" /* target->%eax */ \ - VALGRIND_CALL_NOREDIR_EAX \ - "addl $32, %%esp\n" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[10]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - _argvec[9] = (unsigned long)(arg9); \ - __asm__ volatile( \ - "pushl 36(%%eax)\n\t" \ - "pushl 32(%%eax)\n\t" \ - "pushl 28(%%eax)\n\t" \ - "pushl 24(%%eax)\n\t" \ - "pushl 20(%%eax)\n\t" \ - "pushl 16(%%eax)\n\t" \ - "pushl 12(%%eax)\n\t" \ - "pushl 8(%%eax)\n\t" \ - "pushl 4(%%eax)\n\t" \ - "movl (%%eax), %%eax\n\t" /* target->%eax */ \ - VALGRIND_CALL_NOREDIR_EAX \ - "addl $36, %%esp\n" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[11]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - _argvec[9] = (unsigned long)(arg9); \ - _argvec[10] = (unsigned long)(arg10); \ - __asm__ volatile( \ - "pushl 40(%%eax)\n\t" \ - "pushl 36(%%eax)\n\t" \ - "pushl 32(%%eax)\n\t" \ - "pushl 28(%%eax)\n\t" \ - "pushl 24(%%eax)\n\t" \ - "pushl 20(%%eax)\n\t" \ - "pushl 16(%%eax)\n\t" \ - "pushl 12(%%eax)\n\t" \ - "pushl 8(%%eax)\n\t" \ - "pushl 4(%%eax)\n\t" \ - "movl (%%eax), %%eax\n\t" /* target->%eax */ \ - VALGRIND_CALL_NOREDIR_EAX \ - "addl $40, %%esp\n" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ - arg6,arg7,arg8,arg9,arg10, \ - arg11) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[12]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - _argvec[9] = (unsigned long)(arg9); \ - _argvec[10] = (unsigned long)(arg10); \ - _argvec[11] = (unsigned long)(arg11); \ - __asm__ volatile( \ - "pushl 44(%%eax)\n\t" \ - "pushl 40(%%eax)\n\t" \ - "pushl 36(%%eax)\n\t" \ - "pushl 32(%%eax)\n\t" \ - "pushl 28(%%eax)\n\t" \ - "pushl 24(%%eax)\n\t" \ - "pushl 20(%%eax)\n\t" \ - "pushl 16(%%eax)\n\t" \ - "pushl 12(%%eax)\n\t" \ - "pushl 8(%%eax)\n\t" \ - "pushl 4(%%eax)\n\t" \ - "movl (%%eax), %%eax\n\t" /* target->%eax */ \ - VALGRIND_CALL_NOREDIR_EAX \ - "addl $44, %%esp\n" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ - arg6,arg7,arg8,arg9,arg10, \ - arg11,arg12) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[13]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - _argvec[9] = (unsigned long)(arg9); \ - _argvec[10] = (unsigned long)(arg10); \ - _argvec[11] = (unsigned long)(arg11); \ - _argvec[12] = (unsigned long)(arg12); \ - __asm__ volatile( \ - "pushl 48(%%eax)\n\t" \ - "pushl 44(%%eax)\n\t" \ - "pushl 40(%%eax)\n\t" \ - "pushl 36(%%eax)\n\t" \ - "pushl 32(%%eax)\n\t" \ - "pushl 28(%%eax)\n\t" \ - "pushl 24(%%eax)\n\t" \ - "pushl 20(%%eax)\n\t" \ - "pushl 16(%%eax)\n\t" \ - "pushl 12(%%eax)\n\t" \ - "pushl 8(%%eax)\n\t" \ - "pushl 4(%%eax)\n\t" \ - "movl (%%eax), %%eax\n\t" /* target->%eax */ \ - VALGRIND_CALL_NOREDIR_EAX \ - "addl $48, %%esp\n" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#endif /* PLAT_x86_linux || PLAT_x86_darwin */ - -/* ------------------------ amd64-{linux,darwin} --------------- */ - -#if defined(PLAT_amd64_linux) || defined(PLAT_amd64_darwin) - -/* ARGREGS: rdi rsi rdx rcx r8 r9 (the rest on stack in R-to-L order) */ - -/* These regs are trashed by the hidden call. */ -#define __CALLER_SAVED_REGS /*"rax",*/ "rcx", "rdx", "rsi", \ - "rdi", "r8", "r9", "r10", "r11" - -/* These CALL_FN_ macros assume that on amd64-linux, sizeof(unsigned - long) == 8. */ - -/* NB 9 Sept 07. There is a nasty kludge here in all these CALL_FN_ - macros. In order not to trash the stack redzone, we need to drop - %rsp by 128 before the hidden call, and restore afterwards. The - nastyness is that it is only by luck that the stack still appears - to be unwindable during the hidden call - since then the behaviour - of any routine using this macro does not match what the CFI data - says. Sigh. - - Why is this important? Imagine that a wrapper has a stack - allocated local, and passes to the hidden call, a pointer to it. - Because gcc does not know about the hidden call, it may allocate - that local in the redzone. Unfortunately the hidden call may then - trash it before it comes to use it. So we must step clear of the - redzone, for the duration of the hidden call, to make it safe. - - Probably the same problem afflicts the other redzone-style ABIs too - (ppc64-linux, ppc32-aix5, ppc64-aix5); but for those, the stack is - self describing (none of this CFI nonsense) so at least messing - with the stack pointer doesn't give a danger of non-unwindable - stack. */ - -#define CALL_FN_W_v(lval, orig) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[1]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - __asm__ volatile( \ - "subq $128,%%rsp\n\t" \ - "movq (%%rax), %%rax\n\t" /* target->%rax */ \ - VALGRIND_CALL_NOREDIR_RAX \ - "addq $128,%%rsp\n\t" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_W(lval, orig, arg1) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[2]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - __asm__ volatile( \ - "subq $128,%%rsp\n\t" \ - "movq 8(%%rax), %%rdi\n\t" \ - "movq (%%rax), %%rax\n\t" /* target->%rax */ \ - VALGRIND_CALL_NOREDIR_RAX \ - "addq $128,%%rsp\n\t" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - __asm__ volatile( \ - "subq $128,%%rsp\n\t" \ - "movq 16(%%rax), %%rsi\n\t" \ - "movq 8(%%rax), %%rdi\n\t" \ - "movq (%%rax), %%rax\n\t" /* target->%rax */ \ - VALGRIND_CALL_NOREDIR_RAX \ - "addq $128,%%rsp\n\t" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[4]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - __asm__ volatile( \ - "subq $128,%%rsp\n\t" \ - "movq 24(%%rax), %%rdx\n\t" \ - "movq 16(%%rax), %%rsi\n\t" \ - "movq 8(%%rax), %%rdi\n\t" \ - "movq (%%rax), %%rax\n\t" /* target->%rax */ \ - VALGRIND_CALL_NOREDIR_RAX \ - "addq $128,%%rsp\n\t" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[5]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - __asm__ volatile( \ - "subq $128,%%rsp\n\t" \ - "movq 32(%%rax), %%rcx\n\t" \ - "movq 24(%%rax), %%rdx\n\t" \ - "movq 16(%%rax), %%rsi\n\t" \ - "movq 8(%%rax), %%rdi\n\t" \ - "movq (%%rax), %%rax\n\t" /* target->%rax */ \ - VALGRIND_CALL_NOREDIR_RAX \ - "addq $128,%%rsp\n\t" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[6]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - __asm__ volatile( \ - "subq $128,%%rsp\n\t" \ - "movq 40(%%rax), %%r8\n\t" \ - "movq 32(%%rax), %%rcx\n\t" \ - "movq 24(%%rax), %%rdx\n\t" \ - "movq 16(%%rax), %%rsi\n\t" \ - "movq 8(%%rax), %%rdi\n\t" \ - "movq (%%rax), %%rax\n\t" /* target->%rax */ \ - VALGRIND_CALL_NOREDIR_RAX \ - "addq $128,%%rsp\n\t" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[7]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - __asm__ volatile( \ - "subq $128,%%rsp\n\t" \ - "movq 48(%%rax), %%r9\n\t" \ - "movq 40(%%rax), %%r8\n\t" \ - "movq 32(%%rax), %%rcx\n\t" \ - "movq 24(%%rax), %%rdx\n\t" \ - "movq 16(%%rax), %%rsi\n\t" \ - "movq 8(%%rax), %%rdi\n\t" \ - "movq (%%rax), %%rax\n\t" /* target->%rax */ \ - "addq $128,%%rsp\n\t" \ - VALGRIND_CALL_NOREDIR_RAX \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[8]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - __asm__ volatile( \ - "subq $128,%%rsp\n\t" \ - "pushq 56(%%rax)\n\t" \ - "movq 48(%%rax), %%r9\n\t" \ - "movq 40(%%rax), %%r8\n\t" \ - "movq 32(%%rax), %%rcx\n\t" \ - "movq 24(%%rax), %%rdx\n\t" \ - "movq 16(%%rax), %%rsi\n\t" \ - "movq 8(%%rax), %%rdi\n\t" \ - "movq (%%rax), %%rax\n\t" /* target->%rax */ \ - VALGRIND_CALL_NOREDIR_RAX \ - "addq $8, %%rsp\n" \ - "addq $128,%%rsp\n\t" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[9]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - __asm__ volatile( \ - "subq $128,%%rsp\n\t" \ - "pushq 64(%%rax)\n\t" \ - "pushq 56(%%rax)\n\t" \ - "movq 48(%%rax), %%r9\n\t" \ - "movq 40(%%rax), %%r8\n\t" \ - "movq 32(%%rax), %%rcx\n\t" \ - "movq 24(%%rax), %%rdx\n\t" \ - "movq 16(%%rax), %%rsi\n\t" \ - "movq 8(%%rax), %%rdi\n\t" \ - "movq (%%rax), %%rax\n\t" /* target->%rax */ \ - VALGRIND_CALL_NOREDIR_RAX \ - "addq $16, %%rsp\n" \ - "addq $128,%%rsp\n\t" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[10]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - _argvec[9] = (unsigned long)(arg9); \ - __asm__ volatile( \ - "subq $128,%%rsp\n\t" \ - "pushq 72(%%rax)\n\t" \ - "pushq 64(%%rax)\n\t" \ - "pushq 56(%%rax)\n\t" \ - "movq 48(%%rax), %%r9\n\t" \ - "movq 40(%%rax), %%r8\n\t" \ - "movq 32(%%rax), %%rcx\n\t" \ - "movq 24(%%rax), %%rdx\n\t" \ - "movq 16(%%rax), %%rsi\n\t" \ - "movq 8(%%rax), %%rdi\n\t" \ - "movq (%%rax), %%rax\n\t" /* target->%rax */ \ - VALGRIND_CALL_NOREDIR_RAX \ - "addq $24, %%rsp\n" \ - "addq $128,%%rsp\n\t" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[11]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - _argvec[9] = (unsigned long)(arg9); \ - _argvec[10] = (unsigned long)(arg10); \ - __asm__ volatile( \ - "subq $128,%%rsp\n\t" \ - "pushq 80(%%rax)\n\t" \ - "pushq 72(%%rax)\n\t" \ - "pushq 64(%%rax)\n\t" \ - "pushq 56(%%rax)\n\t" \ - "movq 48(%%rax), %%r9\n\t" \ - "movq 40(%%rax), %%r8\n\t" \ - "movq 32(%%rax), %%rcx\n\t" \ - "movq 24(%%rax), %%rdx\n\t" \ - "movq 16(%%rax), %%rsi\n\t" \ - "movq 8(%%rax), %%rdi\n\t" \ - "movq (%%rax), %%rax\n\t" /* target->%rax */ \ - VALGRIND_CALL_NOREDIR_RAX \ - "addq $32, %%rsp\n" \ - "addq $128,%%rsp\n\t" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10,arg11) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[12]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - _argvec[9] = (unsigned long)(arg9); \ - _argvec[10] = (unsigned long)(arg10); \ - _argvec[11] = (unsigned long)(arg11); \ - __asm__ volatile( \ - "subq $128,%%rsp\n\t" \ - "pushq 88(%%rax)\n\t" \ - "pushq 80(%%rax)\n\t" \ - "pushq 72(%%rax)\n\t" \ - "pushq 64(%%rax)\n\t" \ - "pushq 56(%%rax)\n\t" \ - "movq 48(%%rax), %%r9\n\t" \ - "movq 40(%%rax), %%r8\n\t" \ - "movq 32(%%rax), %%rcx\n\t" \ - "movq 24(%%rax), %%rdx\n\t" \ - "movq 16(%%rax), %%rsi\n\t" \ - "movq 8(%%rax), %%rdi\n\t" \ - "movq (%%rax), %%rax\n\t" /* target->%rax */ \ - VALGRIND_CALL_NOREDIR_RAX \ - "addq $40, %%rsp\n" \ - "addq $128,%%rsp\n\t" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10,arg11,arg12) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[13]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - _argvec[9] = (unsigned long)(arg9); \ - _argvec[10] = (unsigned long)(arg10); \ - _argvec[11] = (unsigned long)(arg11); \ - _argvec[12] = (unsigned long)(arg12); \ - __asm__ volatile( \ - "subq $128,%%rsp\n\t" \ - "pushq 96(%%rax)\n\t" \ - "pushq 88(%%rax)\n\t" \ - "pushq 80(%%rax)\n\t" \ - "pushq 72(%%rax)\n\t" \ - "pushq 64(%%rax)\n\t" \ - "pushq 56(%%rax)\n\t" \ - "movq 48(%%rax), %%r9\n\t" \ - "movq 40(%%rax), %%r8\n\t" \ - "movq 32(%%rax), %%rcx\n\t" \ - "movq 24(%%rax), %%rdx\n\t" \ - "movq 16(%%rax), %%rsi\n\t" \ - "movq 8(%%rax), %%rdi\n\t" \ - "movq (%%rax), %%rax\n\t" /* target->%rax */ \ - VALGRIND_CALL_NOREDIR_RAX \ - "addq $48, %%rsp\n" \ - "addq $128,%%rsp\n\t" \ - : /*out*/ "=a" (_res) \ - : /*in*/ "a" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#endif /* PLAT_amd64_linux || PLAT_amd64_darwin */ - -/* ------------------------ ppc32-linux ------------------------ */ - -#if defined(PLAT_ppc32_linux) - -/* This is useful for finding out about the on-stack stuff: - - extern int f9 ( int,int,int,int,int,int,int,int,int ); - extern int f10 ( int,int,int,int,int,int,int,int,int,int ); - extern int f11 ( int,int,int,int,int,int,int,int,int,int,int ); - extern int f12 ( int,int,int,int,int,int,int,int,int,int,int,int ); - - int g9 ( void ) { - return f9(11,22,33,44,55,66,77,88,99); - } - int g10 ( void ) { - return f10(11,22,33,44,55,66,77,88,99,110); - } - int g11 ( void ) { - return f11(11,22,33,44,55,66,77,88,99,110,121); - } - int g12 ( void ) { - return f12(11,22,33,44,55,66,77,88,99,110,121,132); - } -*/ - -/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ - -/* These regs are trashed by the hidden call. */ -#define __CALLER_SAVED_REGS \ - "lr", "ctr", "xer", \ - "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ - "r0", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ - "r11", "r12", "r13" - -/* These CALL_FN_ macros assume that on ppc32-linux, - sizeof(unsigned long) == 4. */ - -#define CALL_FN_W_v(lval, orig) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[1]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "lwz 11,0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr %0,3" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_W(lval, orig, arg1) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[2]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)arg1; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "lwz 3,4(11)\n\t" /* arg1->r3 */ \ - "lwz 11,0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr %0,3" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)arg1; \ - _argvec[2] = (unsigned long)arg2; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "lwz 3,4(11)\n\t" /* arg1->r3 */ \ - "lwz 4,8(11)\n\t" \ - "lwz 11,0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr %0,3" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[4]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)arg1; \ - _argvec[2] = (unsigned long)arg2; \ - _argvec[3] = (unsigned long)arg3; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "lwz 3,4(11)\n\t" /* arg1->r3 */ \ - "lwz 4,8(11)\n\t" \ - "lwz 5,12(11)\n\t" \ - "lwz 11,0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr %0,3" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[5]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)arg1; \ - _argvec[2] = (unsigned long)arg2; \ - _argvec[3] = (unsigned long)arg3; \ - _argvec[4] = (unsigned long)arg4; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "lwz 3,4(11)\n\t" /* arg1->r3 */ \ - "lwz 4,8(11)\n\t" \ - "lwz 5,12(11)\n\t" \ - "lwz 6,16(11)\n\t" /* arg4->r6 */ \ - "lwz 11,0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr %0,3" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[6]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)arg1; \ - _argvec[2] = (unsigned long)arg2; \ - _argvec[3] = (unsigned long)arg3; \ - _argvec[4] = (unsigned long)arg4; \ - _argvec[5] = (unsigned long)arg5; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "lwz 3,4(11)\n\t" /* arg1->r3 */ \ - "lwz 4,8(11)\n\t" \ - "lwz 5,12(11)\n\t" \ - "lwz 6,16(11)\n\t" /* arg4->r6 */ \ - "lwz 7,20(11)\n\t" \ - "lwz 11,0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr %0,3" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[7]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)arg1; \ - _argvec[2] = (unsigned long)arg2; \ - _argvec[3] = (unsigned long)arg3; \ - _argvec[4] = (unsigned long)arg4; \ - _argvec[5] = (unsigned long)arg5; \ - _argvec[6] = (unsigned long)arg6; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "lwz 3,4(11)\n\t" /* arg1->r3 */ \ - "lwz 4,8(11)\n\t" \ - "lwz 5,12(11)\n\t" \ - "lwz 6,16(11)\n\t" /* arg4->r6 */ \ - "lwz 7,20(11)\n\t" \ - "lwz 8,24(11)\n\t" \ - "lwz 11,0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr %0,3" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[8]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)arg1; \ - _argvec[2] = (unsigned long)arg2; \ - _argvec[3] = (unsigned long)arg3; \ - _argvec[4] = (unsigned long)arg4; \ - _argvec[5] = (unsigned long)arg5; \ - _argvec[6] = (unsigned long)arg6; \ - _argvec[7] = (unsigned long)arg7; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "lwz 3,4(11)\n\t" /* arg1->r3 */ \ - "lwz 4,8(11)\n\t" \ - "lwz 5,12(11)\n\t" \ - "lwz 6,16(11)\n\t" /* arg4->r6 */ \ - "lwz 7,20(11)\n\t" \ - "lwz 8,24(11)\n\t" \ - "lwz 9,28(11)\n\t" \ - "lwz 11,0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr %0,3" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[9]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)arg1; \ - _argvec[2] = (unsigned long)arg2; \ - _argvec[3] = (unsigned long)arg3; \ - _argvec[4] = (unsigned long)arg4; \ - _argvec[5] = (unsigned long)arg5; \ - _argvec[6] = (unsigned long)arg6; \ - _argvec[7] = (unsigned long)arg7; \ - _argvec[8] = (unsigned long)arg8; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "lwz 3,4(11)\n\t" /* arg1->r3 */ \ - "lwz 4,8(11)\n\t" \ - "lwz 5,12(11)\n\t" \ - "lwz 6,16(11)\n\t" /* arg4->r6 */ \ - "lwz 7,20(11)\n\t" \ - "lwz 8,24(11)\n\t" \ - "lwz 9,28(11)\n\t" \ - "lwz 10,32(11)\n\t" /* arg8->r10 */ \ - "lwz 11,0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr %0,3" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[10]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)arg1; \ - _argvec[2] = (unsigned long)arg2; \ - _argvec[3] = (unsigned long)arg3; \ - _argvec[4] = (unsigned long)arg4; \ - _argvec[5] = (unsigned long)arg5; \ - _argvec[6] = (unsigned long)arg6; \ - _argvec[7] = (unsigned long)arg7; \ - _argvec[8] = (unsigned long)arg8; \ - _argvec[9] = (unsigned long)arg9; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "addi 1,1,-16\n\t" \ - /* arg9 */ \ - "lwz 3,36(11)\n\t" \ - "stw 3,8(1)\n\t" \ - /* args1-8 */ \ - "lwz 3,4(11)\n\t" /* arg1->r3 */ \ - "lwz 4,8(11)\n\t" \ - "lwz 5,12(11)\n\t" \ - "lwz 6,16(11)\n\t" /* arg4->r6 */ \ - "lwz 7,20(11)\n\t" \ - "lwz 8,24(11)\n\t" \ - "lwz 9,28(11)\n\t" \ - "lwz 10,32(11)\n\t" /* arg8->r10 */ \ - "lwz 11,0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "addi 1,1,16\n\t" \ - "mr %0,3" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[11]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)arg1; \ - _argvec[2] = (unsigned long)arg2; \ - _argvec[3] = (unsigned long)arg3; \ - _argvec[4] = (unsigned long)arg4; \ - _argvec[5] = (unsigned long)arg5; \ - _argvec[6] = (unsigned long)arg6; \ - _argvec[7] = (unsigned long)arg7; \ - _argvec[8] = (unsigned long)arg8; \ - _argvec[9] = (unsigned long)arg9; \ - _argvec[10] = (unsigned long)arg10; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "addi 1,1,-16\n\t" \ - /* arg10 */ \ - "lwz 3,40(11)\n\t" \ - "stw 3,12(1)\n\t" \ - /* arg9 */ \ - "lwz 3,36(11)\n\t" \ - "stw 3,8(1)\n\t" \ - /* args1-8 */ \ - "lwz 3,4(11)\n\t" /* arg1->r3 */ \ - "lwz 4,8(11)\n\t" \ - "lwz 5,12(11)\n\t" \ - "lwz 6,16(11)\n\t" /* arg4->r6 */ \ - "lwz 7,20(11)\n\t" \ - "lwz 8,24(11)\n\t" \ - "lwz 9,28(11)\n\t" \ - "lwz 10,32(11)\n\t" /* arg8->r10 */ \ - "lwz 11,0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "addi 1,1,16\n\t" \ - "mr %0,3" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10,arg11) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[12]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)arg1; \ - _argvec[2] = (unsigned long)arg2; \ - _argvec[3] = (unsigned long)arg3; \ - _argvec[4] = (unsigned long)arg4; \ - _argvec[5] = (unsigned long)arg5; \ - _argvec[6] = (unsigned long)arg6; \ - _argvec[7] = (unsigned long)arg7; \ - _argvec[8] = (unsigned long)arg8; \ - _argvec[9] = (unsigned long)arg9; \ - _argvec[10] = (unsigned long)arg10; \ - _argvec[11] = (unsigned long)arg11; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "addi 1,1,-32\n\t" \ - /* arg11 */ \ - "lwz 3,44(11)\n\t" \ - "stw 3,16(1)\n\t" \ - /* arg10 */ \ - "lwz 3,40(11)\n\t" \ - "stw 3,12(1)\n\t" \ - /* arg9 */ \ - "lwz 3,36(11)\n\t" \ - "stw 3,8(1)\n\t" \ - /* args1-8 */ \ - "lwz 3,4(11)\n\t" /* arg1->r3 */ \ - "lwz 4,8(11)\n\t" \ - "lwz 5,12(11)\n\t" \ - "lwz 6,16(11)\n\t" /* arg4->r6 */ \ - "lwz 7,20(11)\n\t" \ - "lwz 8,24(11)\n\t" \ - "lwz 9,28(11)\n\t" \ - "lwz 10,32(11)\n\t" /* arg8->r10 */ \ - "lwz 11,0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "addi 1,1,32\n\t" \ - "mr %0,3" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10,arg11,arg12) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[13]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)arg1; \ - _argvec[2] = (unsigned long)arg2; \ - _argvec[3] = (unsigned long)arg3; \ - _argvec[4] = (unsigned long)arg4; \ - _argvec[5] = (unsigned long)arg5; \ - _argvec[6] = (unsigned long)arg6; \ - _argvec[7] = (unsigned long)arg7; \ - _argvec[8] = (unsigned long)arg8; \ - _argvec[9] = (unsigned long)arg9; \ - _argvec[10] = (unsigned long)arg10; \ - _argvec[11] = (unsigned long)arg11; \ - _argvec[12] = (unsigned long)arg12; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "addi 1,1,-32\n\t" \ - /* arg12 */ \ - "lwz 3,48(11)\n\t" \ - "stw 3,20(1)\n\t" \ - /* arg11 */ \ - "lwz 3,44(11)\n\t" \ - "stw 3,16(1)\n\t" \ - /* arg10 */ \ - "lwz 3,40(11)\n\t" \ - "stw 3,12(1)\n\t" \ - /* arg9 */ \ - "lwz 3,36(11)\n\t" \ - "stw 3,8(1)\n\t" \ - /* args1-8 */ \ - "lwz 3,4(11)\n\t" /* arg1->r3 */ \ - "lwz 4,8(11)\n\t" \ - "lwz 5,12(11)\n\t" \ - "lwz 6,16(11)\n\t" /* arg4->r6 */ \ - "lwz 7,20(11)\n\t" \ - "lwz 8,24(11)\n\t" \ - "lwz 9,28(11)\n\t" \ - "lwz 10,32(11)\n\t" /* arg8->r10 */ \ - "lwz 11,0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "addi 1,1,32\n\t" \ - "mr %0,3" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#endif /* PLAT_ppc32_linux */ - -/* ------------------------ ppc64-linux ------------------------ */ - -#if defined(PLAT_ppc64_linux) - -/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ - -/* These regs are trashed by the hidden call. */ -#define __CALLER_SAVED_REGS \ - "lr", "ctr", "xer", \ - "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ - "r0", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ - "r11", "r12", "r13" - -/* These CALL_FN_ macros assume that on ppc64-linux, sizeof(unsigned - long) == 8. */ - -#define CALL_FN_W_v(lval, orig) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+0]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)" /* restore tocptr */ \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_W(lval, orig, arg1) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+1]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)" /* restore tocptr */ \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+2]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)" /* restore tocptr */ \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+3]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)" /* restore tocptr */ \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+4]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)" /* restore tocptr */ \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+5]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)" /* restore tocptr */ \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+6]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 8, 48(11)\n\t" /* arg6->r8 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)" /* restore tocptr */ \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+7]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 8, 48(11)\n\t" /* arg6->r8 */ \ - "ld 9, 56(11)\n\t" /* arg7->r9 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)" /* restore tocptr */ \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+8]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 8, 48(11)\n\t" /* arg6->r8 */ \ - "ld 9, 56(11)\n\t" /* arg7->r9 */ \ - "ld 10, 64(11)\n\t" /* arg8->r10 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)" /* restore tocptr */ \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+9]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - _argvec[2+9] = (unsigned long)arg9; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "addi 1,1,-128\n\t" /* expand stack frame */ \ - /* arg9 */ \ - "ld 3,72(11)\n\t" \ - "std 3,112(1)\n\t" \ - /* args1-8 */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 8, 48(11)\n\t" /* arg6->r8 */ \ - "ld 9, 56(11)\n\t" /* arg7->r9 */ \ - "ld 10, 64(11)\n\t" /* arg8->r10 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - "addi 1,1,128" /* restore frame */ \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+10]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - _argvec[2+9] = (unsigned long)arg9; \ - _argvec[2+10] = (unsigned long)arg10; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "addi 1,1,-128\n\t" /* expand stack frame */ \ - /* arg10 */ \ - "ld 3,80(11)\n\t" \ - "std 3,120(1)\n\t" \ - /* arg9 */ \ - "ld 3,72(11)\n\t" \ - "std 3,112(1)\n\t" \ - /* args1-8 */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 8, 48(11)\n\t" /* arg6->r8 */ \ - "ld 9, 56(11)\n\t" /* arg7->r9 */ \ - "ld 10, 64(11)\n\t" /* arg8->r10 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - "addi 1,1,128" /* restore frame */ \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10,arg11) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+11]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - _argvec[2+9] = (unsigned long)arg9; \ - _argvec[2+10] = (unsigned long)arg10; \ - _argvec[2+11] = (unsigned long)arg11; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "addi 1,1,-144\n\t" /* expand stack frame */ \ - /* arg11 */ \ - "ld 3,88(11)\n\t" \ - "std 3,128(1)\n\t" \ - /* arg10 */ \ - "ld 3,80(11)\n\t" \ - "std 3,120(1)\n\t" \ - /* arg9 */ \ - "ld 3,72(11)\n\t" \ - "std 3,112(1)\n\t" \ - /* args1-8 */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 8, 48(11)\n\t" /* arg6->r8 */ \ - "ld 9, 56(11)\n\t" /* arg7->r9 */ \ - "ld 10, 64(11)\n\t" /* arg8->r10 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - "addi 1,1,144" /* restore frame */ \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10,arg11,arg12) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+12]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - _argvec[2+9] = (unsigned long)arg9; \ - _argvec[2+10] = (unsigned long)arg10; \ - _argvec[2+11] = (unsigned long)arg11; \ - _argvec[2+12] = (unsigned long)arg12; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "addi 1,1,-144\n\t" /* expand stack frame */ \ - /* arg12 */ \ - "ld 3,96(11)\n\t" \ - "std 3,136(1)\n\t" \ - /* arg11 */ \ - "ld 3,88(11)\n\t" \ - "std 3,128(1)\n\t" \ - /* arg10 */ \ - "ld 3,80(11)\n\t" \ - "std 3,120(1)\n\t" \ - /* arg9 */ \ - "ld 3,72(11)\n\t" \ - "std 3,112(1)\n\t" \ - /* args1-8 */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 8, 48(11)\n\t" /* arg6->r8 */ \ - "ld 9, 56(11)\n\t" /* arg7->r9 */ \ - "ld 10, 64(11)\n\t" /* arg8->r10 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - "addi 1,1,144" /* restore frame */ \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#endif /* PLAT_ppc64_linux */ - -/* ------------------------- arm-linux ------------------------- */ - -#if defined(PLAT_arm_linux) - -/* These regs are trashed by the hidden call. */ -#define __CALLER_SAVED_REGS "r0", "r1", "r2", "r3","r4","r14" - -/* These CALL_FN_ macros assume that on arm-linux, sizeof(unsigned - long) == 4. */ - -#define CALL_FN_W_v(lval, orig) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[1]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - __asm__ volatile( \ - "ldr r4, [%1] \n\t" /* target->r4 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ - "mov %0, r0\n" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "0" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_W(lval, orig, arg1) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[2]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - __asm__ volatile( \ - "ldr r0, [%1, #4] \n\t" \ - "ldr r4, [%1] \n\t" /* target->r4 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ - "mov %0, r0\n" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "0" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - __asm__ volatile( \ - "ldr r0, [%1, #4] \n\t" \ - "ldr r1, [%1, #8] \n\t" \ - "ldr r4, [%1] \n\t" /* target->r4 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ - "mov %0, r0\n" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "0" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[4]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - __asm__ volatile( \ - "ldr r0, [%1, #4] \n\t" \ - "ldr r1, [%1, #8] \n\t" \ - "ldr r2, [%1, #12] \n\t" \ - "ldr r4, [%1] \n\t" /* target->r4 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ - "mov %0, r0\n" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "0" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[5]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - __asm__ volatile( \ - "ldr r0, [%1, #4] \n\t" \ - "ldr r1, [%1, #8] \n\t" \ - "ldr r2, [%1, #12] \n\t" \ - "ldr r3, [%1, #16] \n\t" \ - "ldr r4, [%1] \n\t" /* target->r4 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ - "mov %0, r0" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "0" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[6]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - __asm__ volatile( \ - "ldr r0, [%1, #20] \n\t" \ - "push {r0} \n\t" \ - "ldr r0, [%1, #4] \n\t" \ - "ldr r1, [%1, #8] \n\t" \ - "ldr r2, [%1, #12] \n\t" \ - "ldr r3, [%1, #16] \n\t" \ - "ldr r4, [%1] \n\t" /* target->r4 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ - "add sp, sp, #4 \n\t" \ - "mov %0, r0" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "0" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[7]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - __asm__ volatile( \ - "ldr r0, [%1, #20] \n\t" \ - "ldr r1, [%1, #24] \n\t" \ - "push {r0, r1} \n\t" \ - "ldr r0, [%1, #4] \n\t" \ - "ldr r1, [%1, #8] \n\t" \ - "ldr r2, [%1, #12] \n\t" \ - "ldr r3, [%1, #16] \n\t" \ - "ldr r4, [%1] \n\t" /* target->r4 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ - "add sp, sp, #8 \n\t" \ - "mov %0, r0" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "0" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[8]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - __asm__ volatile( \ - "ldr r0, [%1, #20] \n\t" \ - "ldr r1, [%1, #24] \n\t" \ - "ldr r2, [%1, #28] \n\t" \ - "push {r0, r1, r2} \n\t" \ - "ldr r0, [%1, #4] \n\t" \ - "ldr r1, [%1, #8] \n\t" \ - "ldr r2, [%1, #12] \n\t" \ - "ldr r3, [%1, #16] \n\t" \ - "ldr r4, [%1] \n\t" /* target->r4 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ - "add sp, sp, #12 \n\t" \ - "mov %0, r0" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "0" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[9]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - __asm__ volatile( \ - "ldr r0, [%1, #20] \n\t" \ - "ldr r1, [%1, #24] \n\t" \ - "ldr r2, [%1, #28] \n\t" \ - "ldr r3, [%1, #32] \n\t" \ - "push {r0, r1, r2, r3} \n\t" \ - "ldr r0, [%1, #4] \n\t" \ - "ldr r1, [%1, #8] \n\t" \ - "ldr r2, [%1, #12] \n\t" \ - "ldr r3, [%1, #16] \n\t" \ - "ldr r4, [%1] \n\t" /* target->r4 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ - "add sp, sp, #16 \n\t" \ - "mov %0, r0" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "0" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[10]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - _argvec[9] = (unsigned long)(arg9); \ - __asm__ volatile( \ - "ldr r0, [%1, #20] \n\t" \ - "ldr r1, [%1, #24] \n\t" \ - "ldr r2, [%1, #28] \n\t" \ - "ldr r3, [%1, #32] \n\t" \ - "ldr r4, [%1, #36] \n\t" \ - "push {r0, r1, r2, r3, r4} \n\t" \ - "ldr r0, [%1, #4] \n\t" \ - "ldr r1, [%1, #8] \n\t" \ - "ldr r2, [%1, #12] \n\t" \ - "ldr r3, [%1, #16] \n\t" \ - "ldr r4, [%1] \n\t" /* target->r4 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ - "add sp, sp, #20 \n\t" \ - "mov %0, r0" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "0" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[11]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - _argvec[9] = (unsigned long)(arg9); \ - _argvec[10] = (unsigned long)(arg10); \ - __asm__ volatile( \ - "ldr r0, [%1, #40] \n\t" \ - "push {r0} \n\t" \ - "ldr r0, [%1, #20] \n\t" \ - "ldr r1, [%1, #24] \n\t" \ - "ldr r2, [%1, #28] \n\t" \ - "ldr r3, [%1, #32] \n\t" \ - "ldr r4, [%1, #36] \n\t" \ - "push {r0, r1, r2, r3, r4} \n\t" \ - "ldr r0, [%1, #4] \n\t" \ - "ldr r1, [%1, #8] \n\t" \ - "ldr r2, [%1, #12] \n\t" \ - "ldr r3, [%1, #16] \n\t" \ - "ldr r4, [%1] \n\t" /* target->r4 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ - "add sp, sp, #24 \n\t" \ - "mov %0, r0" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "0" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ - arg6,arg7,arg8,arg9,arg10, \ - arg11) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[12]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - _argvec[9] = (unsigned long)(arg9); \ - _argvec[10] = (unsigned long)(arg10); \ - _argvec[11] = (unsigned long)(arg11); \ - __asm__ volatile( \ - "ldr r0, [%1, #40] \n\t" \ - "ldr r1, [%1, #44] \n\t" \ - "push {r0, r1} \n\t" \ - "ldr r0, [%1, #20] \n\t" \ - "ldr r1, [%1, #24] \n\t" \ - "ldr r2, [%1, #28] \n\t" \ - "ldr r3, [%1, #32] \n\t" \ - "ldr r4, [%1, #36] \n\t" \ - "push {r0, r1, r2, r3, r4} \n\t" \ - "ldr r0, [%1, #4] \n\t" \ - "ldr r1, [%1, #8] \n\t" \ - "ldr r2, [%1, #12] \n\t" \ - "ldr r3, [%1, #16] \n\t" \ - "ldr r4, [%1] \n\t" /* target->r4 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ - "add sp, sp, #28 \n\t" \ - "mov %0, r0" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "0" (&_argvec[0]) \ - : /*trash*/ "cc", "memory",__CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5, \ - arg6,arg7,arg8,arg9,arg10, \ - arg11,arg12) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[13]; \ - volatile unsigned long _res; \ - _argvec[0] = (unsigned long)_orig.nraddr; \ - _argvec[1] = (unsigned long)(arg1); \ - _argvec[2] = (unsigned long)(arg2); \ - _argvec[3] = (unsigned long)(arg3); \ - _argvec[4] = (unsigned long)(arg4); \ - _argvec[5] = (unsigned long)(arg5); \ - _argvec[6] = (unsigned long)(arg6); \ - _argvec[7] = (unsigned long)(arg7); \ - _argvec[8] = (unsigned long)(arg8); \ - _argvec[9] = (unsigned long)(arg9); \ - _argvec[10] = (unsigned long)(arg10); \ - _argvec[11] = (unsigned long)(arg11); \ - _argvec[12] = (unsigned long)(arg12); \ - __asm__ volatile( \ - "ldr r0, [%1, #40] \n\t" \ - "ldr r1, [%1, #44] \n\t" \ - "ldr r2, [%1, #48] \n\t" \ - "push {r0, r1, r2} \n\t" \ - "ldr r0, [%1, #20] \n\t" \ - "ldr r1, [%1, #24] \n\t" \ - "ldr r2, [%1, #28] \n\t" \ - "ldr r3, [%1, #32] \n\t" \ - "ldr r4, [%1, #36] \n\t" \ - "push {r0, r1, r2, r3, r4} \n\t" \ - "ldr r0, [%1, #4] \n\t" \ - "ldr r1, [%1, #8] \n\t" \ - "ldr r2, [%1, #12] \n\t" \ - "ldr r3, [%1, #16] \n\t" \ - "ldr r4, [%1] \n\t" /* target->r4 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R4 \ - "add sp, sp, #32 \n\t" \ - "mov %0, r0" \ - : /*out*/ "=r" (_res) \ - : /*in*/ "0" (&_argvec[0]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#endif /* PLAT_arm_linux */ - -/* ------------------------ ppc32-aix5 ------------------------- */ - -#if defined(PLAT_ppc32_aix5) - -/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ - -/* These regs are trashed by the hidden call. */ -#define __CALLER_SAVED_REGS \ - "lr", "ctr", "xer", \ - "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ - "r0", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ - "r11", "r12", "r13" - -/* Expand the stack frame, copying enough info that unwinding - still works. Trashes r3. */ - -#define VG_EXPAND_FRAME_BY_trashes_r3(_n_fr) \ - "addi 1,1,-" #_n_fr "\n\t" \ - "lwz 3," #_n_fr "(1)\n\t" \ - "stw 3,0(1)\n\t" - -#define VG_CONTRACT_FRAME_BY(_n_fr) \ - "addi 1,1," #_n_fr "\n\t" - -/* These CALL_FN_ macros assume that on ppc32-aix5, sizeof(unsigned - long) == 4. */ - -#define CALL_FN_W_v(lval, orig) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+0]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "stw 2,-8(11)\n\t" /* save tocptr */ \ - "lwz 2,-4(11)\n\t" /* use nraddr's tocptr */ \ - "lwz 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "lwz 2,-8(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_W(lval, orig, arg1) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+1]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "stw 2,-8(11)\n\t" /* save tocptr */ \ - "lwz 2,-4(11)\n\t" /* use nraddr's tocptr */ \ - "lwz 3, 4(11)\n\t" /* arg1->r3 */ \ - "lwz 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "lwz 2,-8(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+2]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "stw 2,-8(11)\n\t" /* save tocptr */ \ - "lwz 2,-4(11)\n\t" /* use nraddr's tocptr */ \ - "lwz 3, 4(11)\n\t" /* arg1->r3 */ \ - "lwz 4, 8(11)\n\t" /* arg2->r4 */ \ - "lwz 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "lwz 2,-8(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+3]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "stw 2,-8(11)\n\t" /* save tocptr */ \ - "lwz 2,-4(11)\n\t" /* use nraddr's tocptr */ \ - "lwz 3, 4(11)\n\t" /* arg1->r3 */ \ - "lwz 4, 8(11)\n\t" /* arg2->r4 */ \ - "lwz 5, 12(11)\n\t" /* arg3->r5 */ \ - "lwz 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "lwz 2,-8(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+4]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "stw 2,-8(11)\n\t" /* save tocptr */ \ - "lwz 2,-4(11)\n\t" /* use nraddr's tocptr */ \ - "lwz 3, 4(11)\n\t" /* arg1->r3 */ \ - "lwz 4, 8(11)\n\t" /* arg2->r4 */ \ - "lwz 5, 12(11)\n\t" /* arg3->r5 */ \ - "lwz 6, 16(11)\n\t" /* arg4->r6 */ \ - "lwz 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "lwz 2,-8(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+5]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "stw 2,-8(11)\n\t" /* save tocptr */ \ - "lwz 2,-4(11)\n\t" /* use nraddr's tocptr */ \ - "lwz 3, 4(11)\n\t" /* arg1->r3 */ \ - "lwz 4, 8(11)\n\t" /* arg2->r4 */ \ - "lwz 5, 12(11)\n\t" /* arg3->r5 */ \ - "lwz 6, 16(11)\n\t" /* arg4->r6 */ \ - "lwz 7, 20(11)\n\t" /* arg5->r7 */ \ - "lwz 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "lwz 2,-8(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+6]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "stw 2,-8(11)\n\t" /* save tocptr */ \ - "lwz 2,-4(11)\n\t" /* use nraddr's tocptr */ \ - "lwz 3, 4(11)\n\t" /* arg1->r3 */ \ - "lwz 4, 8(11)\n\t" /* arg2->r4 */ \ - "lwz 5, 12(11)\n\t" /* arg3->r5 */ \ - "lwz 6, 16(11)\n\t" /* arg4->r6 */ \ - "lwz 7, 20(11)\n\t" /* arg5->r7 */ \ - "lwz 8, 24(11)\n\t" /* arg6->r8 */ \ - "lwz 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "lwz 2,-8(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+7]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "stw 2,-8(11)\n\t" /* save tocptr */ \ - "lwz 2,-4(11)\n\t" /* use nraddr's tocptr */ \ - "lwz 3, 4(11)\n\t" /* arg1->r3 */ \ - "lwz 4, 8(11)\n\t" /* arg2->r4 */ \ - "lwz 5, 12(11)\n\t" /* arg3->r5 */ \ - "lwz 6, 16(11)\n\t" /* arg4->r6 */ \ - "lwz 7, 20(11)\n\t" /* arg5->r7 */ \ - "lwz 8, 24(11)\n\t" /* arg6->r8 */ \ - "lwz 9, 28(11)\n\t" /* arg7->r9 */ \ - "lwz 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "lwz 2,-8(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+8]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "stw 2,-8(11)\n\t" /* save tocptr */ \ - "lwz 2,-4(11)\n\t" /* use nraddr's tocptr */ \ - "lwz 3, 4(11)\n\t" /* arg1->r3 */ \ - "lwz 4, 8(11)\n\t" /* arg2->r4 */ \ - "lwz 5, 12(11)\n\t" /* arg3->r5 */ \ - "lwz 6, 16(11)\n\t" /* arg4->r6 */ \ - "lwz 7, 20(11)\n\t" /* arg5->r7 */ \ - "lwz 8, 24(11)\n\t" /* arg6->r8 */ \ - "lwz 9, 28(11)\n\t" /* arg7->r9 */ \ - "lwz 10, 32(11)\n\t" /* arg8->r10 */ \ - "lwz 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "lwz 2,-8(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+9]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - _argvec[2+9] = (unsigned long)arg9; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "stw 2,-8(11)\n\t" /* save tocptr */ \ - "lwz 2,-4(11)\n\t" /* use nraddr's tocptr */ \ - VG_EXPAND_FRAME_BY_trashes_r3(64) \ - /* arg9 */ \ - "lwz 3,36(11)\n\t" \ - "stw 3,56(1)\n\t" \ - /* args1-8 */ \ - "lwz 3, 4(11)\n\t" /* arg1->r3 */ \ - "lwz 4, 8(11)\n\t" /* arg2->r4 */ \ - "lwz 5, 12(11)\n\t" /* arg3->r5 */ \ - "lwz 6, 16(11)\n\t" /* arg4->r6 */ \ - "lwz 7, 20(11)\n\t" /* arg5->r7 */ \ - "lwz 8, 24(11)\n\t" /* arg6->r8 */ \ - "lwz 9, 28(11)\n\t" /* arg7->r9 */ \ - "lwz 10, 32(11)\n\t" /* arg8->r10 */ \ - "lwz 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "lwz 2,-8(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(64) \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+10]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - _argvec[2+9] = (unsigned long)arg9; \ - _argvec[2+10] = (unsigned long)arg10; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "stw 2,-8(11)\n\t" /* save tocptr */ \ - "lwz 2,-4(11)\n\t" /* use nraddr's tocptr */ \ - VG_EXPAND_FRAME_BY_trashes_r3(64) \ - /* arg10 */ \ - "lwz 3,40(11)\n\t" \ - "stw 3,60(1)\n\t" \ - /* arg9 */ \ - "lwz 3,36(11)\n\t" \ - "stw 3,56(1)\n\t" \ - /* args1-8 */ \ - "lwz 3, 4(11)\n\t" /* arg1->r3 */ \ - "lwz 4, 8(11)\n\t" /* arg2->r4 */ \ - "lwz 5, 12(11)\n\t" /* arg3->r5 */ \ - "lwz 6, 16(11)\n\t" /* arg4->r6 */ \ - "lwz 7, 20(11)\n\t" /* arg5->r7 */ \ - "lwz 8, 24(11)\n\t" /* arg6->r8 */ \ - "lwz 9, 28(11)\n\t" /* arg7->r9 */ \ - "lwz 10, 32(11)\n\t" /* arg8->r10 */ \ - "lwz 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "lwz 2,-8(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(64) \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10,arg11) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+11]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - _argvec[2+9] = (unsigned long)arg9; \ - _argvec[2+10] = (unsigned long)arg10; \ - _argvec[2+11] = (unsigned long)arg11; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "stw 2,-8(11)\n\t" /* save tocptr */ \ - "lwz 2,-4(11)\n\t" /* use nraddr's tocptr */ \ - VG_EXPAND_FRAME_BY_trashes_r3(72) \ - /* arg11 */ \ - "lwz 3,44(11)\n\t" \ - "stw 3,64(1)\n\t" \ - /* arg10 */ \ - "lwz 3,40(11)\n\t" \ - "stw 3,60(1)\n\t" \ - /* arg9 */ \ - "lwz 3,36(11)\n\t" \ - "stw 3,56(1)\n\t" \ - /* args1-8 */ \ - "lwz 3, 4(11)\n\t" /* arg1->r3 */ \ - "lwz 4, 8(11)\n\t" /* arg2->r4 */ \ - "lwz 5, 12(11)\n\t" /* arg3->r5 */ \ - "lwz 6, 16(11)\n\t" /* arg4->r6 */ \ - "lwz 7, 20(11)\n\t" /* arg5->r7 */ \ - "lwz 8, 24(11)\n\t" /* arg6->r8 */ \ - "lwz 9, 28(11)\n\t" /* arg7->r9 */ \ - "lwz 10, 32(11)\n\t" /* arg8->r10 */ \ - "lwz 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "lwz 2,-8(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(72) \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10,arg11,arg12) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+12]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - _argvec[2+9] = (unsigned long)arg9; \ - _argvec[2+10] = (unsigned long)arg10; \ - _argvec[2+11] = (unsigned long)arg11; \ - _argvec[2+12] = (unsigned long)arg12; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "stw 2,-8(11)\n\t" /* save tocptr */ \ - "lwz 2,-4(11)\n\t" /* use nraddr's tocptr */ \ - VG_EXPAND_FRAME_BY_trashes_r3(72) \ - /* arg12 */ \ - "lwz 3,48(11)\n\t" \ - "stw 3,68(1)\n\t" \ - /* arg11 */ \ - "lwz 3,44(11)\n\t" \ - "stw 3,64(1)\n\t" \ - /* arg10 */ \ - "lwz 3,40(11)\n\t" \ - "stw 3,60(1)\n\t" \ - /* arg9 */ \ - "lwz 3,36(11)\n\t" \ - "stw 3,56(1)\n\t" \ - /* args1-8 */ \ - "lwz 3, 4(11)\n\t" /* arg1->r3 */ \ - "lwz 4, 8(11)\n\t" /* arg2->r4 */ \ - "lwz 5, 12(11)\n\t" /* arg3->r5 */ \ - "lwz 6, 16(11)\n\t" /* arg4->r6 */ \ - "lwz 7, 20(11)\n\t" /* arg5->r7 */ \ - "lwz 8, 24(11)\n\t" /* arg6->r8 */ \ - "lwz 9, 28(11)\n\t" /* arg7->r9 */ \ - "lwz 10, 32(11)\n\t" /* arg8->r10 */ \ - "lwz 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "lwz 2,-8(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(72) \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#endif /* PLAT_ppc32_aix5 */ - -/* ------------------------ ppc64-aix5 ------------------------- */ - -#if defined(PLAT_ppc64_aix5) - -/* ARGREGS: r3 r4 r5 r6 r7 r8 r9 r10 (the rest on stack somewhere) */ - -/* These regs are trashed by the hidden call. */ -#define __CALLER_SAVED_REGS \ - "lr", "ctr", "xer", \ - "cr0", "cr1", "cr2", "cr3", "cr4", "cr5", "cr6", "cr7", \ - "r0", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", \ - "r11", "r12", "r13" - -/* Expand the stack frame, copying enough info that unwinding - still works. Trashes r3. */ - -#define VG_EXPAND_FRAME_BY_trashes_r3(_n_fr) \ - "addi 1,1,-" #_n_fr "\n\t" \ - "ld 3," #_n_fr "(1)\n\t" \ - "std 3,0(1)\n\t" - -#define VG_CONTRACT_FRAME_BY(_n_fr) \ - "addi 1,1," #_n_fr "\n\t" - -/* These CALL_FN_ macros assume that on ppc64-aix5, sizeof(unsigned - long) == 8. */ - -#define CALL_FN_W_v(lval, orig) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+0]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_W(lval, orig, arg1) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+1]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WW(lval, orig, arg1,arg2) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+2]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WWW(lval, orig, arg1,arg2,arg3) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+3]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_WWWW(lval, orig, arg1,arg2,arg3,arg4) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+4]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_5W(lval, orig, arg1,arg2,arg3,arg4,arg5) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+5]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_6W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+6]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 8, 48(11)\n\t" /* arg6->r8 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_7W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+7]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 8, 48(11)\n\t" /* arg6->r8 */ \ - "ld 9, 56(11)\n\t" /* arg7->r9 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_8W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+8]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 8, 48(11)\n\t" /* arg6->r8 */ \ - "ld 9, 56(11)\n\t" /* arg7->r9 */ \ - "ld 10, 64(11)\n\t" /* arg8->r10 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_9W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+9]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - _argvec[2+9] = (unsigned long)arg9; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - VG_EXPAND_FRAME_BY_trashes_r3(128) \ - /* arg9 */ \ - "ld 3,72(11)\n\t" \ - "std 3,112(1)\n\t" \ - /* args1-8 */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 8, 48(11)\n\t" /* arg6->r8 */ \ - "ld 9, 56(11)\n\t" /* arg7->r9 */ \ - "ld 10, 64(11)\n\t" /* arg8->r10 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(128) \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_10W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+10]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - _argvec[2+9] = (unsigned long)arg9; \ - _argvec[2+10] = (unsigned long)arg10; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - VG_EXPAND_FRAME_BY_trashes_r3(128) \ - /* arg10 */ \ - "ld 3,80(11)\n\t" \ - "std 3,120(1)\n\t" \ - /* arg9 */ \ - "ld 3,72(11)\n\t" \ - "std 3,112(1)\n\t" \ - /* args1-8 */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 8, 48(11)\n\t" /* arg6->r8 */ \ - "ld 9, 56(11)\n\t" /* arg7->r9 */ \ - "ld 10, 64(11)\n\t" /* arg8->r10 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(128) \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_11W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10,arg11) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+11]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - _argvec[2+9] = (unsigned long)arg9; \ - _argvec[2+10] = (unsigned long)arg10; \ - _argvec[2+11] = (unsigned long)arg11; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - VG_EXPAND_FRAME_BY_trashes_r3(144) \ - /* arg11 */ \ - "ld 3,88(11)\n\t" \ - "std 3,128(1)\n\t" \ - /* arg10 */ \ - "ld 3,80(11)\n\t" \ - "std 3,120(1)\n\t" \ - /* arg9 */ \ - "ld 3,72(11)\n\t" \ - "std 3,112(1)\n\t" \ - /* args1-8 */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 8, 48(11)\n\t" /* arg6->r8 */ \ - "ld 9, 56(11)\n\t" /* arg7->r9 */ \ - "ld 10, 64(11)\n\t" /* arg8->r10 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(144) \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#define CALL_FN_W_12W(lval, orig, arg1,arg2,arg3,arg4,arg5,arg6, \ - arg7,arg8,arg9,arg10,arg11,arg12) \ - do { \ - volatile OrigFn _orig = (orig); \ - volatile unsigned long _argvec[3+12]; \ - volatile unsigned long _res; \ - /* _argvec[0] holds current r2 across the call */ \ - _argvec[1] = (unsigned long)_orig.r2; \ - _argvec[2] = (unsigned long)_orig.nraddr; \ - _argvec[2+1] = (unsigned long)arg1; \ - _argvec[2+2] = (unsigned long)arg2; \ - _argvec[2+3] = (unsigned long)arg3; \ - _argvec[2+4] = (unsigned long)arg4; \ - _argvec[2+5] = (unsigned long)arg5; \ - _argvec[2+6] = (unsigned long)arg6; \ - _argvec[2+7] = (unsigned long)arg7; \ - _argvec[2+8] = (unsigned long)arg8; \ - _argvec[2+9] = (unsigned long)arg9; \ - _argvec[2+10] = (unsigned long)arg10; \ - _argvec[2+11] = (unsigned long)arg11; \ - _argvec[2+12] = (unsigned long)arg12; \ - __asm__ volatile( \ - "mr 11,%1\n\t" \ - VG_EXPAND_FRAME_BY_trashes_r3(512) \ - "std 2,-16(11)\n\t" /* save tocptr */ \ - "ld 2,-8(11)\n\t" /* use nraddr's tocptr */ \ - VG_EXPAND_FRAME_BY_trashes_r3(144) \ - /* arg12 */ \ - "ld 3,96(11)\n\t" \ - "std 3,136(1)\n\t" \ - /* arg11 */ \ - "ld 3,88(11)\n\t" \ - "std 3,128(1)\n\t" \ - /* arg10 */ \ - "ld 3,80(11)\n\t" \ - "std 3,120(1)\n\t" \ - /* arg9 */ \ - "ld 3,72(11)\n\t" \ - "std 3,112(1)\n\t" \ - /* args1-8 */ \ - "ld 3, 8(11)\n\t" /* arg1->r3 */ \ - "ld 4, 16(11)\n\t" /* arg2->r4 */ \ - "ld 5, 24(11)\n\t" /* arg3->r5 */ \ - "ld 6, 32(11)\n\t" /* arg4->r6 */ \ - "ld 7, 40(11)\n\t" /* arg5->r7 */ \ - "ld 8, 48(11)\n\t" /* arg6->r8 */ \ - "ld 9, 56(11)\n\t" /* arg7->r9 */ \ - "ld 10, 64(11)\n\t" /* arg8->r10 */ \ - "ld 11, 0(11)\n\t" /* target->r11 */ \ - VALGRIND_BRANCH_AND_LINK_TO_NOREDIR_R11 \ - "mr 11,%1\n\t" \ - "mr %0,3\n\t" \ - "ld 2,-16(11)\n\t" /* restore tocptr */ \ - VG_CONTRACT_FRAME_BY(144) \ - VG_CONTRACT_FRAME_BY(512) \ - : /*out*/ "=r" (_res) \ - : /*in*/ "r" (&_argvec[2]) \ - : /*trash*/ "cc", "memory", __CALLER_SAVED_REGS \ - ); \ - lval = (__typeof__(lval)) _res; \ - } while (0) - -#endif /* PLAT_ppc64_aix5 */ - - -/* ------------------------------------------------------------------ */ -/* ARCHITECTURE INDEPENDENT MACROS for CLIENT REQUESTS. */ -/* */ -/* ------------------------------------------------------------------ */ - -/* Some request codes. There are many more of these, but most are not - exposed to end-user view. These are the public ones, all of the - form 0x1000 + small_number. - - Core ones are in the range 0x00000000--0x0000ffff. The non-public - ones start at 0x2000. -*/ - -/* These macros are used by tools -- they must be public, but don't - embed them into other programs. */ -#define VG_USERREQ_TOOL_BASE(a,b) \ - ((unsigned int)(((a)&0xff) << 24 | ((b)&0xff) << 16)) -#define VG_IS_TOOL_USERREQ(a, b, v) \ - (VG_USERREQ_TOOL_BASE(a,b) == ((v) & 0xffff0000)) - -/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !! - This enum comprises an ABI exported by Valgrind to programs - which use client requests. DO NOT CHANGE THE ORDER OF THESE - ENTRIES, NOR DELETE ANY -- add new ones at the end. */ -typedef - enum { VG_USERREQ__RUNNING_ON_VALGRIND = 0x1001, - VG_USERREQ__DISCARD_TRANSLATIONS = 0x1002, - - /* These allow any function to be called from the simulated - CPU but run on the real CPU. Nb: the first arg passed to - the function is always the ThreadId of the running - thread! So CLIENT_CALL0 actually requires a 1 arg - function, etc. */ - VG_USERREQ__CLIENT_CALL0 = 0x1101, - VG_USERREQ__CLIENT_CALL1 = 0x1102, - VG_USERREQ__CLIENT_CALL2 = 0x1103, - VG_USERREQ__CLIENT_CALL3 = 0x1104, - - /* Can be useful in regression testing suites -- eg. can - send Valgrind's output to /dev/null and still count - errors. */ - VG_USERREQ__COUNT_ERRORS = 0x1201, - - /* These are useful and can be interpreted by any tool that - tracks malloc() et al, by using vg_replace_malloc.c. */ - VG_USERREQ__MALLOCLIKE_BLOCK = 0x1301, - VG_USERREQ__FREELIKE_BLOCK = 0x1302, - /* Memory pool support. */ - VG_USERREQ__CREATE_MEMPOOL = 0x1303, - VG_USERREQ__DESTROY_MEMPOOL = 0x1304, - VG_USERREQ__MEMPOOL_ALLOC = 0x1305, - VG_USERREQ__MEMPOOL_FREE = 0x1306, - VG_USERREQ__MEMPOOL_TRIM = 0x1307, - VG_USERREQ__MOVE_MEMPOOL = 0x1308, - VG_USERREQ__MEMPOOL_CHANGE = 0x1309, - VG_USERREQ__MEMPOOL_EXISTS = 0x130a, - - /* Allow printfs to valgrind log. */ - /* The first two pass the va_list argument by value, which - assumes it is the same size as or smaller than a UWord, - which generally isn't the case. Hence are deprecated. - The second two pass the vargs by reference and so are - immune to this problem. */ - /* both :: char* fmt, va_list vargs (DEPRECATED) */ - VG_USERREQ__PRINTF = 0x1401, - VG_USERREQ__PRINTF_BACKTRACE = 0x1402, - /* both :: char* fmt, va_list* vargs */ - VG_USERREQ__PRINTF_VALIST_BY_REF = 0x1403, - VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF = 0x1404, - - /* Stack support. */ - VG_USERREQ__STACK_REGISTER = 0x1501, - VG_USERREQ__STACK_DEREGISTER = 0x1502, - VG_USERREQ__STACK_CHANGE = 0x1503, - - /* Wine support */ - VG_USERREQ__LOAD_PDB_DEBUGINFO = 0x1601 - } Vg_ClientRequest; - -#if !defined(__GNUC__) -# define __extension__ /* */ -#endif - -/* Returns the number of Valgrinds this code is running under. That - is, 0 if running natively, 1 if running under Valgrind, 2 if - running under Valgrind which is running under another Valgrind, - etc. */ -#define RUNNING_ON_VALGRIND __extension__ \ - ({unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0 /* if not */, \ - VG_USERREQ__RUNNING_ON_VALGRIND, \ - 0, 0, 0, 0, 0); \ - _qzz_res; \ - }) - - -/* Discard translation of code in the range [_qzz_addr .. _qzz_addr + - _qzz_len - 1]. Useful if you are debugging a JITter or some such, - since it provides a way to make sure valgrind will retranslate the - invalidated area. Returns no value. */ -#define VALGRIND_DISCARD_TRANSLATIONS(_qzz_addr,_qzz_len) \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__DISCARD_TRANSLATIONS, \ - _qzz_addr, _qzz_len, 0, 0, 0); \ - } - - -/* These requests are for getting Valgrind itself to print something. - Possibly with a backtrace. This is a really ugly hack. The return value - is the number of characters printed, excluding the "**** " part at the - start and the backtrace (if present). */ - -#if defined(NVALGRIND) - -# define VALGRIND_PRINTF(...) -# define VALGRIND_PRINTF_BACKTRACE(...) - -#else /* NVALGRIND */ - -/* Modern GCC will optimize the static routine out if unused, - and unused attribute will shut down warnings about it. */ -static int VALGRIND_PRINTF(const char *format, ...) - __attribute__((format(__printf__, 1, 2), __unused__)); -static int -VALGRIND_PRINTF(const char *format, ...) -{ - unsigned long _qzz_res; - va_list vargs; - va_start(vargs, format); - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, - VG_USERREQ__PRINTF_VALIST_BY_REF, - (unsigned long)format, - (unsigned long)&vargs, - 0, 0, 0); - va_end(vargs); - return (int)_qzz_res; -} - -static int VALGRIND_PRINTF_BACKTRACE(const char *format, ...) - __attribute__((format(__printf__, 1, 2), __unused__)); -static int -VALGRIND_PRINTF_BACKTRACE(const char *format, ...) -{ - unsigned long _qzz_res; - va_list vargs; - va_start(vargs, format); - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, - VG_USERREQ__PRINTF_BACKTRACE_VALIST_BY_REF, - (unsigned long)format, - (unsigned long)&vargs, - 0, 0, 0); - va_end(vargs); - return (int)_qzz_res; -} - -#endif /* NVALGRIND */ - - -/* These requests allow control to move from the simulated CPU to the - real CPU, calling an arbitary function. - - Note that the current ThreadId is inserted as the first argument. - So this call: - - VALGRIND_NON_SIMD_CALL2(f, arg1, arg2) - - requires f to have this signature: - - Word f(Word tid, Word arg1, Word arg2) - - where "Word" is a word-sized type. - - Note that these client requests are not entirely reliable. For example, - if you call a function with them that subsequently calls printf(), - there's a high chance Valgrind will crash. Generally, your prospects of - these working are made higher if the called function does not refer to - any global variables, and does not refer to any libc or other functions - (printf et al). Any kind of entanglement with libc or dynamic linking is - likely to have a bad outcome, for tricky reasons which we've grappled - with a lot in the past. -*/ -#define VALGRIND_NON_SIMD_CALL0(_qyy_fn) \ - __extension__ \ - ({unsigned long _qyy_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qyy_res, 0 /* default return */, \ - VG_USERREQ__CLIENT_CALL0, \ - _qyy_fn, \ - 0, 0, 0, 0); \ - _qyy_res; \ - }) - -#define VALGRIND_NON_SIMD_CALL1(_qyy_fn, _qyy_arg1) \ - __extension__ \ - ({unsigned long _qyy_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qyy_res, 0 /* default return */, \ - VG_USERREQ__CLIENT_CALL1, \ - _qyy_fn, \ - _qyy_arg1, 0, 0, 0); \ - _qyy_res; \ - }) - -#define VALGRIND_NON_SIMD_CALL2(_qyy_fn, _qyy_arg1, _qyy_arg2) \ - __extension__ \ - ({unsigned long _qyy_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qyy_res, 0 /* default return */, \ - VG_USERREQ__CLIENT_CALL2, \ - _qyy_fn, \ - _qyy_arg1, _qyy_arg2, 0, 0); \ - _qyy_res; \ - }) - -#define VALGRIND_NON_SIMD_CALL3(_qyy_fn, _qyy_arg1, _qyy_arg2, _qyy_arg3) \ - __extension__ \ - ({unsigned long _qyy_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qyy_res, 0 /* default return */, \ - VG_USERREQ__CLIENT_CALL3, \ - _qyy_fn, \ - _qyy_arg1, _qyy_arg2, \ - _qyy_arg3, 0); \ - _qyy_res; \ - }) - - -/* Counts the number of errors that have been recorded by a tool. Nb: - the tool must record the errors with VG_(maybe_record_error)() or - VG_(unique_error)() for them to be counted. */ -#define VALGRIND_COUNT_ERRORS \ - __extension__ \ - ({unsigned int _qyy_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qyy_res, 0 /* default return */, \ - VG_USERREQ__COUNT_ERRORS, \ - 0, 0, 0, 0, 0); \ - _qyy_res; \ - }) - -/* Several Valgrind tools (Memcheck, Massif, Helgrind, DRD) rely on knowing - when heap blocks are allocated in order to give accurate results. This - happens automatically for the standard allocator functions such as - malloc(), calloc(), realloc(), memalign(), new, new[], free(), delete, - delete[], etc. - - But if your program uses a custom allocator, this doesn't automatically - happen, and Valgrind will not do as well. For example, if you allocate - superblocks with mmap() and then allocates chunks of the superblocks, all - Valgrind's observations will be at the mmap() level and it won't know that - the chunks should be considered separate entities. In Memcheck's case, - that means you probably won't get heap block overrun detection (because - there won't be redzones marked as unaddressable) and you definitely won't - get any leak detection. - - The following client requests allow a custom allocator to be annotated so - that it can be handled accurately by Valgrind. - - VALGRIND_MALLOCLIKE_BLOCK marks a region of memory as having been allocated - by a malloc()-like function. For Memcheck (an illustrative case), this - does two things: - - - It records that the block has been allocated. This means any addresses - within the block mentioned in error messages will be - identified as belonging to the block. It also means that if the block - isn't freed it will be detected by the leak checker. - - - It marks the block as being addressable and undefined (if 'is_zeroed' is - not set), or addressable and defined (if 'is_zeroed' is set). This - controls how accesses to the block by the program are handled. - - 'addr' is the start of the usable block (ie. after any - redzone), 'sizeB' is its size. 'rzB' is the redzone size if the allocator - can apply redzones -- these are blocks of padding at the start and end of - each block. Adding redzones is recommended as it makes it much more likely - Valgrind will spot block overruns. `is_zeroed' indicates if the memory is - zeroed (or filled with another predictable value), as is the case for - calloc(). - - VALGRIND_MALLOCLIKE_BLOCK should be put immediately after the point where a - heap block -- that will be used by the client program -- is allocated. - It's best to put it at the outermost level of the allocator if possible; - for example, if you have a function my_alloc() which calls - internal_alloc(), and the client request is put inside internal_alloc(), - stack traces relating to the heap block will contain entries for both - my_alloc() and internal_alloc(), which is probably not what you want. - - For Memcheck users: if you use VALGRIND_MALLOCLIKE_BLOCK to carve out - custom blocks from within a heap block, B, that has been allocated with - malloc/calloc/new/etc, then block B will be *ignored* during leak-checking - -- the custom blocks will take precedence. - - VALGRIND_FREELIKE_BLOCK is the partner to VALGRIND_MALLOCLIKE_BLOCK. For - Memcheck, it does two things: - - - It records that the block has been deallocated. This assumes that the - block was annotated as having been allocated via - VALGRIND_MALLOCLIKE_BLOCK. Otherwise, an error will be issued. - - - It marks the block as being unaddressable. - - VALGRIND_FREELIKE_BLOCK should be put immediately after the point where a - heap block is deallocated. - - In many cases, these two client requests will not be enough to get your - allocator working well with Memcheck. More specifically, if your allocator - writes to freed blocks in any way then a VALGRIND_MAKE_MEM_UNDEFINED call - will be necessary to mark the memory as addressable just before the zeroing - occurs, otherwise you'll get a lot of invalid write errors. For example, - you'll need to do this if your allocator recycles freed blocks, but it - zeroes them before handing them back out (via VALGRIND_MALLOCLIKE_BLOCK). - Alternatively, if your allocator reuses freed blocks for allocator-internal - data structures, VALGRIND_MAKE_MEM_UNDEFINED calls will also be necessary. - - Really, what's happening is a blurring of the lines between the client - program and the allocator... after VALGRIND_FREELIKE_BLOCK is called, the - memory should be considered unaddressable to the client program, but the - allocator knows more than the rest of the client program and so may be able - to safely access it. Extra client requests are necessary for Valgrind to - understand the distinction between the allocator and the rest of the - program. - - Note: there is currently no VALGRIND_REALLOCLIKE_BLOCK client request; it - has to be emulated with MALLOCLIKE/FREELIKE and memory copying. - - Ignored if addr == 0. -*/ -#define VALGRIND_MALLOCLIKE_BLOCK(addr, sizeB, rzB, is_zeroed) \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__MALLOCLIKE_BLOCK, \ - addr, sizeB, rzB, is_zeroed, 0); \ - } - -/* See the comment for VALGRIND_MALLOCLIKE_BLOCK for details. - Ignored if addr == 0. -*/ -#define VALGRIND_FREELIKE_BLOCK(addr, rzB) \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__FREELIKE_BLOCK, \ - addr, rzB, 0, 0, 0); \ - } - -/* Create a memory pool. */ -#define VALGRIND_CREATE_MEMPOOL(pool, rzB, is_zeroed) \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__CREATE_MEMPOOL, \ - pool, rzB, is_zeroed, 0, 0); \ - } - -/* Destroy a memory pool. */ -#define VALGRIND_DESTROY_MEMPOOL(pool) \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__DESTROY_MEMPOOL, \ - pool, 0, 0, 0, 0); \ - } - -/* Associate a piece of memory with a memory pool. */ -#define VALGRIND_MEMPOOL_ALLOC(pool, addr, size) \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__MEMPOOL_ALLOC, \ - pool, addr, size, 0, 0); \ - } - -/* Disassociate a piece of memory from a memory pool. */ -#define VALGRIND_MEMPOOL_FREE(pool, addr) \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__MEMPOOL_FREE, \ - pool, addr, 0, 0, 0); \ - } - -/* Disassociate any pieces outside a particular range. */ -#define VALGRIND_MEMPOOL_TRIM(pool, addr, size) \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__MEMPOOL_TRIM, \ - pool, addr, size, 0, 0); \ - } - -/* Resize and/or move a piece associated with a memory pool. */ -#define VALGRIND_MOVE_MEMPOOL(poolA, poolB) \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__MOVE_MEMPOOL, \ - poolA, poolB, 0, 0, 0); \ - } - -/* Resize and/or move a piece associated with a memory pool. */ -#define VALGRIND_MEMPOOL_CHANGE(pool, addrA, addrB, size) \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__MEMPOOL_CHANGE, \ - pool, addrA, addrB, size, 0); \ - } - -/* Return 1 if a mempool exists, else 0. */ -#define VALGRIND_MEMPOOL_EXISTS(pool) \ - __extension__ \ - ({unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__MEMPOOL_EXISTS, \ - pool, 0, 0, 0, 0); \ - _qzz_res; \ - }) - -/* Mark a piece of memory as being a stack. Returns a stack id. */ -#define VALGRIND_STACK_REGISTER(start, end) \ - __extension__ \ - ({unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__STACK_REGISTER, \ - start, end, 0, 0, 0); \ - _qzz_res; \ - }) - -/* Unmark the piece of memory associated with a stack id as being a - stack. */ -#define VALGRIND_STACK_DEREGISTER(id) \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__STACK_DEREGISTER, \ - id, 0, 0, 0, 0); \ - } - -/* Change the start and end address of the stack id. */ -#define VALGRIND_STACK_CHANGE(id, start, end) \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__STACK_CHANGE, \ - id, start, end, 0, 0); \ - } - -/* Load PDB debug info for Wine PE image_map. */ -#define VALGRIND_LOAD_PDB_DEBUGINFO(fd, ptr, total_size, delta) \ - {unsigned int _qzz_res; \ - VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \ - VG_USERREQ__LOAD_PDB_DEBUGINFO, \ - fd, ptr, total_size, delta, 0); \ - } - - -#undef PLAT_x86_linux -#undef PLAT_amd64_linux -#undef PLAT_ppc32_linux -#undef PLAT_ppc64_linux -#undef PLAT_arm_linux -#undef PLAT_ppc32_aix5 -#undef PLAT_ppc64_aix5 - -#endif /* __VALGRIND_H */