Merge branch 'master' of git://yap.git.sourceforge.net/gitroot/yap/yap-6.3
This commit is contained in:
commit
213c0780d3
@ -344,7 +344,6 @@
|
||||
#endif
|
||||
|
||||
#if defined(YAPOR) || defined(THREADS)
|
||||
#undef MODE_DIRECTED_TABLING
|
||||
#undef TABLING_EARLY_COMPLETION
|
||||
#undef INCOMPLETE_TABLING
|
||||
#undef LIMIT_TABLING
|
||||
|
@ -74,26 +74,32 @@ static inline void answer_to_stdout(char *answer);
|
||||
#endif /* YAPOR */
|
||||
|
||||
#ifdef TABLING
|
||||
static inline long show_statistics_table_entries(IOSTREAM *out);
|
||||
static inline long show_statistics_subgoal_frames(IOSTREAM *out);
|
||||
static inline long show_statistics_dependency_frames(IOSTREAM *out);
|
||||
static inline long show_statistics_subgoal_trie_nodes(IOSTREAM *out);
|
||||
static inline long show_statistics_subgoal_trie_hashes(IOSTREAM *out);
|
||||
static inline long show_statistics_answer_trie_nodes(IOSTREAM *out);
|
||||
static inline long show_statistics_answer_trie_hashes(IOSTREAM *out);
|
||||
static inline long show_statistics_global_trie_nodes(IOSTREAM *out);
|
||||
static inline long show_statistics_global_trie_hashes(IOSTREAM *out);
|
||||
static inline struct page_statistics show_statistics_table_entries(IOSTREAM *out);
|
||||
#if defined(THREADS_FULL_SHARING) || defined(THREADS_CONSUMER_SHARING)
|
||||
static inline struct page_statistics show_statistics_subgoal_entries(IOSTREAM *out);
|
||||
#endif /* THREADS_FULL_SHARING || THREADS_CONSUMER_SHARING */
|
||||
static inline struct page_statistics show_statistics_subgoal_frames(IOSTREAM *out);
|
||||
static inline struct page_statistics show_statistics_dependency_frames(IOSTREAM *out);
|
||||
static inline struct page_statistics show_statistics_subgoal_trie_nodes(IOSTREAM *out);
|
||||
static inline struct page_statistics show_statistics_subgoal_trie_hashes(IOSTREAM *out);
|
||||
static inline struct page_statistics show_statistics_answer_trie_nodes(IOSTREAM *out);
|
||||
static inline struct page_statistics show_statistics_answer_trie_hashes(IOSTREAM *out);
|
||||
#if defined(THREADS_FULL_SHARING)
|
||||
static inline struct page_statistics show_statistics_answer_ref_nodes(IOSTREAM *out);
|
||||
#endif /* THREADS_FULL_SHARING */
|
||||
static inline struct page_statistics show_statistics_global_trie_nodes(IOSTREAM *out);
|
||||
static inline struct page_statistics show_statistics_global_trie_hashes(IOSTREAM *out);
|
||||
#endif /* TABLING */
|
||||
#ifdef YAPOR
|
||||
static inline long show_statistics_or_frames(IOSTREAM *out);
|
||||
static inline long show_statistics_query_goal_solution_frames(IOSTREAM *out);
|
||||
static inline long show_statistics_query_goal_answer_frames(IOSTREAM *out);
|
||||
static inline struct page_statistics show_statistics_or_frames(IOSTREAM *out);
|
||||
static inline struct page_statistics show_statistics_query_goal_solution_frames(IOSTREAM *out);
|
||||
static inline struct page_statistics show_statistics_query_goal_answer_frames(IOSTREAM *out);
|
||||
#endif /* YAPOR */
|
||||
#if defined(YAPOR) && defined(TABLING)
|
||||
static inline long show_statistics_suspension_frames(IOSTREAM *out);
|
||||
static inline struct page_statistics show_statistics_suspension_frames(IOSTREAM *out);
|
||||
#ifdef TABLING_INNER_CUTS
|
||||
static inline long show_statistics_table_subgoal_solution_frames(IOSTREAM *out);
|
||||
static inline long show_statistics_table_subgoal_answer_frames(IOSTREAM *out);
|
||||
static inline struct page_statistics show_statistics_table_subgoal_solution_frames(IOSTREAM *out);
|
||||
static inline struct page_statistics show_statistics_table_subgoal_answer_frames(IOSTREAM *out);
|
||||
#endif /* TABLING_INNER_CUTS */
|
||||
#endif /* YAPOR && TABLING */
|
||||
|
||||
@ -114,15 +120,18 @@ struct page_statistics {
|
||||
long pages_allocated; /* same as struct pages (opt.structs.h) */
|
||||
#endif /* USE_PAGES_MALLOC */
|
||||
long structs_in_use; /* same as struct pages (opt.structs.h) */
|
||||
long bytes_in_use;
|
||||
};
|
||||
|
||||
#define Pg_bytes_in_use(STATS) STATS.bytes_in_use
|
||||
|
||||
#ifdef USE_PAGES_MALLOC
|
||||
#ifdef DEBUG_TABLING
|
||||
#define CHECK_PAGE_FREE_STRUCTS(STR_TYPE, STR_PAGES) \
|
||||
#define CHECK_PAGE_FREE_STRUCTS(STR_TYPE, PAGE) \
|
||||
{ pg_hd_ptr pg_hd; \
|
||||
STR_TYPE *aux_ptr; \
|
||||
long cont = 0; \
|
||||
pg_hd = Pg_free_pg(STR_PAGES); \
|
||||
pg_hd = Pg_free_pg(PAGE); \
|
||||
while (pg_hd) { \
|
||||
aux_ptr = PgHd_free_str(pg_hd); \
|
||||
while (aux_ptr) { \
|
||||
@ -131,59 +140,66 @@ struct page_statistics {
|
||||
} \
|
||||
pg_hd = PgHd_next(pg_hd); \
|
||||
} \
|
||||
if(Pg_str_free(STR_PAGES) != cont)printf("ERRRO!!!!!!!!\n");\
|
||||
TABLING_ERROR_CHECKING(CHECK_PAGE_FREE_STRUCTS, Pg_str_free(STR_PAGES) != cont); \
|
||||
TABLING_ERROR_CHECKING(CHECK_PAGE_FREE_STRUCTS, Pg_str_free(PAGE) != cont); \
|
||||
}
|
||||
#else
|
||||
#define CHECK_PAGE_FREE_STRUCTS(STR_TYPE,STR_PAGES)
|
||||
#define CHECK_PAGE_FREE_STRUCTS(STR_TYPE, PAGE)
|
||||
#endif /* DEBUG_TABLING */
|
||||
#define INIT_PAGE_STATS(STATS) \
|
||||
Pg_pg_alloc(STATS) = 0; \
|
||||
Pg_str_in_use(STATS) = 0
|
||||
#define INCREMENT_PAGE_STATS(STATS,PAGE) \
|
||||
#define INCREMENT_PAGE_STATS(STATS, PAGE) \
|
||||
Pg_pg_alloc(STATS) += Pg_pg_alloc(PAGE); \
|
||||
Pg_str_in_use(STATS) += Pg_str_in_use(PAGE)
|
||||
#define INCREMENT_AUX_STATS(STATS, BYTES, PAGES) \
|
||||
BYTES += Pg_bytes_in_use(STATS); \
|
||||
PAGES += Pg_pg_alloc(STATS)
|
||||
#define SHOW_PAGE_STATS_MSG(STR_NAME) " " STR_NAME " %10ld bytes (%ld pages and %ld structs in use)\n"
|
||||
#define SHOW_PAGE_STATS_ARGS(STATS,STR_TYPE) Pg_str_in_use(STATS) * sizeof(STR_TYPE), Pg_pg_alloc(STATS), Pg_str_in_use(STATS)
|
||||
#define SHOW_PAGE_STATS_ARGS(STATS, STR_TYPE) Pg_str_in_use(STATS) * sizeof(STR_TYPE), Pg_pg_alloc(STATS), Pg_str_in_use(STATS)
|
||||
#else /* !USE_PAGES_MALLOC */
|
||||
#define CHECK_PAGE_FREE_STRUCTS(STR_TYPE, PAGE)
|
||||
#define INIT_PAGE_STATS(STATS) \
|
||||
Pg_str_in_use(STATS) = 0
|
||||
#define INCREMENT_PAGE_STATS(STATS,PAGE) \
|
||||
#define INCREMENT_PAGE_STATS(STATS, PAGE) \
|
||||
Pg_str_in_use(STATS) += Pg_str_in_use(PAGE)
|
||||
#define CHECK_PAGE_FREE_STRUCTS(STR_TYPE,STR_PAGES)
|
||||
#define INCREMENT_AUX_STATS(STATS, BYTES, PAGES) \
|
||||
BYTES += Pg_bytes_in_use(STATS)
|
||||
#define SHOW_PAGE_STATS_MSG(STR_NAME) " " STR_NAME " %10ld bytes (%ld structs in use)\n"
|
||||
#define SHOW_PAGE_STATS_ARGS(STATS,STR_TYPE) Pg_str_in_use(STATS) * sizeof(STR_TYPE), Pg_str_in_use(STATS)
|
||||
#endif
|
||||
#define SHOW_PAGE_STATS_ARGS(STATS, STR_TYPE) Pg_str_in_use(STATS) * sizeof(STR_TYPE), Pg_str_in_use(STATS)
|
||||
#endif /* USE_PAGES_MALLOC */
|
||||
|
||||
#define GET_GLOBAL_PAGE_STATS(STATS,STR_PAGES) \
|
||||
#define GET_GLOBAL_PAGE_STATS(STATS, STR_TYPE, STR_PAGES) \
|
||||
INIT_PAGE_STATS(STATS); \
|
||||
CHECK_PAGE_FREE_STRUCTS(STR_TYPE,STR_PAGES); \
|
||||
INCREMENT_PAGE_STATS(STATS,STR_PAGES)
|
||||
#define GET_REMOTE_PAGE_STATS(STATS,STR_PAGES) \
|
||||
CHECK_PAGE_FREE_STRUCTS(STR_TYPE, STR_PAGES); \
|
||||
INCREMENT_PAGE_STATS(STATS, STR_PAGES); \
|
||||
Pg_bytes_in_use(STATS) = Pg_str_in_use(STATS) * sizeof(STR_TYPE)
|
||||
#define GET_REMOTE_PAGE_STATS(STATS, STR_TYPE, STR_PAGES) \
|
||||
INIT_PAGE_STATS(STATS); \
|
||||
LOCK(GLOBAL_ThreadHandlesLock); \
|
||||
{ int wid; \
|
||||
for (wid = 0; wid < MAX_THREADS; wid++) { \
|
||||
if (!Yap_local[wid]) \
|
||||
if (! Yap_local[wid]) \
|
||||
break; \
|
||||
if (REMOTE_ThreadHandle(wid).in_use) { \
|
||||
CHECK_PAGE_FREE_STRUCTS(STR_TYPE,STR_PAGES); \
|
||||
INCREMENT_PAGE_STATS(STATS,STR_PAGES(wid)); \
|
||||
CHECK_PAGE_FREE_STRUCTS(STR_TYPE, STR_PAGES(wid)); \
|
||||
INCREMENT_PAGE_STATS(STATS, STR_PAGES(wid)); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
UNLOCK(GLOBAL_ThreadHandlesLock)
|
||||
#define SHOW_GLOBAL_PAGE_STATS(OUT_STREAM,STR_TYPE,STR_PAGES,STR_NAME) \
|
||||
UNLOCK(GLOBAL_ThreadHandlesLock); \
|
||||
Pg_bytes_in_use(STATS) = Pg_str_in_use(STATS) * sizeof(STR_TYPE)
|
||||
|
||||
#define SHOW_GLOBAL_PAGE_STATS(OUT_STREAM, STR_TYPE, STR_PAGES, STR_NAME) \
|
||||
{ struct page_statistics stats; \
|
||||
GET_GLOBAL_PAGE_STATS(stats,STR_PAGES); \
|
||||
Sfprintf(OUT_STREAM,SHOW_PAGE_STATS_MSG(STR_NAME),SHOW_PAGE_STATS_ARGS(stats,STR_TYPE)); \
|
||||
return Pg_str_in_use(stats) * sizeof(STR_TYPE); \
|
||||
GET_GLOBAL_PAGE_STATS(stats, STR_TYPE, STR_PAGES); \
|
||||
Sfprintf(OUT_STREAM, SHOW_PAGE_STATS_MSG(STR_NAME), SHOW_PAGE_STATS_ARGS(stats, STR_TYPE)); \
|
||||
return stats; \
|
||||
}
|
||||
#define SHOW_REMOTE_PAGE_STATS(OUT_STREAM,STR_TYPE,STR_PAGES,STR_NAME) \
|
||||
#define SHOW_REMOTE_PAGE_STATS(OUT_STREAM, STR_TYPE, STR_PAGES, STR_NAME) \
|
||||
{ struct page_statistics stats; \
|
||||
GET_REMOTE_PAGE_STATS(stats,STR_PAGES); \
|
||||
Sfprintf(OUT_STREAM,SHOW_PAGE_STATS_MSG(STR_NAME),SHOW_PAGE_STATS_ARGS(stats,STR_TYPE)); \
|
||||
return Pg_str_in_use(stats) * sizeof(STR_TYPE); \
|
||||
GET_REMOTE_PAGE_STATS(stats, STR_TYPE, STR_PAGES); \
|
||||
Sfprintf(OUT_STREAM, SHOW_PAGE_STATS_MSG(STR_NAME), SHOW_PAGE_STATS_ARGS(stats, STR_TYPE)); \
|
||||
return stats; \
|
||||
}
|
||||
|
||||
|
||||
@ -636,35 +652,56 @@ static Int p_show_statistics_table( USES_REGS1 ) {
|
||||
|
||||
|
||||
static Int p_show_statistics_tabling( USES_REGS1 ) {
|
||||
struct page_statistics stats;
|
||||
long bytes, total_bytes = 0;
|
||||
#ifdef USE_PAGES_MALLOC
|
||||
long total_pages = 0;
|
||||
#endif /* USE_PAGES_MALLOC */
|
||||
IOSTREAM *out;
|
||||
long total_bytes = 0, aux_bytes;
|
||||
|
||||
if (!PL_get_stream_handle(Yap_InitSlot(Deref(ARG1) PASS_REGS), &out))
|
||||
return (FALSE);
|
||||
aux_bytes = 0;
|
||||
bytes = 0;
|
||||
Sfprintf(out, "Execution data structures\n");
|
||||
aux_bytes += show_statistics_table_entries(out);
|
||||
aux_bytes += show_statistics_subgoal_frames(out);
|
||||
aux_bytes += show_statistics_dependency_frames(out);
|
||||
Sfprintf(out, " Memory in use (I): %10ld bytes\n\n", aux_bytes);
|
||||
total_bytes += aux_bytes;
|
||||
aux_bytes = 0;
|
||||
stats = show_statistics_table_entries(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
#if defined(THREADS_FULL_SHARING) || defined(THREADS_CONSUMER_SHARING)
|
||||
stats = show_statistics_subgoal_entries(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
#endif /* THREADS_FULL_SHARING || THREADS_CONSUMER_SHARING */
|
||||
stats = show_statistics_subgoal_frames(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_dependency_frames(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
Sfprintf(out, " Memory in use (I): %10ld bytes\n\n", bytes);
|
||||
total_bytes += bytes;
|
||||
bytes = 0;
|
||||
Sfprintf(out, "Local trie data structures\n");
|
||||
aux_bytes += show_statistics_subgoal_trie_nodes(out);
|
||||
aux_bytes += show_statistics_answer_trie_nodes(out);
|
||||
aux_bytes += show_statistics_subgoal_trie_hashes(out);
|
||||
aux_bytes += show_statistics_answer_trie_hashes(out);
|
||||
Sfprintf(out, " Memory in use (II): %10ld bytes\n\n", aux_bytes);
|
||||
total_bytes += aux_bytes;
|
||||
aux_bytes = 0;
|
||||
stats = show_statistics_subgoal_trie_nodes(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_answer_trie_nodes(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_subgoal_trie_hashes(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_answer_trie_hashes(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
#if defined(THREADS_FULL_SHARING)
|
||||
stats = show_statistics_answer_ref_nodes(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
#endif /* THREADS_FULL_SHARING */
|
||||
Sfprintf(out, " Memory in use (II): %10ld bytes\n\n", bytes);
|
||||
total_bytes += bytes;
|
||||
bytes = 0;
|
||||
Sfprintf(out, "Global trie data structures\n");
|
||||
aux_bytes += show_statistics_global_trie_nodes(out);
|
||||
aux_bytes += show_statistics_global_trie_hashes(out);
|
||||
Sfprintf(out, " Memory in use (III): %10ld bytes\n\n", aux_bytes);
|
||||
total_bytes += aux_bytes;
|
||||
stats = show_statistics_global_trie_nodes(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_global_trie_hashes(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
Sfprintf(out, " Memory in use (III): %10ld bytes\n\n", bytes);
|
||||
total_bytes += bytes;
|
||||
#ifdef USE_PAGES_MALLOC
|
||||
Sfprintf(out, "Total memory in use (I+II+III): %10ld bytes (%ld pages in use)\n",
|
||||
total_bytes, Pg_str_in_use(GLOBAL_pages_void));
|
||||
total_bytes, total_pages);
|
||||
Sfprintf(out, "Total memory allocated: %10ld bytes (%ld pages in total)\n",
|
||||
Pg_pg_alloc(GLOBAL_pages_void) * Yap_page_size, Pg_pg_alloc(GLOBAL_pages_void));
|
||||
#else
|
||||
@ -674,6 +711,7 @@ static Int p_show_statistics_tabling( USES_REGS1 ) {
|
||||
return (TRUE);
|
||||
}
|
||||
|
||||
|
||||
static Int p_show_statistics_global_trie( USES_REGS1 ) {
|
||||
IOSTREAM *out;
|
||||
|
||||
@ -686,6 +724,7 @@ static Int p_show_statistics_global_trie( USES_REGS1 ) {
|
||||
#endif /* TABLING */
|
||||
|
||||
|
||||
|
||||
/*********************************
|
||||
** YapOr C Predicates **
|
||||
*********************************/
|
||||
@ -779,25 +818,32 @@ static Int p_parallel_new_answer( USES_REGS1 ) {
|
||||
|
||||
|
||||
static Int p_show_statistics_or( USES_REGS1 ) {
|
||||
struct page_statistics stats;
|
||||
long bytes, total_bytes = 0;
|
||||
#ifdef USE_PAGES_MALLOC
|
||||
long total_pages = 0;
|
||||
#endif /* USE_PAGES_MALLOC */
|
||||
IOSTREAM *out;
|
||||
long total_bytes = 0, aux_bytes;
|
||||
|
||||
if (!PL_get_stream_handle(Yap_InitSlot(Deref(ARG1) PASS_REGS), &out))
|
||||
return (FALSE);
|
||||
aux_bytes = 0;
|
||||
bytes = 0;
|
||||
Sfprintf(out, "Execution data structures\n");
|
||||
aux_bytes += show_statistics_or_frames(out);
|
||||
Sfprintf(out, " Memory in use (I): %10ld bytes\n\n", aux_bytes);
|
||||
total_bytes += aux_bytes;
|
||||
aux_bytes = 0;
|
||||
stats = show_statistics_or_frames(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
Sfprintf(out, " Memory in use (I): %10ld bytes\n\n", bytes);
|
||||
total_bytes += bytes;
|
||||
bytes = 0;
|
||||
Sfprintf(out, "Cut support data structures\n");
|
||||
aux_bytes += show_statistics_query_goal_solution_frames(out);
|
||||
aux_bytes += show_statistics_query_goal_answer_frames(out);
|
||||
Sfprintf(out, " Memory in use (II): %10ld bytes\n\n", aux_bytes);
|
||||
total_bytes += aux_bytes;
|
||||
stats = show_statistics_query_goal_solution_frames(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_query_goal_answer_frames(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
Sfprintf(out, " Memory in use (II): %10ld bytes\n\n", bytes);
|
||||
total_bytes += bytes;
|
||||
#ifdef USE_PAGES_MALLOC
|
||||
Sfprintf(out, "Total memory in use (I+II): %10ld bytes (%ld pages in use)\n",
|
||||
total_bytes, Pg_str_in_use(GLOBAL_pages_void));
|
||||
total_bytes, total_pages);
|
||||
Sfprintf(out, "Total memory allocated: %10ld bytes (%ld pages in total)\n",
|
||||
Pg_pg_alloc(GLOBAL_pages_void) * Yap_page_size, Pg_pg_alloc(GLOBAL_pages_void));
|
||||
#else
|
||||
@ -816,47 +862,74 @@ static Int p_show_statistics_or( USES_REGS1 ) {
|
||||
|
||||
#if defined(YAPOR) && defined(TABLING)
|
||||
static Int p_show_statistics_opt( USES_REGS1 ) {
|
||||
struct page_statistics stats;
|
||||
long bytes, total_bytes = 0;
|
||||
#ifdef USE_PAGES_MALLOC
|
||||
long total_pages = 0;
|
||||
#endif /* USE_PAGES_MALLOC */
|
||||
IOSTREAM *out;
|
||||
long total_bytes = 0, aux_bytes;
|
||||
|
||||
if (!PL_get_stream_handle(Yap_InitSlot(Deref(ARG1) PASS_REGS), &out))
|
||||
return (FALSE);
|
||||
aux_bytes = 0;
|
||||
bytes = 0;
|
||||
Sfprintf(out, "Execution data structures\n");
|
||||
aux_bytes += show_statistics_table_entries(out);
|
||||
aux_bytes += show_statistics_subgoal_frames(out);
|
||||
aux_bytes += show_statistics_dependency_frames(out);
|
||||
aux_bytes += show_statistics_or_frames(out);
|
||||
aux_bytes += show_statistics_suspension_frames(out);
|
||||
Sfprintf(out, " Memory in use (I): %10ld bytes\n\n", aux_bytes);
|
||||
total_bytes += aux_bytes;
|
||||
aux_bytes = 0;
|
||||
stats = show_statistics_table_entries(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
#if defined(THREADS_FULL_SHARING) || defined(THREADS_CONSUMER_SHARING)
|
||||
stats = show_statistics_subgoal_entries(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
#endif /* THREADS_FULL_SHARING || THREADS_CONSUMER_SHARING */
|
||||
stats = show_statistics_subgoal_frames(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_dependency_frames(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_or_frames(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_suspension_frames(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
Sfprintf(out, " Memory in use (I): %10ld bytes\n\n", bytes);
|
||||
total_bytes += bytes;
|
||||
bytes = 0;
|
||||
Sfprintf(out, "Local trie data structures\n");
|
||||
aux_bytes += show_statistics_subgoal_trie_nodes(out);
|
||||
aux_bytes += show_statistics_answer_trie_nodes(out);
|
||||
aux_bytes += show_statistics_subgoal_trie_hashes(out);
|
||||
aux_bytes += show_statistics_answer_trie_hashes(out);
|
||||
Sfprintf(out, " Memory in use (II): %10ld bytes\n\n", aux_bytes);
|
||||
total_bytes += aux_bytes;
|
||||
aux_bytes = 0;
|
||||
stats = show_statistics_subgoal_trie_nodes(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_answer_trie_nodes(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_subgoal_trie_hashes(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_answer_trie_hashes(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
#if defined(THREADS_FULL_SHARING)
|
||||
stats = show_statistics_answer_ref_nodes(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
#endif /* THREADS_FULL_SHARING */
|
||||
Sfprintf(out, " Memory in use (II): %10ld bytes\n\n", bytes);
|
||||
total_bytes += bytes;
|
||||
bytes = 0;
|
||||
Sfprintf(out, "Global trie data structures\n");
|
||||
aux_bytes += show_statistics_global_trie_nodes(out);
|
||||
aux_bytes += show_statistics_global_trie_hashes(out);
|
||||
Sfprintf(out, " Memory in use (III): %10ld bytes\n\n", aux_bytes);
|
||||
total_bytes += aux_bytes;
|
||||
aux_bytes = 0;
|
||||
stats = show_statistics_global_trie_nodes(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_global_trie_hashes(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
Sfprintf(out, " Memory in use (III): %10ld bytes\n\n", bytes);
|
||||
total_bytes += bytes;
|
||||
bytes = 0;
|
||||
Sfprintf(out, "Cut support data structures\n");
|
||||
aux_bytes += show_statistics_query_goal_solution_frames(out);
|
||||
aux_bytes += show_statistics_query_goal_answer_frames(out);
|
||||
stats = show_statistics_query_goal_solution_frames(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_query_goal_answer_frames(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
#ifdef TABLING_INNER_CUTS
|
||||
aux_bytes += show_statistics_table_subgoal_solution_frames(out);
|
||||
aux_bytes += show_statistics_table_subgoal_answer_frames(out);
|
||||
stats = show_statistics_table_subgoal_solution_frames(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
stats = show_statistics_table_subgoal_answer_frames(out);
|
||||
INCREMENT_AUX_STATS(stats, bytes, total_pages);
|
||||
#endif /* TABLING_INNER_CUTS */
|
||||
Sfprintf(out, " Memory in use (IV): %10ld bytes\n\n", aux_bytes);
|
||||
total_bytes += aux_bytes;
|
||||
Sfprintf(out, " Memory in use (IV): %10ld bytes\n\n", bytes);
|
||||
total_bytes += bytes;
|
||||
#ifdef USE_PAGES_MALLOC
|
||||
Sfprintf(out, "Total memory in use (I+II+III+IV): %10ld bytes (%ld pages in use)\n",
|
||||
total_bytes, Pg_str_in_use(GLOBAL_pages_void));
|
||||
total_bytes, total_pages);
|
||||
Sfprintf(out, "Total memory allocated: %10ld bytes (%ld pages in total)\n",
|
||||
Pg_pg_alloc(GLOBAL_pages_void) * Yap_page_size, Pg_pg_alloc(GLOBAL_pages_void));
|
||||
#else
|
||||
@ -876,121 +949,121 @@ static Int p_get_optyap_statistics( USES_REGS1 ) {
|
||||
value = IntOfTerm(Deref(ARG1));
|
||||
#ifdef TABLING
|
||||
if (value == 0 || value == 1) { /* table_entries */
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_tab_ent);
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct table_entry);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct table_entry, GLOBAL_pages_tab_ent);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
#if defined(THREADS_FULL_SHARING) || defined(THREADS_CONSUMER_SHARING)
|
||||
if (value == 0 || value == 16) { /* subgoal_entries */
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_sg_entry);
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct subgoal_entry);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct subgoal_entry, GLOBAL_pages_sg_ent);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
#endif /* THREADS_FULL_SHARING || THREADS_CONSUMER_SHARING */
|
||||
if (value == 0 || value == 2) { /* subgoal_frames */
|
||||
#if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING) && !defined(THREADS_FULL_SHARING) && !defined(THREADS_CONSUMER_SHARING)
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_sg_fr);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct subgoal_frame, GLOBAL_pages_sg_fr);
|
||||
#else
|
||||
GET_REMOTE_PAGE_STATS(stats, REMOTE_pages_sg_fr);
|
||||
GET_REMOTE_PAGE_STATS(stats, struct subgoal_frame, REMOTE_pages_sg_fr);
|
||||
#endif
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct subgoal_frame);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
if (value == 0 || value == 3) { /* dependency_frames */
|
||||
#if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING) && !defined(THREADS_FULL_SHARING) && !defined(THREADS_CONSUMER_SHARING)
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_dep_fr);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct dependency_frame, GLOBAL_pages_dep_fr);
|
||||
#else
|
||||
GET_REMOTE_PAGE_STATS(stats, REMOTE_pages_dep_fr);
|
||||
GET_REMOTE_PAGE_STATS(stats, struct dependency_frame, REMOTE_pages_dep_fr);
|
||||
#endif
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct dependency_frame);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
if (value == 0 || value == 6) { /* subgoal_trie_nodes */
|
||||
#if !defined(THREADS_NO_SHARING)
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_sg_node);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct subgoal_trie_node, GLOBAL_pages_sg_node);
|
||||
#else
|
||||
GET_REMOTE_PAGE_STATS(stats, REMOTE_pages_sg_node);
|
||||
GET_REMOTE_PAGE_STATS(stats, struct subgoal_trie_node, REMOTE_pages_sg_node);
|
||||
#endif
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct subgoal_trie_node);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
if (value == 0 || value == 8) { /* subgoal_trie_hashes */
|
||||
#if !defined(THREADS_NO_SHARING)
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_sg_hash);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct subgoal_trie_hash, GLOBAL_pages_sg_hash);
|
||||
#else
|
||||
GET_REMOTE_PAGE_STATS(stats, REMOTE_pages_sg_hash);
|
||||
GET_REMOTE_PAGE_STATS(stats, struct subgoal_trie_hash, REMOTE_pages_sg_hash);
|
||||
#endif
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct subgoal_trie_hash);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
if (value == 0 || value == 7) { /* answer_trie_nodes */
|
||||
#if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING)
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_ans_node);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct answer_trie_node, GLOBAL_pages_ans_node);
|
||||
#else
|
||||
GET_REMOTE_PAGE_STATS(stats, REMOTE_pages_ans_node);
|
||||
GET_REMOTE_PAGE_STATS(stats, struct answer_trie_node, REMOTE_pages_ans_node);
|
||||
#endif
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct answer_trie_node);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
if (value == 0 || value == 9) { /* answer_trie_hashes */
|
||||
#if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING)
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_ans_hash);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct answer_trie_hash, GLOBAL_pages_ans_hash);
|
||||
#else
|
||||
GET_REMOTE_PAGE_STATS(stats, REMOTE_pages_ans_hash);
|
||||
GET_REMOTE_PAGE_STATS(stats, struct answer_trie_hash, REMOTE_pages_ans_hash);
|
||||
#endif
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct answer_trie_hash);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
#if defined(THREADS_FULL_SHARING)
|
||||
if (value == 0 || value == 17) { /* answer_ref_nodes */
|
||||
GET_REMOTE_PAGE_STATS(stats, REMOTE_pages_ans_ref_node);
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct answer_ref_node);
|
||||
GET_REMOTE_PAGE_STATS(stats, struct answer_ref_node, REMOTE_pages_ans_ref_node);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
#endif /* THREADS_FULL_SHARING */
|
||||
if (value == 0 || value == 10) { /* global_trie_nodes */
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_gt_node);
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct global_trie_node);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct global_trie_node, GLOBAL_pages_gt_node);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
if (value == 0 || value == 11) { /* global_trie_hashes */
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_gt_hash);
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct global_trie_hash);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct global_trie_hash, GLOBAL_pages_gt_hash);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
#endif /* TABLING */
|
||||
#ifdef YAPOR
|
||||
if (value == 0 || value == 4) { /* or_frames */
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_or_fr);
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct or_frame);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct or_frame, GLOBAL_pages_or_fr);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
if (value == 0 || value == 12) { /* query_goal_solution_frames */
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_qg_sol_fr);
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct query_goal_solution_frame);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct query_goal_solution_frame, GLOBAL_pages_qg_sol_fr);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
if (value == 0 || value == 13) { /* query_goal_answer_frames */
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_qg_ans_fr);
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct query_goal_answer_frame);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct query_goal_answer_frame, GLOBAL_pages_qg_ans_fr);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
#endif /* YAPOR */
|
||||
#if defined(YAPOR) && defined(TABLING)
|
||||
if (value == 0 || value == 5) { /* suspension_frames */
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_susp_fr);
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct suspension_frame);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct suspension_frame, GLOBAL_pages_susp_fr);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
#ifdef TABLING_INNER_CUTS
|
||||
if (value == 0 || value == 14) { /* table_subgoal_solution_frames */
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_tg_sol_fr);
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct table_subgoal_solution_frame);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct table_subgoal_solution_frame, GLOBAL_pages_tg_sol_fr);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
if (value == 0 || value == 15) { /* table_subgoal_answer_frames */
|
||||
GET_GLOBAL_PAGE_STATS(stats, GLOBAL_pages_tg_ans_fr);
|
||||
bytes += Pg_str_in_use(stats) * sizeof(struct table_subgoal_answer_frame);
|
||||
GET_GLOBAL_PAGE_STATS(stats, struct table_subgoal_answer_frame, GLOBAL_pages_tg_ans_fr);
|
||||
bytes += Pg_bytes_in_use(stats);
|
||||
if (value != 0) structs = Pg_str_in_use(stats);
|
||||
}
|
||||
#endif /* TABLING_INNER_CUTS */
|
||||
@ -1120,11 +1193,17 @@ static inline void answer_to_stdout(char *answer) {
|
||||
|
||||
|
||||
#ifdef TABLING
|
||||
static inline long show_statistics_table_entries(IOSTREAM *out) {
|
||||
static inline struct page_statistics show_statistics_table_entries(IOSTREAM *out) {
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct table_entry, GLOBAL_pages_tab_ent, "Table entries: ");
|
||||
}
|
||||
|
||||
static inline long show_statistics_subgoal_frames(IOSTREAM *out) {
|
||||
#if defined(THREADS_FULL_SHARING) || defined(THREADS_CONSUMER_SHARING)
|
||||
static inline struct page_statistics show_statistics_subgoal_entries(IOSTREAM *out) {
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct subgoal_entry, GLOBAL_pages_sg_ent, "Subgoal entries: ");
|
||||
}
|
||||
#endif /* THREADS_FULL_SHARING || THREADS_CONSUMER_SHARING */
|
||||
|
||||
static inline struct page_statistics show_statistics_subgoal_frames(IOSTREAM *out) {
|
||||
#if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING) && !defined(THREADS_FULL_SHARING) && !defined(THREADS_CONSUMER_SHARING)
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct subgoal_frame, GLOBAL_pages_sg_fr, "Subgoal frames: ");
|
||||
#else
|
||||
@ -1132,7 +1211,7 @@ static inline long show_statistics_subgoal_frames(IOSTREAM *out) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline long show_statistics_dependency_frames(IOSTREAM *out) {
|
||||
static inline struct page_statistics show_statistics_dependency_frames(IOSTREAM *out) {
|
||||
#if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING) && !defined(THREADS_FULL_SHARING) && !defined(THREADS_CONSUMER_SHARING)
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct dependency_frame, GLOBAL_pages_dep_fr, "Dependency frames: ");
|
||||
#else
|
||||
@ -1140,7 +1219,7 @@ static inline long show_statistics_dependency_frames(IOSTREAM *out) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline long show_statistics_subgoal_trie_nodes(IOSTREAM *out) {
|
||||
static inline struct page_statistics show_statistics_subgoal_trie_nodes(IOSTREAM *out) {
|
||||
#if !defined(THREADS_NO_SHARING)
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct subgoal_trie_node, GLOBAL_pages_sg_node, "Subgoal trie nodes: ");
|
||||
#else
|
||||
@ -1148,7 +1227,7 @@ static inline long show_statistics_subgoal_trie_nodes(IOSTREAM *out) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline long show_statistics_subgoal_trie_hashes(IOSTREAM *out) {
|
||||
static inline struct page_statistics show_statistics_subgoal_trie_hashes(IOSTREAM *out) {
|
||||
#if !defined(THREADS_NO_SHARING)
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct subgoal_trie_hash, GLOBAL_pages_sg_hash, "Subgoal trie hashes: ");
|
||||
#else
|
||||
@ -1156,7 +1235,7 @@ static inline long show_statistics_subgoal_trie_hashes(IOSTREAM *out) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline long show_statistics_answer_trie_nodes(IOSTREAM *out) {
|
||||
static inline struct page_statistics show_statistics_answer_trie_nodes(IOSTREAM *out) {
|
||||
#if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING)
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct answer_trie_node, GLOBAL_pages_ans_node, "Answer trie nodes: ");
|
||||
#else
|
||||
@ -1164,7 +1243,7 @@ static inline long show_statistics_answer_trie_nodes(IOSTREAM *out) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline long show_statistics_answer_trie_hashes(IOSTREAM *out) {
|
||||
static inline struct page_statistics show_statistics_answer_trie_hashes(IOSTREAM *out) {
|
||||
#if !defined(THREADS_NO_SHARING) && !defined(THREADS_SUBGOAL_SHARING)
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct answer_trie_hash, GLOBAL_pages_ans_hash, "Answer trie hashes: ");
|
||||
#else
|
||||
@ -1172,42 +1251,48 @@ static inline long show_statistics_answer_trie_hashes(IOSTREAM *out) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline long show_statistics_global_trie_nodes(IOSTREAM *out) {
|
||||
#if defined(THREADS_FULL_SHARING)
|
||||
static inline struct page_statistics show_statistics_answer_ref_nodes(IOSTREAM *out) {
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct answer_ref_node, REMOTE_pages_ans_ref_node, "Answer ref nodes: ");
|
||||
}
|
||||
#endif /* THREADS_FULL_SHARING */
|
||||
|
||||
static inline struct page_statistics show_statistics_global_trie_nodes(IOSTREAM *out) {
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct global_trie_node, GLOBAL_pages_gt_node, "Global trie nodes: ");
|
||||
}
|
||||
|
||||
static inline long show_statistics_global_trie_hashes(IOSTREAM *out) {
|
||||
static inline struct page_statistics show_statistics_global_trie_hashes(IOSTREAM *out) {
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct global_trie_hash, GLOBAL_pages_gt_hash, "Global trie hashes: ");
|
||||
}
|
||||
#endif /* TABLING */
|
||||
|
||||
|
||||
#ifdef YAPOR
|
||||
static inline long show_statistics_or_frames(IOSTREAM *out) {
|
||||
static inline struct page_statistics show_statistics_or_frames(IOSTREAM *out) {
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct or_frame, GLOBAL_pages_or_fr, "Or-frames: ");
|
||||
}
|
||||
|
||||
static inline long show_statistics_query_goal_solution_frames(IOSTREAM *out) {
|
||||
static inline struct page_statistics show_statistics_query_goal_solution_frames(IOSTREAM *out) {
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct query_goal_solution_frame, GLOBAL_pages_qg_sol_fr, "Query goal solution frames: ");
|
||||
}
|
||||
|
||||
static inline long show_statistics_query_goal_answer_frames(IOSTREAM *out) {
|
||||
static inline struct page_statistics show_statistics_query_goal_answer_frames(IOSTREAM *out) {
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct query_goal_answer_frame, GLOBAL_pages_qg_ans_fr, "Query goal answer frames: ");
|
||||
}
|
||||
#endif /* YAPOR */
|
||||
|
||||
|
||||
#if defined(YAPOR) && defined(TABLING)
|
||||
static inline long show_statistics_suspension_frames(IOSTREAM *out) {
|
||||
static inline struct page_statistics show_statistics_suspension_frames(IOSTREAM *out) {
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct suspension_frame, GLOBAL_pages_susp_fr, "Suspension frames: ");
|
||||
}
|
||||
|
||||
#ifdef TABLING_INNER_CUTS
|
||||
static inline long show_statistics_table_subgoal_solution_frames(IOSTREAM *out) {
|
||||
static inline struct page_statistics show_statistics_table_subgoal_solution_frames(IOSTREAM *out) {
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct table_subgoal_solution_frame, GLOBAL_pages_tg_sol_fr, "Table subgoal solution frames:");
|
||||
}
|
||||
|
||||
static inline long show_statistics_table_subgoal_answer_frames(IOSTREAM *out) {
|
||||
static inline struct page_statistics show_statistics_table_subgoal_answer_frames(IOSTREAM *out) {
|
||||
SHOW_GLOBAL_PAGE_STATS(out, struct table_subgoal_answer_frame, GLOBAL_pages_tg_ans_fr, "Table subgoal answer frames: ");
|
||||
}
|
||||
#endif /* TABLING_INNER_CUTS */
|
||||
|
@ -606,6 +606,7 @@ static inline void adjust_freeze_registers(void) {
|
||||
|
||||
|
||||
static inline void mark_as_completed(sg_fr_ptr sg_fr) {
|
||||
CACHE_REGS
|
||||
LOCK_SG_FR(sg_fr);
|
||||
SgFr_state(sg_fr) = complete;
|
||||
UNLOCK_SG_FR(sg_fr);
|
||||
|
@ -1388,6 +1388,7 @@ static inline ans_node_ptr answer_search_loop(sg_fr_ptr sg_fr, ans_node_ptr curr
|
||||
}
|
||||
|
||||
static inline ans_node_ptr answer_search_min_max(sg_fr_ptr sg_fr, ans_node_ptr current_node, Term t, int mode) {
|
||||
CACHE_REGS
|
||||
ans_node_ptr child_node;
|
||||
Term child_term;
|
||||
Float trie_value, term_value;
|
||||
@ -1486,6 +1487,8 @@ static inline ans_node_ptr answer_search_min_max(sg_fr_ptr sg_fr, ans_node_ptr c
|
||||
SgFr_invalid_chain(SG_FR) = NODE
|
||||
|
||||
static void invalidate_answer_trie(ans_node_ptr current_node, sg_fr_ptr sg_fr, int position) {
|
||||
CACHE_REGS
|
||||
|
||||
if (IS_ANSWER_TRIE_HASH(current_node)) {
|
||||
ans_hash_ptr hash;
|
||||
ans_node_ptr *bucket, *last_bucket;
|
||||
|
@ -8,6 +8,8 @@
|
||||
:- module(clpbn_bp,
|
||||
[bp/3,
|
||||
check_if_bp_done/1,
|
||||
set_solver_parameter/2,
|
||||
use_log_space/0,
|
||||
init_bp_solver/4,
|
||||
run_bp_solver/3,
|
||||
finalize_bp_solver/1
|
||||
@ -34,12 +36,30 @@
|
||||
|
||||
:- attribute id/1.
|
||||
|
||||
:- dynamic num_bayes_nets/1.
|
||||
:- dynamic network_counting/1.
|
||||
|
||||
|
||||
check_if_bp_done(_Var).
|
||||
|
||||
num_bayes_nets(0).
|
||||
network_counting(0).
|
||||
|
||||
|
||||
:- set_solver_parameter(run_mode, normal).
|
||||
%:- set_solver_parameter(run_mode, convert).
|
||||
%: -set_solver_parameter(run_mode, compress).
|
||||
|
||||
:- set_solver_parameter(schedule, seq_fixed).
|
||||
%:- set_solver_parameter(schedule, seq_random).
|
||||
%:- set_solver_parameter(schedule, parallel).
|
||||
%:- set_solver_parameter(schedule, max_residual).
|
||||
|
||||
:- set_solver_parameter(accuracy, 0.0001).
|
||||
|
||||
:- set_solver_parameter(max_iter, 1000).
|
||||
|
||||
:- set_solver_parameter(always_loopy_solver, false).
|
||||
|
||||
% :- use_log_space.
|
||||
|
||||
|
||||
bp([[]],_,_) :- !.
|
||||
@ -51,14 +71,15 @@ bp([QueryVars], AllVars, Output) :-
|
||||
|
||||
|
||||
init_bp_solver(_, AllVars, _, (BayesNet, DistIds)) :-
|
||||
%inc_num_bayes_nets,
|
||||
%(showprofres(50) -> true ; true),
|
||||
%inc_network_counting,
|
||||
process_ids(AllVars, 0, DistIds0),
|
||||
get_vars_info(AllVars, VarsInfo),
|
||||
sort(DistIds0, DistIds),
|
||||
%(num_bayes_nets(0) -> writeln(vars:VarsInfo) ; true),
|
||||
%(num_bayes_nets(0) -> writeln(dists:DistsInfo) ; true),
|
||||
%(network_counting(0) -> writeln(vars:VarsInfo) ; true),
|
||||
%(network_counting(0) -> writeln(distsids:DistIds) ; true),
|
||||
create_network(VarsInfo, BayesNet).
|
||||
%get_extra_vars_info(AllVars, ExtraVarsInfo),
|
||||
%(network_counting(0) -> writeln(extra:ExtraVarsInfo) ; true),
|
||||
%set_extra_vars_info(BayesNet, ExtraVarsInfo).
|
||||
|
||||
|
||||
@ -103,6 +124,8 @@ get_extra_vars_info([V|Vs], [v(VarId, Label, Domain)|VarsInfo]) :-
|
||||
get_dist_domain(DistId, Domain0),
|
||||
numbers2atoms(Domain0, Domain),
|
||||
get_extra_vars_info(Vs, VarsInfo).
|
||||
get_extra_vars_info([_|Vs], VarsInfo) :-
|
||||
get_extra_vars_info(Vs, VarsInfo).
|
||||
|
||||
|
||||
numbers2atoms([], []).
|
||||
@ -118,8 +141,7 @@ run_bp_solver(QVsL0, LPs, (BayesNet, DistIds)) :-
|
||||
get_dists_parameters(DistIds, DistsParams),
|
||||
set_parameters(BayesNet, DistsParams),
|
||||
process_query_list(QVsL0, QVsL),
|
||||
%writeln(' qvs':QVsL),
|
||||
%(num_bayes_nets(1506) -> writeln(qvs:QVsL) ; true),
|
||||
%(network_counting(0) -> writeln(qvs:QVsL) ; true),
|
||||
run_solver(BayesNet, QVsL, LPs).
|
||||
|
||||
|
||||
@ -139,11 +161,11 @@ get_dists_parameters([Id|Ids], [dist(Id, Params)|DistsInfo]) :-
|
||||
|
||||
|
||||
finalize_bp_solver((BayesNet, _)) :-
|
||||
delete_bayes_net(BayesNet).
|
||||
free_bayesian_network(BayesNet).
|
||||
|
||||
|
||||
inc_num_bayes_nets :-
|
||||
retract(num_bayes_nets(Count0)),
|
||||
inc_network_counting :-
|
||||
retract(network_counting(Count0)),
|
||||
Count is Count0 + 1,
|
||||
assert(num_bayes_nets(Count)).
|
||||
assert(network_counting(Count)).
|
||||
|
||||
|
@ -1,149 +0,0 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "BPNodeInfo.h"
|
||||
#include "BPSolver.h"
|
||||
|
||||
BPNodeInfo::BPNodeInfo (BayesNode* node)
|
||||
{
|
||||
node_ = node;
|
||||
ds_ = node->getDomainSize();
|
||||
piValsCalc_ = false;
|
||||
ldValsCalc_ = false;
|
||||
nPiMsgsRcv_ = 0;
|
||||
nLdMsgsRcv_ = 0;
|
||||
piVals_.resize (ds_, 1);
|
||||
ldVals_.resize (ds_, 1);
|
||||
const BnNodeSet& childs = node->getChilds();
|
||||
for (unsigned i = 0; i < childs.size(); i++) {
|
||||
cmsgs_.insert (make_pair (childs[i], false));
|
||||
}
|
||||
const BnNodeSet& parents = node->getParents();
|
||||
for (unsigned i = 0; i < parents.size(); i++) {
|
||||
pmsgs_.insert (make_pair (parents[i], false));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
BPNodeInfo::getBeliefs (void) const
|
||||
{
|
||||
double sum = 0.0;
|
||||
ParamSet beliefs (ds_);
|
||||
for (unsigned xi = 0; xi < ds_; xi++) {
|
||||
double prod = piVals_[xi] * ldVals_[xi];
|
||||
beliefs[xi] = prod;
|
||||
sum += prod;
|
||||
}
|
||||
assert (sum);
|
||||
//normalize the beliefs
|
||||
for (unsigned xi = 0; xi < ds_; xi++) {
|
||||
beliefs[xi] /= sum;
|
||||
}
|
||||
return beliefs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BPNodeInfo::readyToSendPiMsgTo (const BayesNode* child) const
|
||||
{
|
||||
for (unsigned i = 0; i < inChildLinks_.size(); i++) {
|
||||
if (inChildLinks_[i]->getSource() != child
|
||||
&& !inChildLinks_[i]->messageWasSended()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BPNodeInfo::readyToSendLambdaMsgTo (const BayesNode* parent) const
|
||||
{
|
||||
for (unsigned i = 0; i < inParentLinks_.size(); i++) {
|
||||
if (inParentLinks_[i]->getSource() != parent
|
||||
&& !inParentLinks_[i]->messageWasSended()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
BPNodeInfo::getPiValue (unsigned idx) const
|
||||
{
|
||||
assert (idx >=0 && idx < ds_);
|
||||
return piVals_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BPNodeInfo::setPiValue (unsigned idx, Param value)
|
||||
{
|
||||
assert (idx >=0 && idx < ds_);
|
||||
piVals_[idx] = value;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
BPNodeInfo::getLambdaValue (unsigned idx) const
|
||||
{
|
||||
assert (idx >=0 && idx < ds_);
|
||||
return ldVals_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BPNodeInfo::setLambdaValue (unsigned idx, Param value)
|
||||
{
|
||||
assert (idx >=0 && idx < ds_);
|
||||
ldVals_[idx] = value;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
BPNodeInfo::getBeliefChange (void)
|
||||
{
|
||||
double change = 0.0;
|
||||
if (oldBeliefs_.size() == 0) {
|
||||
oldBeliefs_ = getBeliefs();
|
||||
change = 9999999999.0;
|
||||
} else {
|
||||
ParamSet currentBeliefs = getBeliefs();
|
||||
for (unsigned xi = 0; xi < ds_; xi++) {
|
||||
change += abs (currentBeliefs[xi] - oldBeliefs_[xi]);
|
||||
}
|
||||
oldBeliefs_ = currentBeliefs;
|
||||
}
|
||||
return change;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BPNodeInfo::receivedBottomInfluence (void) const
|
||||
{
|
||||
// if all lambda values are equal, then neither
|
||||
// this node neither its descendents have evidence,
|
||||
// we can use this to don't send lambda messages his parents
|
||||
bool childInfluenced = false;
|
||||
for (unsigned xi = 1; xi < ds_; xi++) {
|
||||
if (ldVals_[xi] != ldVals_[0]) {
|
||||
childInfluenced = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return childInfluenced;
|
||||
}
|
||||
|
@ -1,82 +0,0 @@
|
||||
#ifndef BP_BP_NODE_H
|
||||
#define BP_BP_NODE_H
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "BPSolver.h"
|
||||
#include "BayesNode.h"
|
||||
#include "Shared.h"
|
||||
|
||||
//class Edge;
|
||||
|
||||
using namespace std;
|
||||
|
||||
class BPNodeInfo
|
||||
{
|
||||
public:
|
||||
BPNodeInfo (int);
|
||||
BPNodeInfo (BayesNode*);
|
||||
|
||||
ParamSet getBeliefs (void) const;
|
||||
double getPiValue (unsigned) const;
|
||||
void setPiValue (unsigned, Param);
|
||||
double getLambdaValue (unsigned) const;
|
||||
void setLambdaValue (unsigned, Param);
|
||||
double getBeliefChange (void);
|
||||
bool receivedBottomInfluence (void) const;
|
||||
|
||||
ParamSet& getPiValues (void) { return piVals_; }
|
||||
ParamSet& getLambdaValues (void) { return ldVals_; }
|
||||
bool arePiValuesCalculated (void) { return piValsCalc_; }
|
||||
bool areLambdaValuesCalculated (void) { return ldValsCalc_; }
|
||||
void markPiValuesAsCalculated (void) { piValsCalc_ = true; }
|
||||
void markLambdaValuesAsCalculated (void) { ldValsCalc_ = true; }
|
||||
void incNumPiMsgsRcv (void) { nPiMsgsRcv_ ++; }
|
||||
void incNumLambdaMsgsRcv (void) { nLdMsgsRcv_ ++; }
|
||||
|
||||
bool receivedAllPiMessages (void)
|
||||
{
|
||||
return node_->getParents().size() == nPiMsgsRcv_;
|
||||
}
|
||||
|
||||
bool receivedAllLambdaMessages (void)
|
||||
{
|
||||
return node_->getChilds().size() == nLdMsgsRcv_;
|
||||
}
|
||||
|
||||
bool readyToSendPiMsgTo (const BayesNode*) const ;
|
||||
bool readyToSendLambdaMsgTo (const BayesNode*) const;
|
||||
|
||||
CEdgeSet getIncomingParentLinks (void) { return inParentLinks_; }
|
||||
CEdgeSet getIncomingChildLinks (void) { return inChildLinks_; }
|
||||
CEdgeSet getOutcomingParentLinks (void) { return outParentLinks_; }
|
||||
CEdgeSet getOutcomingChildLinks (void) { return outChildLinks_; }
|
||||
|
||||
void addIncomingParentLink (Edge* l) { inParentLinks_.push_back (l); }
|
||||
void addIncomingChildLink (Edge* l) { inChildLinks_.push_back (l); }
|
||||
void addOutcomingParentLink (Edge* l) { outParentLinks_.push_back (l); }
|
||||
void addOutcomingChildLink (Edge* l) { outChildLinks_.push_back (l); }
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BPNodeInfo);
|
||||
|
||||
ParamSet piVals_; // pi values
|
||||
ParamSet ldVals_; // lambda values
|
||||
ParamSet oldBeliefs_;
|
||||
unsigned nPiMsgsRcv_;
|
||||
unsigned nLdMsgsRcv_;
|
||||
bool piValsCalc_;
|
||||
bool ldValsCalc_;
|
||||
EdgeSet inParentLinks_;
|
||||
EdgeSet inChildLinks_;
|
||||
EdgeSet outParentLinks_;
|
||||
EdgeSet outChildLinks_;
|
||||
unsigned ds_;
|
||||
const BayesNode* node_;
|
||||
map<const BayesNode*, bool> pmsgs_;
|
||||
map<const BayesNode*, bool> cmsgs_;
|
||||
};
|
||||
|
||||
#endif //BP_BP_NODE_H
|
||||
|
@ -1,905 +0,0 @@
|
||||
#include <cstdlib>
|
||||
#include <limits>
|
||||
#include <time.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "BPSolver.h"
|
||||
|
||||
BPSolver::BPSolver (const BayesNet& bn) : Solver (&bn)
|
||||
{
|
||||
bn_ = &bn;
|
||||
useAlwaysLoopySolver_ = false;
|
||||
//jointCalcType_ = CHAIN_RULE;
|
||||
jointCalcType_ = JUNCTION_NODE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
BPSolver::~BPSolver (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodesI_.size(); i++) {
|
||||
delete nodesI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BPSolver::runSolver (void)
|
||||
{
|
||||
clock_t start_ = clock();
|
||||
unsigned size = bn_->getNumberOfNodes();
|
||||
unsigned nIters = 0;
|
||||
initializeSolver();
|
||||
if (bn_->isSingleConnected() && !useAlwaysLoopySolver_) {
|
||||
runPolyTreeSolver();
|
||||
Statistics::numSolvedPolyTrees ++;
|
||||
} else {
|
||||
runLoopySolver();
|
||||
Statistics::numSolvedLoopyNets ++;
|
||||
if (nIter_ >= SolverOptions::maxIter) {
|
||||
Statistics::numUnconvergedRuns ++;
|
||||
} else {
|
||||
nIters = nIter_;
|
||||
}
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
if (nIter_ < SolverOptions::maxIter) {
|
||||
cout << "Belief propagation converged in " ;
|
||||
cout << nIter_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
double time = (double (clock() - start_)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStats (size, nIters, time);
|
||||
if (EXPORT_TO_DOT && size > EXPORT_MIN_SIZE) {
|
||||
stringstream ss;
|
||||
ss << size << "." ;
|
||||
ss << Statistics::getCounting (size) << ".dot" ;
|
||||
bn_->exportToDotFormat (ss.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
BPSolver::getPosterioriOf (Vid vid) const
|
||||
{
|
||||
BayesNode* node = bn_->getBayesNode (vid);
|
||||
assert (node);
|
||||
return nodesI_[node->getIndex()]->getBeliefs();
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
BPSolver::getJointDistributionOf (const VidSet& jointVids)
|
||||
{
|
||||
if (DL >= 2) {
|
||||
cout << "calculating joint distribution on: " ;
|
||||
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||
Variable* var = bn_->getBayesNode (jointVids[i]);
|
||||
cout << var->getLabel() << " " ;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
if (jointCalcType_ == JUNCTION_NODE) {
|
||||
return getJointByJunctionNode (jointVids);
|
||||
} else {
|
||||
return getJointByChainRule (jointVids);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BPSolver::initializeSolver (void)
|
||||
{
|
||||
if (DL >= 2) {
|
||||
if (!useAlwaysLoopySolver_) {
|
||||
cout << "-> solver type = polytree solver" << endl;
|
||||
cout << "-> schedule = n/a";
|
||||
} else {
|
||||
cout << "-> solver = loopy solver" << endl;
|
||||
cout << "-> schedule = ";
|
||||
switch (SolverOptions::schedule) {
|
||||
case SolverOptions::S_SEQ_FIXED: cout << "sequential fixed" ; break;
|
||||
case SolverOptions::S_SEQ_RANDOM: cout << "sequential random" ; break;
|
||||
case SolverOptions::S_PARALLEL: cout << "parallel" ; break;
|
||||
case SolverOptions::S_MAX_RESIDUAL: cout << "max residual" ; break;
|
||||
}
|
||||
}
|
||||
cout << endl;
|
||||
cout << "-> joint method = " ;
|
||||
if (jointCalcType_ == JUNCTION_NODE) {
|
||||
cout << "junction node" << endl;
|
||||
} else {
|
||||
cout << "chain rule " << endl;
|
||||
}
|
||||
cout << "-> max iters = " << SolverOptions::maxIter << endl;
|
||||
cout << "-> accuracy = " << SolverOptions::accuracy << endl;
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
CBnNodeSet nodes = bn_->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodesI_.size(); i++) {
|
||||
delete nodesI_[i];
|
||||
}
|
||||
nodesI_.clear();
|
||||
nodesI_.reserve (nodes.size());
|
||||
links_.clear();
|
||||
sortedOrder_.clear();
|
||||
edgeMap_.clear();
|
||||
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
nodesI_.push_back (new BPNodeInfo (nodes[i]));
|
||||
}
|
||||
|
||||
BnNodeSet roots = bn_->getRootNodes();
|
||||
for (unsigned i = 0; i < roots.size(); i++) {
|
||||
const ParamSet& params = roots[i]->getParameters();
|
||||
ParamSet& piVals = M(roots[i])->getPiValues();
|
||||
for (unsigned ri = 0; ri < roots[i]->getDomainSize(); ri++) {
|
||||
piVals[ri] = params[ri];
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
CBnNodeSet parents = nodes[i]->getParents();
|
||||
for (unsigned j = 0; j < parents.size(); j++) {
|
||||
Edge* newLink = new Edge (parents[j], nodes[i], PI_MSG);
|
||||
links_.push_back (newLink);
|
||||
M(nodes[i])->addIncomingParentLink (newLink);
|
||||
M(parents[j])->addOutcomingChildLink (newLink);
|
||||
}
|
||||
CBnNodeSet childs = nodes[i]->getChilds();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
Edge* newLink = new Edge (childs[j], nodes[i], LAMBDA_MSG);
|
||||
links_.push_back (newLink);
|
||||
M(nodes[i])->addIncomingChildLink (newLink);
|
||||
M(childs[j])->addOutcomingParentLink (newLink);
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
if (nodes[i]->hasEvidence()) {
|
||||
ParamSet& piVals = M(nodes[i])->getPiValues();
|
||||
ParamSet& ldVals = M(nodes[i])->getLambdaValues();
|
||||
for (unsigned xi = 0; xi < nodes[i]->getDomainSize(); xi++) {
|
||||
piVals[xi] = 0.0;
|
||||
ldVals[xi] = 0.0;
|
||||
}
|
||||
piVals[nodes[i]->getEvidence()] = 1.0;
|
||||
ldVals[nodes[i]->getEvidence()] = 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BPSolver::runPolyTreeSolver (void)
|
||||
{
|
||||
CBnNodeSet nodes = bn_->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
if (nodes[i]->isRoot()) {
|
||||
M(nodes[i])->markPiValuesAsCalculated();
|
||||
}
|
||||
if (nodes[i]->isLeaf()) {
|
||||
M(nodes[i])->markLambdaValuesAsCalculated();
|
||||
}
|
||||
}
|
||||
|
||||
bool finish = false;
|
||||
while (!finish) {
|
||||
finish = true;
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
if (M(nodes[i])->arePiValuesCalculated() == false
|
||||
&& M(nodes[i])->receivedAllPiMessages()) {
|
||||
if (!nodes[i]->hasEvidence()) {
|
||||
updatePiValues (nodes[i]);
|
||||
}
|
||||
M(nodes[i])->markPiValuesAsCalculated();
|
||||
finish = false;
|
||||
}
|
||||
|
||||
if (M(nodes[i])->areLambdaValuesCalculated() == false
|
||||
&& M(nodes[i])->receivedAllLambdaMessages()) {
|
||||
if (!nodes[i]->hasEvidence()) {
|
||||
updateLambdaValues (nodes[i]);
|
||||
}
|
||||
M(nodes[i])->markLambdaValuesAsCalculated();
|
||||
finish = false;
|
||||
}
|
||||
|
||||
if (M(nodes[i])->arePiValuesCalculated()) {
|
||||
CEdgeSet outChildLinks = M(nodes[i])->getOutcomingChildLinks();
|
||||
for (unsigned j = 0; j < outChildLinks.size(); j++) {
|
||||
BayesNode* child = outChildLinks[j]->getDestination();
|
||||
if (!outChildLinks[j]->messageWasSended()) {
|
||||
if (M(nodes[i])->readyToSendPiMsgTo (child)) {
|
||||
outChildLinks[j]->setNextMessage (getMessage (outChildLinks[j]));
|
||||
outChildLinks[j]->updateMessage();
|
||||
M(child)->incNumPiMsgsRcv();
|
||||
}
|
||||
finish = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (M(nodes[i])->areLambdaValuesCalculated()) {
|
||||
CEdgeSet outParentLinks = M(nodes[i])->getOutcomingParentLinks();
|
||||
for (unsigned j = 0; j < outParentLinks.size(); j++) {
|
||||
BayesNode* parent = outParentLinks[j]->getDestination();
|
||||
if (!outParentLinks[j]->messageWasSended()) {
|
||||
if (M(nodes[i])->readyToSendLambdaMsgTo (parent)) {
|
||||
outParentLinks[j]->setNextMessage (getMessage (outParentLinks[j]));
|
||||
outParentLinks[j]->updateMessage();
|
||||
M(parent)->incNumLambdaMsgsRcv();
|
||||
}
|
||||
finish = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BPSolver::runLoopySolver()
|
||||
{
|
||||
nIter_ = 0;
|
||||
while (!converged() && nIter_ < SolverOptions::maxIter) {
|
||||
|
||||
nIter_++;
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
cout << " Iteration " << nIter_ << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
switch (SolverOptions::schedule) {
|
||||
|
||||
case SolverOptions::S_SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
|
||||
case SolverOptions::S_SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->setNextMessage (getMessage (links_[i]));
|
||||
links_[i]->updateMessage();
|
||||
updateValues (links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case SolverOptions::S_PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
//calculateNextMessage (links_[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
//updateMessage (links_[i]);
|
||||
//updateValues (links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case SolverOptions::S_MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BPSolver::converged (void) const
|
||||
{
|
||||
// this can happen if the graph is fully disconnected
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIter_ == 0 || nIter_ == 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (SolverOptions::schedule == SolverOptions::S_MAX_RESIDUAL) {
|
||||
Param maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||
if (maxResidual < SolverOptions::accuracy) {
|
||||
converged = true;
|
||||
} else {
|
||||
converged = false;
|
||||
}
|
||||
} else {
|
||||
CBnNodeSet nodes = bn_->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
if (!nodes[i]->hasEvidence()) {
|
||||
double change = M(nodes[i])->getBeliefChange();
|
||||
if (DL >= 2) {
|
||||
cout << nodes[i]->getLabel() + " belief change = " ;
|
||||
cout << change << endl;
|
||||
}
|
||||
if (change > SolverOptions::accuracy) {
|
||||
converged = false;
|
||||
if (DL == 0) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BPSolver::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIter_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->setNextMessage (getMessage (links_[i]));
|
||||
links_[i]->updateResidual();
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
edgeMap_.insert (make_pair (links_[i], it));
|
||||
if (DL >= 2) {
|
||||
cout << "calculating " << links_[i]->toString() << endl;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < sortedOrder_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
cout << endl << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
Edge* edge = *it;
|
||||
if (DL >= 2) {
|
||||
cout << "updating " << edge->toString() << endl;
|
||||
}
|
||||
if (edge->getResidual() < SolverOptions::accuracy) {
|
||||
return;
|
||||
}
|
||||
edge->updateMessage();
|
||||
updateValues (edge);
|
||||
edge->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
edgeMap_.find (edge)->second = sortedOrder_.insert (edge);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
CEdgeSet outChildLinks =
|
||||
M(edge->getDestination())->getOutcomingChildLinks();
|
||||
for (unsigned i = 0; i < outChildLinks.size(); i++) {
|
||||
if (outChildLinks[i]->getDestination() != edge->getSource()) {
|
||||
if (DL >= 2) {
|
||||
cout << " calculating " << outChildLinks[i]->toString() << endl;
|
||||
}
|
||||
outChildLinks[i]->setNextMessage (getMessage (outChildLinks[i]));
|
||||
outChildLinks[i]->updateResidual();
|
||||
EdgeMap::iterator iter = edgeMap_.find (outChildLinks[i]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (outChildLinks[i]);
|
||||
}
|
||||
}
|
||||
CEdgeSet outParentLinks =
|
||||
M(edge->getDestination())->getOutcomingParentLinks();
|
||||
for (unsigned i = 0; i < outParentLinks.size(); i++) {
|
||||
if (outParentLinks[i]->getDestination() != edge->getSource()) {
|
||||
//&& !outParentLinks[i]->getDestination()->hasEvidence()) FIXME{
|
||||
if (DL >= 2) {
|
||||
cout << " calculating " << outParentLinks[i]->toString() << endl;
|
||||
}
|
||||
outParentLinks[i]->setNextMessage (getMessage (outParentLinks[i]));
|
||||
outParentLinks[i]->updateResidual();
|
||||
EdgeMap::iterator iter = edgeMap_.find (outParentLinks[i]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (outParentLinks[i]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BPSolver::updatePiValues (BayesNode* x)
|
||||
{
|
||||
// π(Xi)
|
||||
if (DL >= 3) {
|
||||
cout << "updating " << PI << " values for " << x->getLabel() << endl;
|
||||
}
|
||||
CEdgeSet parentLinks = M(x)->getIncomingParentLinks();
|
||||
assert (x->getParents() == parentLinks.size());
|
||||
const vector<CptEntry>& entries = x->getCptEntries();
|
||||
stringstream* calcs1 = 0;
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
ParamSet messageProducts (entries.size());
|
||||
for (unsigned k = 0; k < entries.size(); k++) {
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
double messageProduct = 1.0;
|
||||
const DConf& conf = entries[k].getDomainConfiguration();
|
||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
||||
assert (parentLinks[i]->getSource() == parents[i]);
|
||||
assert (parentLinks[i]->getDestination() == x);
|
||||
messageProduct *= parentLinks[i]->getMessage()[conf[i]];
|
||||
if (DL >= 5) {
|
||||
if (i != 0) *calcs1 << "." ;
|
||||
if (i != 0) *calcs2 << "*" ;
|
||||
*calcs1 << PI << "(" << parentLinks[i]->getSource()->getLabel();
|
||||
*calcs1 << " --> " << x->getLabel() << ")" ;
|
||||
*calcs1 << "[" ;
|
||||
*calcs1 << parentLinks[i]->getSource()->getDomain()[conf[i]];
|
||||
*calcs1 << "]";
|
||||
*calcs2 << parentLinks[i]->getMessage()[conf[i]];
|
||||
}
|
||||
}
|
||||
messageProducts[k] = messageProduct;
|
||||
if (DL >= 5) {
|
||||
cout << " mp" << k;
|
||||
cout << " = " << (*calcs1).str();
|
||||
if (parentLinks.size() == 1) {
|
||||
cout << " = " << messageProduct << endl;
|
||||
} else {
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << messageProduct << endl;
|
||||
}
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned xi = 0; xi < x->getDomainSize(); xi++) {
|
||||
double sum = 0.0;
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
for (unsigned k = 0; k < entries.size(); k++) {
|
||||
sum += x->getProbability (xi, entries[k]) * messageProducts[k];
|
||||
if (DL >= 5) {
|
||||
if (k != 0) *calcs1 << " + " ;
|
||||
if (k != 0) *calcs2 << " + " ;
|
||||
*calcs1 << x->cptEntryToString (xi, entries[k]);
|
||||
*calcs1 << ".mp" << k;
|
||||
*calcs2 << x->getProbability (xi, entries[k]);
|
||||
*calcs2 << "*" << messageProducts[k];
|
||||
}
|
||||
}
|
||||
M(x)->setPiValue (xi, sum);
|
||||
if (DL >= 5) {
|
||||
cout << " " << PI << "(" << x->getLabel() << ")" ;
|
||||
cout << "[" << x->getDomain()[xi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << sum << endl;
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BPSolver::updateLambdaValues (BayesNode* x)
|
||||
{
|
||||
// λ(Xi)
|
||||
if (DL >= 3) {
|
||||
cout << "updating " << LD << " values for " << x->getLabel() << endl;
|
||||
}
|
||||
CEdgeSet childLinks = M(x)->getIncomingChildLinks();
|
||||
stringstream* calcs1 = 0;
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
for (unsigned xi = 0; xi < x->getDomainSize(); xi++) {
|
||||
double product = 1.0;
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
||||
assert (childLinks[i]->getDestination() == x);
|
||||
product *= childLinks[i]->getMessage()[xi];
|
||||
if (DL >= 5) {
|
||||
if (i != 0) *calcs1 << "." ;
|
||||
if (i != 0) *calcs2 << "*" ;
|
||||
*calcs1 << LD << "(" << childLinks[i]->getSource()->getLabel();
|
||||
*calcs1 << "-->" << x->getLabel() << ")" ;
|
||||
*calcs1 << "[" << x->getDomain()[xi] << "]" ;
|
||||
*calcs2 << childLinks[i]->getMessage()[xi];
|
||||
}
|
||||
}
|
||||
M(x)->setLambdaValue (xi, product);
|
||||
if (DL >= 5) {
|
||||
cout << " " << LD << "(" << x->getLabel() << ")" ;
|
||||
cout << "[" << x->getDomain()[xi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
if (childLinks.size() == 1) {
|
||||
cout << " = " << product << endl;
|
||||
} else {
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << product << endl;
|
||||
}
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
BPSolver::calculateNextPiMessage (Edge* edge)
|
||||
{
|
||||
// πX(Zi)
|
||||
BayesNode* z = edge->getSource();
|
||||
BayesNode* x = edge->getDestination();
|
||||
ParamSet zxPiNextMessage (z->getDomainSize());
|
||||
CEdgeSet zChildLinks = M(z)->getIncomingChildLinks();
|
||||
stringstream* calcs1 = 0;
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
|
||||
for (unsigned zi = 0; zi < z->getDomainSize(); zi++) {
|
||||
double product = M(z)->getPiValue (zi);
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
*calcs1 << PI << "(" << z->getLabel() << ")";
|
||||
*calcs1 << "[" << z->getDomain()[zi] << "]" ;
|
||||
*calcs2 << product;
|
||||
}
|
||||
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
||||
assert (zChildLinks[i]->getDestination() == z);
|
||||
if (zChildLinks[i]->getSource() != x) {
|
||||
product *= zChildLinks[i]->getMessage()[zi];
|
||||
if (DL >= 5) {
|
||||
*calcs1 << "." << LD << "(" << zChildLinks[i]->getSource()->getLabel();
|
||||
*calcs1 << "-->" << z->getLabel() << ")";
|
||||
*calcs1 << "[" << z->getDomain()[zi] + "]" ;
|
||||
*calcs2 << " * " << zChildLinks[i]->getMessage()[zi];
|
||||
}
|
||||
}
|
||||
}
|
||||
zxPiNextMessage[zi] = product;
|
||||
if (DL >= 5) {
|
||||
cout << " " << PI << "(" << z->getLabel();
|
||||
cout << "-->" << x->getLabel() << ")" ;
|
||||
cout << "[" << z->getDomain()[zi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
if (zChildLinks.size() == 1) {
|
||||
cout << " = " << product << endl;
|
||||
} else {
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << product << endl;
|
||||
}
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
return zxPiNextMessage;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
BPSolver::calculateNextLambdaMessage (Edge* edge)
|
||||
{
|
||||
// λY(Xi)
|
||||
BayesNode* y = edge->getSource();
|
||||
BayesNode* x = edge->getDestination();
|
||||
if (!M(y)->receivedBottomInfluence()) {
|
||||
//cout << "returning 1" << endl;
|
||||
//return edge->getMessage();
|
||||
}
|
||||
if (x->hasEvidence()) {
|
||||
//cout << "returning 2" << endl;
|
||||
//return edge->getMessage();
|
||||
}
|
||||
ParamSet yxLambdaNextMessage (x->getDomainSize());
|
||||
CEdgeSet yParentLinks = M(y)->getIncomingParentLinks();
|
||||
const vector<CptEntry>& allEntries = y->getCptEntries();
|
||||
int parentIndex = y->getIndexOfParent (x);
|
||||
stringstream* calcs1 = 0;
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
vector<CptEntry> entries;
|
||||
DConstraint constr = make_pair (parentIndex, 0);
|
||||
for (unsigned i = 0; i < allEntries.size(); i++) {
|
||||
if (allEntries[i].matchConstraints(constr)) {
|
||||
entries.push_back (allEntries[i]);
|
||||
}
|
||||
}
|
||||
|
||||
ParamSet messageProducts (entries.size());
|
||||
for (unsigned k = 0; k < entries.size(); k++) {
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
double messageProduct = 1.0;
|
||||
const DConf& conf = entries[k].getDomainConfiguration();
|
||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
||||
assert (yParentLinks[i]->getDestination() == y);
|
||||
if (yParentLinks[i]->getSource() != x) {
|
||||
if (DL >= 5) {
|
||||
if (messageProduct != 1.0) *calcs1 << "*" ;
|
||||
if (messageProduct != 1.0) *calcs2 << "*" ;
|
||||
*calcs1 << PI << "(" << yParentLinks[i]->getSource()->getLabel();
|
||||
*calcs1 << "-->" << y->getLabel() << ")" ;
|
||||
*calcs1 << "[" ;
|
||||
*calcs1 << yParentLinks[i]->getSource()->getDomain()[conf[i]];
|
||||
*calcs1 << "]" ;
|
||||
*calcs2 << yParentLinks[i]->getMessage()[conf[i]];
|
||||
}
|
||||
messageProduct *= yParentLinks[i]->getMessage()[conf[i]];
|
||||
}
|
||||
}
|
||||
messageProducts[k] = messageProduct;
|
||||
if (DL >= 5) {
|
||||
cout << " mp" << k;
|
||||
cout << " = " << (*calcs1).str();
|
||||
if (yParentLinks.size() == 1) {
|
||||
cout << 1 << endl;
|
||||
} else if (yParentLinks.size() == 2) {
|
||||
cout << " = " << messageProduct << endl;
|
||||
} else {
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << messageProduct << endl;
|
||||
}
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned xi = 0; xi < x->getDomainSize(); xi++) {
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
vector<CptEntry> entries;
|
||||
DConstraint constr = make_pair (parentIndex, xi);
|
||||
for (unsigned i = 0; i < allEntries.size(); i++) {
|
||||
if (allEntries[i].matchConstraints(constr)) {
|
||||
entries.push_back (allEntries[i]);
|
||||
}
|
||||
}
|
||||
double outerSum = 0.0;
|
||||
for (unsigned yi = 0; yi < y->getDomainSize(); yi++) {
|
||||
if (DL >= 5) {
|
||||
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
|
||||
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
|
||||
}
|
||||
double innerSum = 0.0;
|
||||
for (unsigned k = 0; k < entries.size(); k++) {
|
||||
if (DL >= 5) {
|
||||
if (k != 0) *calcs1 << " + " ;
|
||||
if (k != 0) *calcs2 << " + " ;
|
||||
*calcs1 << y->cptEntryToString (yi, entries[k]);
|
||||
*calcs1 << ".mp" << k;
|
||||
*calcs2 << y->getProbability (yi, entries[k]);
|
||||
*calcs2 << "*" << messageProducts[k];
|
||||
}
|
||||
innerSum += y->getProbability (yi, entries[k]) * messageProducts[k];
|
||||
}
|
||||
outerSum += innerSum * M(y)->getLambdaValue (yi);
|
||||
if (DL >= 5) {
|
||||
*calcs1 << "}." << LD << "(" << y->getLabel() << ")" ;
|
||||
*calcs1 << "[" << y->getDomain()[yi] << "]";
|
||||
*calcs2 << "}*" << M(y)->getLambdaValue (yi);
|
||||
}
|
||||
}
|
||||
yxLambdaNextMessage[xi] = outerSum;
|
||||
if (DL >= 5) {
|
||||
cout << " " << LD << "(" << y->getLabel();
|
||||
cout << "-->" << x->getLabel() << ")" ;
|
||||
cout << "[" << x->getDomain()[xi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << outerSum << endl;
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
return yxLambdaNextMessage;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
BPSolver::getJointByJunctionNode (const VidSet& jointVids) const
|
||||
{
|
||||
BnNodeSet jointVars;
|
||||
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||
jointVars.push_back (bn_->getBayesNode (jointVids[i]));
|
||||
}
|
||||
|
||||
BayesNet* mrn = bn_->getMinimalRequesiteNetwork (jointVids);
|
||||
|
||||
BnNodeSet parents;
|
||||
unsigned dsize = 1;
|
||||
for (unsigned i = 0; i < jointVars.size(); i++) {
|
||||
parents.push_back (mrn->getBayesNode (jointVids[i]));
|
||||
dsize *= jointVars[i]->getDomainSize();
|
||||
}
|
||||
|
||||
unsigned nParams = dsize * dsize;
|
||||
ParamSet params (nParams);
|
||||
|
||||
for (unsigned i = 0; i < nParams; i++) {
|
||||
unsigned row = i / dsize;
|
||||
unsigned col = i % dsize;
|
||||
if (row == col) {
|
||||
params[i] = 1;
|
||||
} else {
|
||||
params[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned maxVid = std::numeric_limits<unsigned>::max();
|
||||
Distribution* dist = new Distribution (params);
|
||||
|
||||
mrn->addNode (maxVid, dsize, NO_EVIDENCE, parents, dist);
|
||||
mrn->setIndexes();
|
||||
|
||||
BPSolver solver (*mrn);
|
||||
solver.runSolver();
|
||||
|
||||
const ParamSet& results = solver.getPosterioriOf (maxVid);
|
||||
|
||||
delete mrn;
|
||||
delete dist;
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
BPSolver::getJointByChainRule (const VidSet& jointVids) const
|
||||
{
|
||||
BnNodeSet jointVars;
|
||||
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||
jointVars.push_back (bn_->getBayesNode (jointVids[i]));
|
||||
}
|
||||
|
||||
BayesNet* mrn = bn_->getMinimalRequesiteNetwork (jointVids[0]);
|
||||
BPSolver solver (*mrn);
|
||||
solver.runSolver();
|
||||
ParamSet prevBeliefs = solver.getPosterioriOf (jointVids[0]);
|
||||
delete mrn;
|
||||
|
||||
VarSet observedVars = {jointVars[0]};
|
||||
|
||||
for (unsigned i = 1; i < jointVids.size(); i++) {
|
||||
mrn = bn_->getMinimalRequesiteNetwork (jointVids[i]);
|
||||
ParamSet newBeliefs;
|
||||
vector<DConf> confs =
|
||||
Util::getDomainConfigurations (observedVars);
|
||||
for (unsigned j = 0; j < confs.size(); j++) {
|
||||
for (unsigned k = 0; k < observedVars.size(); k++) {
|
||||
if (!observedVars[k]->hasEvidence()) {
|
||||
BayesNode* node = mrn->getBayesNode (observedVars[k]->getVarId());
|
||||
if (node) {
|
||||
node->setEvidence (confs[j][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
BPSolver solver (*mrn);
|
||||
solver.runSolver();
|
||||
ParamSet beliefs = solver.getPosterioriOf (jointVids[i]);
|
||||
for (unsigned k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
}
|
||||
|
||||
int count = -1;
|
||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->getDomainSize() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
observedVars.push_back (jointVars[i]);
|
||||
delete mrn;
|
||||
}
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BPSolver::printMessageStatusOf (const BayesNode* var) const
|
||||
{
|
||||
cout << left;
|
||||
cout << setw (10) << "domain" ;
|
||||
cout << setw (20) << PI << "(" + var->getLabel() + ")" ;
|
||||
cout << setw (20) << LD << "(" + var->getLabel() + ")" ;
|
||||
cout << setw (16) << "belief" ;
|
||||
cout << endl;
|
||||
cout << "--------------------------------" ;
|
||||
cout << "--------------------------------" ;
|
||||
cout << endl;
|
||||
|
||||
BPNodeInfo* x = M(var);
|
||||
ParamSet& piVals = x->getPiValues();
|
||||
ParamSet& ldVals = x->getLambdaValues();
|
||||
ParamSet beliefs = x->getBeliefs();
|
||||
const Domain& domain = var->getDomain();
|
||||
CBnNodeSet& childs = var->getChilds();
|
||||
|
||||
for (unsigned xi = 0; xi < var->getDomainSize(); xi++) {
|
||||
cout << setw (10) << domain[xi];
|
||||
cout << setw (19) << piVals[xi];
|
||||
cout << setw (19) << ldVals[xi];
|
||||
cout.precision (PRECISION);
|
||||
cout << setw (16) << beliefs[xi];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
if (childs.size() > 0) {
|
||||
string s = "(" + var->getLabel() + ")" ;
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
cout << setw (10) << "domain" ;
|
||||
cout << setw (28) << PI + childs[j]->getLabel() + s;
|
||||
cout << setw (28) << LD + childs[j]->getLabel() + s;
|
||||
cout << endl;
|
||||
cout << "--------------------------------" ;
|
||||
cout << "--------------------------------" ;
|
||||
cout << endl;
|
||||
/* FIXME
|
||||
const ParamSet& piMessage = x->getPiMessage (childs[j]);
|
||||
const ParamSet& lambdaMessage = x->getLambdaMessage (childs[j]);
|
||||
for (unsigned xi = 0; xi < var->getDomainSize(); xi++) {
|
||||
cout << setw (10) << domain[xi];
|
||||
cout.precision (PRECISION);
|
||||
cout << setw (27) << piMessage[xi];
|
||||
cout.precision (PRECISION);
|
||||
cout << setw (27) << lambdaMessage[xi];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
*/
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BPSolver::printAllMessageStatus (void) const
|
||||
{
|
||||
CBnNodeSet nodes = bn_->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
printMessageStatusOf (nodes[i]);
|
||||
}
|
||||
}
|
||||
|
@ -1,192 +0,0 @@
|
||||
#ifndef BP_BP_SOLVER_H
|
||||
#define BP_BP_SOLVER_H
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "BayesNet.h"
|
||||
#include "BPNodeInfo.h"
|
||||
#include "Shared.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class BPNodeInfo;
|
||||
|
||||
static const string PI = "pi" ;
|
||||
static const string LD = "ld" ;
|
||||
|
||||
|
||||
enum MessageType {PI_MSG, LAMBDA_MSG};
|
||||
enum JointCalcType {CHAIN_RULE, JUNCTION_NODE};
|
||||
|
||||
class Edge
|
||||
{
|
||||
public:
|
||||
Edge (BayesNode* s, BayesNode* d, MessageType t)
|
||||
{
|
||||
source_ = s;
|
||||
destin_ = d;
|
||||
type_ = t;
|
||||
if (type_ == PI_MSG) {
|
||||
currMsg_.resize (s->getDomainSize(), 1);
|
||||
nextMsg_.resize (s->getDomainSize(), 1);
|
||||
} else {
|
||||
currMsg_.resize (d->getDomainSize(), 1);
|
||||
nextMsg_.resize (d->getDomainSize(), 1);
|
||||
}
|
||||
msgSended_ = false;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
//void setMessage (ParamSet msg)
|
||||
//{
|
||||
// Util::normalize (msg);
|
||||
// residual_ = Util::getMaxNorm (currMsg_, msg);
|
||||
// currMsg_ = msg;
|
||||
//}
|
||||
|
||||
void setNextMessage (CParamSet msg)
|
||||
{
|
||||
nextMsg_ = msg;
|
||||
Util::normalize (nextMsg_);
|
||||
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
|
||||
}
|
||||
|
||||
void updateMessage (void)
|
||||
{
|
||||
currMsg_ = nextMsg_;
|
||||
if (DL >= 3) {
|
||||
cout << "updating " << toString() << endl;
|
||||
}
|
||||
msgSended_ = true;
|
||||
}
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
if (type_ == PI_MSG) {
|
||||
ss << PI;
|
||||
} else if (type_ == LAMBDA_MSG) {
|
||||
ss << LD;
|
||||
} else {
|
||||
abort();
|
||||
}
|
||||
ss << "(" << source_->getLabel();
|
||||
ss << " --> " << destin_->getLabel() << ")" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
BayesNode* getSource (void) const { return source_; }
|
||||
BayesNode* getDestination (void) const { return destin_; }
|
||||
MessageType getMessageType (void) const { return type_; }
|
||||
CParamSet getMessage (void) const { return currMsg_; }
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
double getResidual (void) const { return residual_; }
|
||||
void clearResidual (void) { residual_ = 0.0; }
|
||||
|
||||
private:
|
||||
BayesNode* source_;
|
||||
BayesNode* destin_;
|
||||
MessageType type_;
|
||||
ParamSet currMsg_;
|
||||
ParamSet nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
};
|
||||
|
||||
|
||||
class BPSolver : public Solver
|
||||
{
|
||||
public:
|
||||
BPSolver (const BayesNet&);
|
||||
~BPSolver (void);
|
||||
|
||||
void runSolver (void);
|
||||
ParamSet getPosterioriOf (Vid) const;
|
||||
ParamSet getJointDistributionOf (const VidSet&);
|
||||
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BPSolver);
|
||||
|
||||
void initializeSolver (void);
|
||||
void runPolyTreeSolver (void);
|
||||
void runLoopySolver (void);
|
||||
void maxResidualSchedule (void);
|
||||
bool converged (void) const;
|
||||
void updatePiValues (BayesNode*);
|
||||
void updateLambdaValues (BayesNode*);
|
||||
ParamSet calculateNextLambdaMessage (Edge* edge);
|
||||
ParamSet calculateNextPiMessage (Edge* edge);
|
||||
ParamSet getJointByJunctionNode (const VidSet&) const;
|
||||
ParamSet getJointByChainRule (const VidSet&) const;
|
||||
void printMessageStatusOf (const BayesNode*) const;
|
||||
void printAllMessageStatus (void) const;
|
||||
|
||||
ParamSet getMessage (Edge* edge)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
cout << " calculating " << edge->toString() << endl;
|
||||
}
|
||||
if (edge->getMessageType() == PI_MSG) {
|
||||
return calculateNextPiMessage (edge);
|
||||
} else if (edge->getMessageType() == LAMBDA_MSG) {
|
||||
return calculateNextLambdaMessage (edge);
|
||||
} else {
|
||||
abort();
|
||||
}
|
||||
return ParamSet();
|
||||
}
|
||||
|
||||
void updateValues (Edge* edge)
|
||||
{
|
||||
if (!edge->getDestination()->hasEvidence()) {
|
||||
if (edge->getMessageType() == PI_MSG) {
|
||||
updatePiValues (edge->getDestination());
|
||||
} else if (edge->getMessageType() == LAMBDA_MSG) {
|
||||
updateLambdaValues (edge->getDestination());
|
||||
} else {
|
||||
abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BPNodeInfo* M (const BayesNode* node) const
|
||||
{
|
||||
assert (node);
|
||||
assert (node == bn_->getBayesNode (node->getVarId()));
|
||||
assert (node->getIndex() < nodesI_.size());
|
||||
return nodesI_[node->getIndex()];
|
||||
}
|
||||
|
||||
const BayesNet* bn_;
|
||||
vector<BPNodeInfo*> nodesI_;
|
||||
unsigned nIter_;
|
||||
vector<Edge*> links_;
|
||||
bool useAlwaysLoopySolver_;
|
||||
JointCalcType jointCalcType_;
|
||||
|
||||
struct compare
|
||||
{
|
||||
inline bool operator() (const Edge* e1, const Edge* e2)
|
||||
{
|
||||
return e1->getResidual() > e2->getResidual();
|
||||
}
|
||||
};
|
||||
|
||||
typedef multiset<Edge*, compare> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
|
||||
typedef map<Edge*, SortedOrder::iterator> EdgeMap;
|
||||
EdgeMap edgeMap_;
|
||||
|
||||
};
|
||||
|
||||
#endif //BP_BP_SOLVER_H
|
||||
|
@ -4,111 +4,12 @@
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "xmlParser/xmlParser.h"
|
||||
|
||||
#include "BayesNet.h"
|
||||
|
||||
|
||||
BayesNet::BayesNet (const char* fileName)
|
||||
{
|
||||
map<string, Domain> domains;
|
||||
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF");
|
||||
// only the first network is parsed, others are ignored
|
||||
XMLNode xNode = xMainNode.getChildNode ("NETWORK");
|
||||
unsigned nVars = xNode.nChildNode ("VARIABLE");
|
||||
for (unsigned i = 0; i < nVars; i++) {
|
||||
XMLNode var = xNode.getChildNode ("VARIABLE", i);
|
||||
string type = var.getAttribute ("TYPE");
|
||||
if (type != "nature") {
|
||||
cerr << "error: only \"nature\" variables are supported" << endl;
|
||||
abort();
|
||||
}
|
||||
Domain domain;
|
||||
string varLabel = var.getChildNode("NAME").getText();
|
||||
unsigned dsize = var.nChildNode ("OUTCOME");
|
||||
for (unsigned j = 0; j < dsize; j++) {
|
||||
if (var.getChildNode("OUTCOME", j).getText() == 0) {
|
||||
stringstream ss;
|
||||
ss << j + 1;
|
||||
domain.push_back (ss.str());
|
||||
} else {
|
||||
domain.push_back (var.getChildNode("OUTCOME", j).getText());
|
||||
}
|
||||
}
|
||||
domains.insert (make_pair (varLabel, domain));
|
||||
}
|
||||
|
||||
unsigned nDefs = xNode.nChildNode ("DEFINITION");
|
||||
if (nVars != nDefs) {
|
||||
cerr << "error: different number of variables and definitions" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
queue<unsigned> indexes;
|
||||
for (unsigned i = 0; i < nDefs; i++) {
|
||||
indexes.push (i);
|
||||
}
|
||||
|
||||
while (!indexes.empty()) {
|
||||
unsigned index = indexes.front();
|
||||
indexes.pop();
|
||||
XMLNode def = xNode.getChildNode ("DEFINITION", index);
|
||||
string varLabel = def.getChildNode("FOR").getText();
|
||||
map<string, Domain>::const_iterator iter;
|
||||
iter = domains.find (varLabel);
|
||||
if (iter == domains.end()) {
|
||||
cerr << "error: unknow variable `" << varLabel << "'" << endl;
|
||||
abort();
|
||||
}
|
||||
bool processItLatter = false;
|
||||
BnNodeSet parents;
|
||||
unsigned nParams = iter->second.size();
|
||||
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
|
||||
string parentLabel = def.getChildNode("GIVEN", j).getText();
|
||||
BayesNode* parentNode = getBayesNode (parentLabel);
|
||||
if (parentNode) {
|
||||
nParams *= parentNode->getDomainSize();
|
||||
parents.push_back (parentNode);
|
||||
}
|
||||
else {
|
||||
iter = domains.find (parentLabel);
|
||||
if (iter == domains.end()) {
|
||||
cerr << "error: unknow parent `" << parentLabel << "'" << endl;
|
||||
abort();
|
||||
} else {
|
||||
// this definition contains a parent that doesn't
|
||||
// have a corresponding bayesian node instance yet,
|
||||
// so process this definition latter
|
||||
indexes.push (index);
|
||||
processItLatter = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!processItLatter) {
|
||||
unsigned count = 0;
|
||||
ParamSet params (nParams);
|
||||
stringstream s (def.getChildNode("TABLE").getText());
|
||||
while (!s.eof() && count < nParams) {
|
||||
s >> params[count];
|
||||
count ++;
|
||||
}
|
||||
if (count != nParams) {
|
||||
cerr << "error: invalid number of parameters " ;
|
||||
cerr << "for variable `" << varLabel << "'" << endl;
|
||||
abort();
|
||||
}
|
||||
params = reorderParameters (params, iter->second.size());
|
||||
addNode (varLabel, iter->second, parents, params);
|
||||
}
|
||||
}
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNet::~BayesNet (void)
|
||||
{
|
||||
@ -119,26 +20,130 @@ BayesNet::~BayesNet (void)
|
||||
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::addNode (Vid vid)
|
||||
void
|
||||
BayesNet::readFromBifFormat (const char* fileName)
|
||||
{
|
||||
indexMap_.insert (make_pair (vid, nodes_.size()));
|
||||
nodes_.push_back (new BayesNode (vid));
|
||||
return nodes_.back();
|
||||
XMLNode xMainNode = XMLNode::openFileHelper (fileName, "BIF");
|
||||
// only the first network is parsed, others are ignored
|
||||
XMLNode xNode = xMainNode.getChildNode ("NETWORK");
|
||||
unsigned nVars = xNode.nChildNode ("VARIABLE");
|
||||
for (unsigned i = 0; i < nVars; i++) {
|
||||
XMLNode var = xNode.getChildNode ("VARIABLE", i);
|
||||
if (string (var.getAttribute ("TYPE")) != "nature") {
|
||||
cerr << "error: only \"nature\" variables are supported" << endl;
|
||||
abort();
|
||||
}
|
||||
States states;
|
||||
string label = var.getChildNode("NAME").getText();
|
||||
unsigned nrStates = var.nChildNode ("OUTCOME");
|
||||
for (unsigned j = 0; j < nrStates; j++) {
|
||||
if (var.getChildNode("OUTCOME", j).getText() == 0) {
|
||||
stringstream ss;
|
||||
ss << j + 1;
|
||||
states.push_back (ss.str());
|
||||
} else {
|
||||
states.push_back (var.getChildNode("OUTCOME", j).getText());
|
||||
}
|
||||
}
|
||||
addNode (label, states);
|
||||
}
|
||||
|
||||
unsigned nDefs = xNode.nChildNode ("DEFINITION");
|
||||
if (nVars != nDefs) {
|
||||
cerr << "error: different number of variables and definitions" << endl;
|
||||
abort();
|
||||
}
|
||||
for (unsigned i = 0; i < nDefs; i++) {
|
||||
XMLNode def = xNode.getChildNode ("DEFINITION", i);
|
||||
string label = def.getChildNode("FOR").getText();
|
||||
BayesNode* node = getBayesNode (label);
|
||||
if (!node) {
|
||||
cerr << "error: unknow variable `" << label << "'" << endl;
|
||||
abort();
|
||||
}
|
||||
BnNodeSet parents;
|
||||
unsigned nParams = node->nrStates();
|
||||
for (int j = 0; j < def.nChildNode ("GIVEN"); j++) {
|
||||
string parentLabel = def.getChildNode("GIVEN", j).getText();
|
||||
BayesNode* parentNode = getBayesNode (parentLabel);
|
||||
if (!parentNode) {
|
||||
cerr << "error: unknow variable `" << parentLabel << "'" << endl;
|
||||
abort();
|
||||
}
|
||||
nParams *= parentNode->nrStates();
|
||||
parents.push_back (parentNode);
|
||||
}
|
||||
node->setParents (parents);
|
||||
unsigned count = 0;
|
||||
ParamSet params (nParams);
|
||||
stringstream s (def.getChildNode("TABLE").getText());
|
||||
while (!s.eof() && count < nParams) {
|
||||
s >> params[count];
|
||||
count ++;
|
||||
}
|
||||
if (count != nParams) {
|
||||
cerr << "error: invalid number of parameters " ;
|
||||
cerr << "for variable `" << label << "'" << endl;
|
||||
abort();
|
||||
}
|
||||
params = reorderParameters (params, node->nrStates());
|
||||
Distribution* dist = new Distribution (params);
|
||||
node->setDistribution (dist);
|
||||
addDistribution (dist);
|
||||
}
|
||||
|
||||
setIndexes();
|
||||
if (NSPACE == NumberSpace::LOGARITHM) {
|
||||
distributionsToLogs();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::addNode (BayesNode* n)
|
||||
{
|
||||
indexMap_.insert (make_pair (n->varId(), nodes_.size()));
|
||||
nodes_.push_back (n);
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::addNode (Vid vid,
|
||||
BayesNet::addNode (string label, const States& states)
|
||||
{
|
||||
VarId vid = nodes_.size();
|
||||
indexMap_.insert (make_pair (vid, nodes_.size()));
|
||||
GraphicalModel::addVariableInformation (vid, label, states);
|
||||
BayesNode* node = new BayesNode (VarNode (vid, states.size()));
|
||||
nodes_.push_back (node);
|
||||
return node;
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::addNode (VarId vid,
|
||||
unsigned dsize,
|
||||
int evidence,
|
||||
BnNodeSet& parents,
|
||||
Distribution* dist)
|
||||
{
|
||||
indexMap_.insert (make_pair (vid, nodes_.size()));
|
||||
nodes_.push_back (new BayesNode (
|
||||
vid, dsize, evidence, parents, dist));
|
||||
nodes_.push_back (new BayesNode (vid, dsize, evidence, parents, dist));
|
||||
return nodes_.back();
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::addNode (VarId vid,
|
||||
unsigned dsize,
|
||||
int evidence,
|
||||
Distribution* dist)
|
||||
{
|
||||
indexMap_.insert (make_pair (vid, nodes_.size()));
|
||||
nodes_.push_back (new BayesNode (vid, dsize, evidence, dist));
|
||||
return nodes_.back();
|
||||
}
|
||||
|
||||
@ -146,14 +151,16 @@ BayesNet::addNode (Vid vid,
|
||||
|
||||
BayesNode*
|
||||
BayesNet::addNode (string label,
|
||||
Domain domain,
|
||||
States states,
|
||||
BnNodeSet& parents,
|
||||
ParamSet& params)
|
||||
{
|
||||
indexMap_.insert (make_pair (nodes_.size(), nodes_.size()));
|
||||
VarId vid = nodes_.size();
|
||||
indexMap_.insert (make_pair (vid, nodes_.size()));
|
||||
GraphicalModel::addVariableInformation (vid, label, states);
|
||||
Distribution* dist = new Distribution (params);
|
||||
BayesNode* node = new BayesNode (
|
||||
nodes_.size(), label, domain, parents, dist);
|
||||
vid, states.size(), NO_EVIDENCE, parents, dist);
|
||||
dists_.push_back (dist);
|
||||
nodes_.push_back (node);
|
||||
return node;
|
||||
@ -162,7 +169,7 @@ BayesNet::addNode (string label,
|
||||
|
||||
|
||||
BayesNode*
|
||||
BayesNet::getBayesNode (Vid vid) const
|
||||
BayesNet::getBayesNode (VarId vid) const
|
||||
{
|
||||
IndexMap::const_iterator it = indexMap_.find (vid);
|
||||
if (it == indexMap_.end()) {
|
||||
@ -179,7 +186,7 @@ BayesNet::getBayesNode (string label) const
|
||||
{
|
||||
BayesNode* node = 0;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (nodes_[i]->getLabel() == label) {
|
||||
if (nodes_[i]->label() == label) {
|
||||
node = nodes_[i];
|
||||
break;
|
||||
}
|
||||
@ -190,10 +197,25 @@ BayesNet::getBayesNode (string label) const
|
||||
|
||||
|
||||
|
||||
Variable*
|
||||
BayesNet::getVariable (Vid vid) const
|
||||
VarNode*
|
||||
BayesNet::getVariableNode (VarId vid) const
|
||||
{
|
||||
return getBayesNode (vid);
|
||||
BayesNode* node = getBayesNode (vid);
|
||||
assert (node);
|
||||
return node;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
VarNodes
|
||||
BayesNet::getVariableNodes (void) const
|
||||
{
|
||||
VarNodes vars;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
vars.push_back (nodes_[i]);
|
||||
}
|
||||
return vars;
|
||||
}
|
||||
|
||||
|
||||
@ -230,7 +252,7 @@ BayesNet::getBayesNodes (void) const
|
||||
|
||||
|
||||
unsigned
|
||||
BayesNet::getNumberOfNodes (void) const
|
||||
BayesNet::nrNodes (void) const
|
||||
{
|
||||
return nodes_.size();
|
||||
}
|
||||
@ -265,37 +287,25 @@ BayesNet::getLeafNodes (void) const
|
||||
|
||||
|
||||
|
||||
VarSet
|
||||
BayesNet::getVariables (void) const
|
||||
BayesNet*
|
||||
BayesNet::getMinimalRequesiteNetwork (VarId vid) const
|
||||
{
|
||||
VarSet vars;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
vars.push_back (nodes_[i]);
|
||||
}
|
||||
return vars;
|
||||
return getMinimalRequesiteNetwork (VarIdSet() = {vid});
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNet*
|
||||
BayesNet::getMinimalRequesiteNetwork (Vid vid) const
|
||||
{
|
||||
return getMinimalRequesiteNetwork (VidSet() = {vid});
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNet*
|
||||
BayesNet::getMinimalRequesiteNetwork (const VidSet& queryVids) const
|
||||
BayesNet::getMinimalRequesiteNetwork (const VarIdSet& queryVarIds) const
|
||||
{
|
||||
BnNodeSet queryVars;
|
||||
for (unsigned i = 0; i < queryVids.size(); i++) {
|
||||
assert (getBayesNode (queryVids[i]));
|
||||
queryVars.push_back (getBayesNode (queryVids[i]));
|
||||
for (unsigned i = 0; i < queryVarIds.size(); i++) {
|
||||
assert (getBayesNode (queryVarIds[i]));
|
||||
queryVars.push_back (getBayesNode (queryVarIds[i]));
|
||||
}
|
||||
// cout << "query vars: " ;
|
||||
// for (unsigned i = 0; i < queryVars.size(); i++) {
|
||||
// cout << queryVars[i]->getLabel() << " " ;
|
||||
// cout << queryVars[i]->label() << " " ;
|
||||
// }
|
||||
// cout << endl;
|
||||
|
||||
@ -344,7 +354,7 @@ BayesNet::getMinimalRequesiteNetwork (const VidSet& queryVids) const
|
||||
cout << "----------------------------------------------------------" ;
|
||||
cout << endl;
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
cout << nodes_[i]->getLabel() << ":\t\t" ;
|
||||
cout << nodes_[i]->label() << ":\t\t" ;
|
||||
if (states[i]) {
|
||||
states[i]->markedOnTop ? cout << "yes\t" : cout << "no\t" ;
|
||||
states[i]->markedOnBottom ? cout << "yes\t" : cout << "no\t" ;
|
||||
@ -374,6 +384,8 @@ void
|
||||
BayesNet::constructGraph (BayesNet* bn,
|
||||
const vector<StateInfo*>& states) const
|
||||
{
|
||||
BnNodeSet mrnNodes;
|
||||
vector<VarIdSet> parents;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
bool isRequired = false;
|
||||
if (states[i]) {
|
||||
@ -382,35 +394,28 @@ BayesNet::constructGraph (BayesNet* bn,
|
||||
states[i]->markedOnTop;
|
||||
}
|
||||
if (isRequired) {
|
||||
BnNodeSet parents;
|
||||
parents.push_back (VarIdSet());
|
||||
if (states[i]->markedOnTop) {
|
||||
const BnNodeSet& ps = nodes_[i]->getParents();
|
||||
for (unsigned j = 0; j < ps.size(); j++) {
|
||||
BayesNode* parent = bn->getBayesNode (ps[j]->getVarId());
|
||||
if (!parent) {
|
||||
parent = bn->addNode (ps[j]->getVarId());
|
||||
}
|
||||
parents.push_back (parent);
|
||||
parents.back().push_back (ps[j]->varId());
|
||||
}
|
||||
}
|
||||
BayesNode* node = bn->getBayesNode (nodes_[i]->getVarId());
|
||||
if (node) {
|
||||
node->setData (nodes_[i]->getDomainSize(),
|
||||
nodes_[i]->getEvidence(), parents,
|
||||
assert (bn->getBayesNode (nodes_[i]->varId()) == 0);
|
||||
BayesNode* mrnNode = bn->addNode (nodes_[i]->varId(),
|
||||
nodes_[i]->nrStates(),
|
||||
nodes_[i]->getEvidence(),
|
||||
nodes_[i]->getDistribution());
|
||||
} else {
|
||||
node = bn->addNode (nodes_[i]->getVarId(),
|
||||
nodes_[i]->getDomainSize(),
|
||||
nodes_[i]->getEvidence(), parents,
|
||||
nodes_[i]->getDistribution());
|
||||
}
|
||||
if (nodes_[i]->hasDomain()) {
|
||||
node->setDomain (nodes_[i]->getDomain());
|
||||
}
|
||||
if (nodes_[i]->hasLabel()) {
|
||||
node->setLabel (nodes_[i]->getLabel());
|
||||
mrnNodes.push_back (mrnNode);
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < mrnNodes.size(); i++) {
|
||||
BnNodeSet ps;
|
||||
for (unsigned j = 0; j < parents[i].size(); j++) {
|
||||
assert (bn->getBayesNode (parents[i][j]) != 0);
|
||||
ps.push_back (bn->getBayesNode (parents[i][j]));
|
||||
}
|
||||
mrnNodes[i]->setParents (ps);
|
||||
}
|
||||
bn->setIndexes();
|
||||
}
|
||||
@ -418,7 +423,7 @@ BayesNet::constructGraph (BayesNet* bn,
|
||||
|
||||
|
||||
bool
|
||||
BayesNet::isSingleConnected (void) const
|
||||
BayesNet::isPolyTree (void) const
|
||||
{
|
||||
return !containsUndirectedCycle();
|
||||
}
|
||||
@ -435,6 +440,16 @@ BayesNet::setIndexes (void)
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::distributionsToLogs (void)
|
||||
{
|
||||
for (unsigned i = 0; i < dists_.size(); i++) {
|
||||
Util::toLog (dists_[i]->params);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNet::freeDistributions (void)
|
||||
{
|
||||
@ -456,9 +471,9 @@ BayesNet::printGraphicalModel (void) const
|
||||
|
||||
|
||||
void
|
||||
BayesNet::exportToDotFormat (const char* fileName,
|
||||
BayesNet::exportToGraphViz (const char* fileName,
|
||||
bool showNeighborless,
|
||||
CVidSet& highlightVids) const
|
||||
const VarIdSet& highlightVarIds) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
@ -467,27 +482,32 @@ BayesNet::exportToDotFormat (const char* fileName,
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "digraph \"" << fileName << "\" {" << endl;
|
||||
|
||||
out << "digraph {" << endl;
|
||||
out << "ranksep=1" << endl;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (showNeighborless || nodes_[i]->hasNeighbors()) {
|
||||
out << '"' << nodes_[i]->getLabel() << '"' ;
|
||||
out << nodes_[i]->varId() ;
|
||||
if (nodes_[i]->hasEvidence()) {
|
||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||
out << " [" ;
|
||||
out << "label=\"" << nodes_[i]->label() << "\"," ;
|
||||
out << "style=filled, fillcolor=yellow" ;
|
||||
out << "]" ;
|
||||
} else {
|
||||
out << " [" ;
|
||||
out << "label=\"" << nodes_[i]->label() << "\"" ;
|
||||
out << "]" ;
|
||||
}
|
||||
out << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < highlightVids.size(); i++) {
|
||||
BayesNode* node = getBayesNode (highlightVids[i]);
|
||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
||||
BayesNode* node = getBayesNode (highlightVarIds[i]);
|
||||
if (node) {
|
||||
out << '"' << node->getLabel() << '"' ;
|
||||
// out << " [shape=polygon, sides=6]" << endl;
|
||||
out << node->varId() ;
|
||||
out << " [shape=box3d]" << endl;
|
||||
} else {
|
||||
cout << "error: invalid variable id: " << highlightVids[i] << endl;
|
||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
@ -495,8 +515,7 @@ BayesNet::exportToDotFormat (const char* fileName,
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
const BnNodeSet& childs = nodes_[i]->getChilds();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
out << '"' << nodes_[i]->getLabel() << '"' << " -> " ;
|
||||
out << '"' << childs[j]->getLabel() << '"' << endl;
|
||||
out << nodes_[i]->varId() << " -> " << childs[j]->varId() << " [style=bold]" << endl ;
|
||||
}
|
||||
}
|
||||
|
||||
@ -521,24 +540,24 @@ BayesNet::exportToBifFormat (const char* fileName) const
|
||||
out << "<NAME>" << fileName << "</NAME>" << endl << endl;
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
out << "<VARIABLE TYPE=\"nature\">" << endl;
|
||||
out << "\t<NAME>" << nodes_[i]->getLabel() << "</NAME>" << endl;
|
||||
const Domain& domain = nodes_[i]->getDomain();
|
||||
for (unsigned j = 0; j < domain.size(); j++) {
|
||||
out << "\t<OUTCOME>" << domain[j] << "</OUTCOME>" << endl;
|
||||
out << "\t<NAME>" << nodes_[i]->label() << "</NAME>" << endl;
|
||||
const States& states = nodes_[i]->states();
|
||||
for (unsigned j = 0; j < states.size(); j++) {
|
||||
out << "\t<OUTCOME>" << states[j] << "</OUTCOME>" << endl;
|
||||
}
|
||||
out << "</VARIABLE>" << endl << endl;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
out << "<DEFINITION>" << endl;
|
||||
out << "\t<FOR>" << nodes_[i]->getLabel() << "</FOR>" << endl;
|
||||
out << "\t<FOR>" << nodes_[i]->label() << "</FOR>" << endl;
|
||||
const BnNodeSet& parents = nodes_[i]->getParents();
|
||||
for (unsigned j = 0; j < parents.size(); j++) {
|
||||
out << "\t<GIVEN>" << parents[j]->getLabel();
|
||||
out << "\t<GIVEN>" << parents[j]->label();
|
||||
out << "</GIVEN>" << endl;
|
||||
}
|
||||
ParamSet params = revertParameterReorder (nodes_[i]->getParameters(),
|
||||
nodes_[i]->getDomainSize());
|
||||
nodes_[i]->nrStates());
|
||||
out << "\t<TABLE>" ;
|
||||
for (unsigned j = 0; j < params.size(); j++) {
|
||||
out << " " << params[j];
|
||||
@ -571,9 +590,7 @@ BayesNet::containsUndirectedCycle (void) const
|
||||
|
||||
|
||||
bool
|
||||
BayesNet::containsUndirectedCycle (int v,
|
||||
int p,
|
||||
vector<bool>& visited) const
|
||||
BayesNet::containsUndirectedCycle (int v, int p, vector<bool>& visited) const
|
||||
{
|
||||
visited[v] = true;
|
||||
vector<int> adjacencies = getAdjacentNodes (v);
|
||||
@ -611,8 +628,7 @@ BayesNet::getAdjacentNodes (int v) const
|
||||
|
||||
|
||||
ParamSet
|
||||
BayesNet::reorderParameters (CParamSet params,
|
||||
unsigned domainSize) const
|
||||
BayesNet::reorderParameters (const ParamSet& params, unsigned dsize) const
|
||||
{
|
||||
// the interchange format for bayesian networks keeps the probabilities
|
||||
// in the following order:
|
||||
@ -623,13 +639,13 @@ BayesNet::reorderParameters (CParamSet params,
|
||||
// p(a1|b1,c1) p(a1|b1,c2) p(a1|b2,c1) p(a1|b2,c2) p(a2|b1,c1) p(a2|b1,c2)
|
||||
// p(a2|b2,c1) p(a2|b2,c2).
|
||||
unsigned count = 0;
|
||||
unsigned rowSize = params.size() / domainSize;
|
||||
unsigned rowSize = params.size() / dsize;
|
||||
ParamSet reordered;
|
||||
while (reordered.size() < params.size()) {
|
||||
unsigned idx = count;
|
||||
for (unsigned i = 0; i < rowSize; i++) {
|
||||
reordered.push_back (params[idx]);
|
||||
idx += domainSize;
|
||||
idx += dsize ;
|
||||
}
|
||||
count++;
|
||||
}
|
||||
@ -639,15 +655,14 @@ BayesNet::reorderParameters (CParamSet params,
|
||||
|
||||
|
||||
ParamSet
|
||||
BayesNet::revertParameterReorder (CParamSet params,
|
||||
unsigned domainSize) const
|
||||
BayesNet::revertParameterReorder (const ParamSet& params, unsigned dsize) const
|
||||
{
|
||||
unsigned count = 0;
|
||||
unsigned rowSize = params.size() / domainSize;
|
||||
unsigned rowSize = params.size() / dsize;
|
||||
ParamSet reordered;
|
||||
while (reordered.size() < params.size()) {
|
||||
unsigned idx = count;
|
||||
for (unsigned i = 0; i < domainSize; i++) {
|
||||
for (unsigned i = 0; i < dsize; i++) {
|
||||
reordered.push_back (params[idx]);
|
||||
idx += rowSize;
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef BP_BAYES_NET_H
|
||||
#define BP_BAYES_NET_H
|
||||
#ifndef HORUS_BAYESNET_H
|
||||
#define HORUS_BAYESNET_H
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
@ -44,60 +44,58 @@ struct StateInfo
|
||||
|
||||
typedef vector<Distribution*> DistSet;
|
||||
typedef queue<ScheduleInfo, list<ScheduleInfo> > Scheduling;
|
||||
typedef map<unsigned, unsigned> Histogram;
|
||||
typedef map<unsigned, double> Times;
|
||||
|
||||
|
||||
class BayesNet : public GraphicalModel
|
||||
{
|
||||
public:
|
||||
BayesNet (void) {};
|
||||
BayesNet (const char*);
|
||||
~BayesNet (void);
|
||||
|
||||
BayesNode* addNode (unsigned);
|
||||
BayesNode* addNode (unsigned, unsigned, int, BnNodeSet&,
|
||||
Distribution*);
|
||||
BayesNode* addNode (string, Domain, BnNodeSet&, ParamSet&);
|
||||
BayesNode* getBayesNode (Vid) const;
|
||||
void readFromBifFormat (const char*);
|
||||
void addNode (BayesNode*);
|
||||
BayesNode* addNode (string, const States&);
|
||||
BayesNode* addNode (VarId, unsigned, int, BnNodeSet&, Distribution*);
|
||||
BayesNode* addNode (VarId, unsigned, int, Distribution*);
|
||||
BayesNode* addNode (string, States, BnNodeSet&, ParamSet&);
|
||||
BayesNode* getBayesNode (VarId) const;
|
||||
BayesNode* getBayesNode (string) const;
|
||||
Variable* getVariable (Vid) const;
|
||||
VarNode* getVariableNode (VarId) const;
|
||||
VarNodes getVariableNodes (void) const;
|
||||
void addDistribution (Distribution*);
|
||||
Distribution* getDistribution (unsigned) const;
|
||||
const BnNodeSet& getBayesNodes (void) const;
|
||||
unsigned getNumberOfNodes (void) const;
|
||||
unsigned nrNodes (void) const;
|
||||
BnNodeSet getRootNodes (void) const;
|
||||
BnNodeSet getLeafNodes (void) const;
|
||||
VarSet getVariables (void) const;
|
||||
BayesNet* getMinimalRequesiteNetwork (Vid) const;
|
||||
BayesNet* getMinimalRequesiteNetwork (const VidSet&) const;
|
||||
void constructGraph (BayesNet*,
|
||||
const vector<StateInfo*>&) const;
|
||||
bool isSingleConnected (void) const;
|
||||
BayesNet* getMinimalRequesiteNetwork (VarId) const;
|
||||
BayesNet* getMinimalRequesiteNetwork (const VarIdSet&) const;
|
||||
void constructGraph (
|
||||
BayesNet*, const vector<StateInfo*>&) const;
|
||||
bool isPolyTree (void) const;
|
||||
void setIndexes (void);
|
||||
void distributionsToLogs (void);
|
||||
void freeDistributions (void);
|
||||
void printGraphicalModel (void) const;
|
||||
void exportToDotFormat (const char*, bool = true,
|
||||
CVidSet = VidSet()) const;
|
||||
void exportToGraphViz (const char*, bool = true,
|
||||
const VarIdSet& = VarIdSet()) const;
|
||||
void exportToBifFormat (const char*) const;
|
||||
|
||||
static Histogram histogram_;
|
||||
static Times times_;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BayesNet);
|
||||
|
||||
bool containsUndirectedCycle (void) const;
|
||||
bool containsUndirectedCycle (int, int,
|
||||
vector<bool>&)const;
|
||||
vector<int> getAdjacentNodes (int) const ;
|
||||
ParamSet reorderParameters (CParamSet, unsigned) const;
|
||||
ParamSet revertParameterReorder (CParamSet, unsigned) const;
|
||||
bool containsUndirectedCycle (int, int, vector<bool>&)const;
|
||||
vector<int> getAdjacentNodes (int) const;
|
||||
ParamSet reorderParameters (const ParamSet&, unsigned) const;
|
||||
ParamSet revertParameterReorder (const ParamSet&, unsigned) const;
|
||||
void scheduleParents (const BayesNode*, Scheduling&) const;
|
||||
void scheduleChilds (const BayesNode*, Scheduling&) const;
|
||||
|
||||
BnNodeSet nodes_;
|
||||
DistSet dists_;
|
||||
|
||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
||||
IndexMap indexMap_;
|
||||
};
|
||||
|
||||
@ -123,5 +121,5 @@ BayesNet::scheduleChilds (const BayesNode* n, Scheduling& sch) const
|
||||
}
|
||||
}
|
||||
|
||||
#endif //BP_BAYES_NET_H
|
||||
#endif // HORUS_BAYESNET_H
|
||||
|
||||
|
@ -1,34 +1,30 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "BayesNode.h"
|
||||
|
||||
|
||||
BayesNode::BayesNode (Vid vid,
|
||||
BayesNode::BayesNode (VarId vid,
|
||||
unsigned dsize,
|
||||
int evidence,
|
||||
const BnNodeSet& parents,
|
||||
Distribution* dist) : Variable (vid, dsize, evidence)
|
||||
Distribution* dist)
|
||||
: VarNode (vid, dsize, evidence)
|
||||
{
|
||||
parents_ = parents;
|
||||
dist_ = dist;
|
||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||
parents[i]->addChild (this);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
BayesNode::BayesNode (Vid vid,
|
||||
string label,
|
||||
const Domain& domain,
|
||||
BayesNode::BayesNode (VarId vid,
|
||||
unsigned dsize,
|
||||
int evidence,
|
||||
const BnNodeSet& parents,
|
||||
Distribution* dist) : Variable (vid, domain,
|
||||
NO_EVIDENCE, label)
|
||||
Distribution* dist)
|
||||
: VarNode (vid, dsize, evidence)
|
||||
{
|
||||
parents_ = parents;
|
||||
dist_ = dist;
|
||||
@ -40,15 +36,9 @@ BayesNode::BayesNode (Vid vid,
|
||||
|
||||
|
||||
void
|
||||
BayesNode::setData (unsigned dsize,
|
||||
int evidence,
|
||||
const BnNodeSet& parents,
|
||||
Distribution* dist)
|
||||
BayesNode::setParents (const BnNodeSet& parents)
|
||||
{
|
||||
setDomainSize (dsize);
|
||||
setEvidence (evidence);
|
||||
parents_ = parents;
|
||||
dist_ = dist;
|
||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||
parents[i]->addChild (this);
|
||||
}
|
||||
@ -64,6 +54,15 @@ BayesNode::addChild (BayesNode* node)
|
||||
|
||||
|
||||
|
||||
void
|
||||
BayesNode::setDistribution (Distribution* dist)
|
||||
{
|
||||
assert (dist);
|
||||
dist_ = dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Distribution*
|
||||
BayesNode::getDistribution (void)
|
||||
{
|
||||
@ -140,14 +139,14 @@ BayesNode::getCptEntries (void)
|
||||
for (int i = parents_.size() - 1; i >= 0; i--) {
|
||||
unsigned index = 0;
|
||||
while (index < rowSize) {
|
||||
for (unsigned j = 0; j < parents_[i]->getDomainSize(); j++) {
|
||||
for (unsigned j = 0; j < parents_[i]->nrStates(); j++) {
|
||||
for (unsigned r = 0; r < nReps; r++) {
|
||||
confs[index][i] = j;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
nReps *= parents_[i]->getDomainSize();
|
||||
nReps *= parents_[i]->nrStates();
|
||||
}
|
||||
|
||||
dist_->entries.reserve (rowSize);
|
||||
@ -180,14 +179,14 @@ BayesNode::cptEntryToString (const CptEntry& entry) const
|
||||
ss << "p(" ;
|
||||
const DConf& conf = entry.getDomainConfiguration();
|
||||
int row = entry.getParameterIndex() / getRowSize();
|
||||
ss << getDomain()[row];
|
||||
ss << states()[row];
|
||||
if (parents_.size() > 0) {
|
||||
ss << "|" ;
|
||||
for (unsigned int i = 0; i < conf.size(); i++) {
|
||||
if (i != 0) {
|
||||
ss << ",";
|
||||
}
|
||||
ss << parents_[i]->getDomain()[conf[i]];
|
||||
ss << parents_[i]->states()[conf[i]];
|
||||
}
|
||||
}
|
||||
ss << ")" ;
|
||||
@ -202,14 +201,14 @@ BayesNode::cptEntryToString (int row, const CptEntry& entry) const
|
||||
stringstream ss;
|
||||
ss << "p(" ;
|
||||
const DConf& conf = entry.getDomainConfiguration();
|
||||
ss << getDomain()[row];
|
||||
ss << states()[row];
|
||||
if (parents_.size() > 0) {
|
||||
ss << "|" ;
|
||||
for (unsigned int i = 0; i < conf.size(); i++) {
|
||||
if (i != 0) {
|
||||
ss << ",";
|
||||
}
|
||||
ss << parents_[i]->getDomain()[conf[i]];
|
||||
ss << parents_[i]->states()[conf[i]];
|
||||
}
|
||||
}
|
||||
ss << ")" ;
|
||||
@ -226,21 +225,21 @@ BayesNode::getDomainHeaders (void) const
|
||||
unsigned nReps = 1;
|
||||
vector<string> headers (rowSize);
|
||||
for (int i = nParents - 1; i >= 0; i--) {
|
||||
Domain domain = parents_[i]->getDomain();
|
||||
States states = parents_[i]->states();
|
||||
unsigned index = 0;
|
||||
while (index < rowSize) {
|
||||
for (unsigned j = 0; j < parents_[i]->getDomainSize(); j++) {
|
||||
for (unsigned j = 0; j < parents_[i]->nrStates(); j++) {
|
||||
for (unsigned r = 0; r < nReps; r++) {
|
||||
if (headers[index] != "") {
|
||||
headers[index] = domain[j] + "," + headers[index];
|
||||
headers[index] = states[j] + "," + headers[index];
|
||||
} else {
|
||||
headers[index] = domain[j];
|
||||
headers[index] = states[j];
|
||||
}
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
nReps *= parents_[i]->getDomainSize();
|
||||
nReps *= parents_[i]->nrStates();
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
@ -251,8 +250,8 @@ ostream&
|
||||
operator << (ostream& o, const BayesNode& node)
|
||||
{
|
||||
o << "variable " << node.getIndex() << endl;
|
||||
o << "Var Id: " << node.getVarId() << endl;
|
||||
o << "Label: " << node.getLabel() << endl;
|
||||
o << "Var Id: " << node.varId() << endl;
|
||||
o << "Label: " << node.label() << endl;
|
||||
|
||||
o << "Evidence: " ;
|
||||
if (node.hasEvidence()) {
|
||||
@ -267,9 +266,9 @@ operator << (ostream& o, const BayesNode& node)
|
||||
const BnNodeSet& parents = node.getParents();
|
||||
if (parents.size() != 0) {
|
||||
for (unsigned int i = 0; i < parents.size() - 1; i++) {
|
||||
o << parents[i]->getLabel() << ", " ;
|
||||
o << parents[i]->label() << ", " ;
|
||||
}
|
||||
o << parents[parents.size() - 1]->getLabel();
|
||||
o << parents[parents.size() - 1]->label();
|
||||
}
|
||||
o << endl;
|
||||
|
||||
@ -277,19 +276,19 @@ operator << (ostream& o, const BayesNode& node)
|
||||
const BnNodeSet& childs = node.getChilds();
|
||||
if (childs.size() != 0) {
|
||||
for (unsigned int i = 0; i < childs.size() - 1; i++) {
|
||||
o << childs[i]->getLabel() << ", " ;
|
||||
o << childs[i]->label() << ", " ;
|
||||
}
|
||||
o << childs[childs.size() - 1]->getLabel();
|
||||
o << childs[childs.size() - 1]->label();
|
||||
}
|
||||
o << endl;
|
||||
|
||||
o << "Domain: " ;
|
||||
Domain domain = node.getDomain();
|
||||
for (unsigned int i = 0; i < domain.size() - 1; i++) {
|
||||
o << domain[i] << ", " ;
|
||||
States states = node.states();
|
||||
for (unsigned int i = 0; i < states.size() - 1; i++) {
|
||||
o << states[i] << ", " ;
|
||||
}
|
||||
if (domain.size() != 0) {
|
||||
o << domain[domain.size() - 1];
|
||||
if (states.size() != 0) {
|
||||
o << states[states.size() - 1];
|
||||
}
|
||||
o << endl;
|
||||
|
||||
@ -298,10 +297,10 @@ operator << (ostream& o, const BayesNode& node)
|
||||
// min width of following columns
|
||||
const unsigned int MIN_COMBO_WIDTH = 12;
|
||||
|
||||
unsigned int domainWidth = domain[0].length();
|
||||
for (unsigned int i = 1; i < domain.size(); i++) {
|
||||
if (domain[i].length() > domainWidth) {
|
||||
domainWidth = domain[i].length();
|
||||
unsigned int domainWidth = states[0].length();
|
||||
for (unsigned int i = 1; i < states.size(); i++) {
|
||||
if (states[i].length() > domainWidth) {
|
||||
domainWidth = states[i].length();
|
||||
}
|
||||
}
|
||||
domainWidth = (domainWidth < MIN_DOMAIN_WIDTH)
|
||||
@ -334,9 +333,9 @@ operator << (ostream& o, const BayesNode& node)
|
||||
}
|
||||
o << endl;
|
||||
|
||||
for (unsigned int i = 0; i < domain.size(); i++) {
|
||||
for (unsigned int i = 0; i < states.size(); i++) {
|
||||
ParamSet row = node.getRow (i);
|
||||
o << left << setw (domainWidth) << domain[i] << right;
|
||||
o << left << setw (domainWidth) << states[i] << right;
|
||||
for (unsigned j = 0; j < node.getRowSize(); j++) {
|
||||
o << setw (widths[j]) << row[j];
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
#ifndef BP_BAYES_NODE_H
|
||||
#define BP_BAYES_NODE_H
|
||||
#ifndef HORUS_BAYESNODE_H
|
||||
#define HORUS_BAYESNODE_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "Variable.h"
|
||||
#include "VarNode.h"
|
||||
#include "CptEntry.h"
|
||||
#include "Distribution.h"
|
||||
#include "Shared.h"
|
||||
@ -11,16 +11,16 @@
|
||||
using namespace std;
|
||||
|
||||
|
||||
class BayesNode : public Variable
|
||||
class BayesNode : public VarNode
|
||||
{
|
||||
public:
|
||||
BayesNode (Vid vid) : Variable (vid) {}
|
||||
BayesNode (Vid, unsigned, int, const BnNodeSet&, Distribution*);
|
||||
BayesNode (Vid, string, const Domain&, const BnNodeSet&, Distribution*);
|
||||
BayesNode (const VarNode& v) : VarNode (v) {}
|
||||
BayesNode (VarId, unsigned, int, Distribution*);
|
||||
BayesNode (VarId, unsigned, int, const BnNodeSet&, Distribution*);
|
||||
|
||||
void setData (unsigned, int, const BnNodeSet&,
|
||||
Distribution*);
|
||||
void setParents (const BnNodeSet&);
|
||||
void addChild (BayesNode*);
|
||||
void setDistribution (Distribution*);
|
||||
Distribution* getDistribution (void);
|
||||
const ParamSet& getParameters (void);
|
||||
ParamSet getRow (int) const;
|
||||
@ -39,7 +39,7 @@ class BayesNode : public Variable
|
||||
|
||||
unsigned getRowSize (void) const
|
||||
{
|
||||
return dist_->params.size() / getDomainSize();
|
||||
return dist_->params.size() / nrStates();
|
||||
}
|
||||
|
||||
double getProbability (int row, const CptEntry& entry)
|
||||
@ -52,7 +52,7 @@ class BayesNode : public Variable
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BayesNode);
|
||||
|
||||
Domain getDomainHeaders (void) const;
|
||||
States getDomainHeaders (void) const;
|
||||
friend ostream& operator << (ostream&, const BayesNode&);
|
||||
|
||||
BnNodeSet parents_;
|
||||
@ -62,5 +62,5 @@ class BayesNode : public Variable
|
||||
|
||||
ostream& operator << (ostream&, const BayesNode&);
|
||||
|
||||
#endif //BP_BAYES_NODE_H
|
||||
#endif // HORUS_BAYESNODE_H
|
||||
|
||||
|
962
packages/CLPBN/clpbn/bp/BnBpSolver.cpp
Normal file
962
packages/CLPBN/clpbn/bp/BnBpSolver.cpp
Normal file
@ -0,0 +1,962 @@
|
||||
#include <cstdlib>
|
||||
#include <limits>
|
||||
#include <time.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "BnBpSolver.h"
|
||||
|
||||
BnBpSolver::BnBpSolver (const BayesNet& bn) : Solver (&bn)
|
||||
{
|
||||
bayesNet_ = &bn;
|
||||
jointCalcType_ = CHAIN_RULE;
|
||||
//jointCalcType_ = JUNCTION_NODE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
BnBpSolver::~BnBpSolver (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodesI_.size(); i++) {
|
||||
delete nodesI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::runSolver (void)
|
||||
{
|
||||
clock_t start;
|
||||
if (COLLECT_STATISTICS) {
|
||||
start = clock();
|
||||
}
|
||||
initializeSolver();
|
||||
if (!BpOptions::useAlwaysLoopySolver && bayesNet_->isPolyTree()) {
|
||||
runPolyTreeSolver();
|
||||
} else {
|
||||
runLoopySolver();
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Belief propagation converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
unsigned size = bayesNet_->nrNodes();
|
||||
if (COLLECT_STATISTICS) {
|
||||
unsigned nIters = 0;
|
||||
bool loopy = bayesNet_->isPolyTree() == false;
|
||||
if (loopy) nIters = nIters_;
|
||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||
}
|
||||
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
|
||||
stringstream ss;
|
||||
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
|
||||
bayesNet_->exportToGraphViz (ss.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
BnBpSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
BayesNode* node = bayesNet_->getBayesNode (vid);
|
||||
assert (node);
|
||||
return nodesI_[node->getIndex()]->getBeliefs();
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
BnBpSolver::getJointDistributionOf (const VarIdSet& jointVarIds)
|
||||
{
|
||||
if (DL >= 2) {
|
||||
cout << "calculating joint distribution on: " ;
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
VarNode* var = bayesNet_->getBayesNode (jointVarIds[i]);
|
||||
cout << var->label() << " " ;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
if (jointCalcType_ == JUNCTION_NODE) {
|
||||
return getJointByJunctionNode (jointVarIds);
|
||||
} else {
|
||||
return getJointByChainRule (jointVarIds);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::initializeSolver (void)
|
||||
{
|
||||
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodesI_.size(); i++) {
|
||||
delete nodesI_[i];
|
||||
}
|
||||
nodesI_.clear();
|
||||
nodesI_.reserve (nodes.size());
|
||||
links_.clear();
|
||||
sortedOrder_.clear();
|
||||
linkMap_.clear();
|
||||
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
nodesI_.push_back (new BpNodeInfo (nodes[i]));
|
||||
}
|
||||
|
||||
BnNodeSet roots = bayesNet_->getRootNodes();
|
||||
for (unsigned i = 0; i < roots.size(); i++) {
|
||||
const ParamSet& params = roots[i]->getParameters();
|
||||
ParamSet& piVals = ninf(roots[i])->getPiValues();
|
||||
for (unsigned ri = 0; ri < roots[i]->nrStates(); ri++) {
|
||||
piVals[ri] = params[ri];
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
const BnNodeSet& parents = nodes[i]->getParents();
|
||||
for (unsigned j = 0; j < parents.size(); j++) {
|
||||
BpLink* newLink = new BpLink (
|
||||
parents[j], nodes[i], LinkOrientation::DOWN);
|
||||
links_.push_back (newLink);
|
||||
ninf(nodes[i])->addIncomingParentLink (newLink);
|
||||
ninf(parents[j])->addOutcomingChildLink (newLink);
|
||||
}
|
||||
const BnNodeSet& childs = nodes[i]->getChilds();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
BpLink* newLink = new BpLink (
|
||||
childs[j], nodes[i], LinkOrientation::UP);
|
||||
links_.push_back (newLink);
|
||||
ninf(nodes[i])->addIncomingChildLink (newLink);
|
||||
ninf(childs[j])->addOutcomingParentLink (newLink);
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
if (nodes[i]->hasEvidence()) {
|
||||
ParamSet& piVals = ninf(nodes[i])->getPiValues();
|
||||
ParamSet& ldVals = ninf(nodes[i])->getLambdaValues();
|
||||
for (unsigned xi = 0; xi < nodes[i]->nrStates(); xi++) {
|
||||
piVals[xi] = Util::noEvidence();
|
||||
ldVals[xi] = Util::noEvidence();
|
||||
}
|
||||
piVals[nodes[i]->getEvidence()] = Util::withEvidence();
|
||||
ldVals[nodes[i]->getEvidence()] = Util::withEvidence();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::runPolyTreeSolver (void)
|
||||
{
|
||||
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
if (nodes[i]->isRoot()) {
|
||||
ninf(nodes[i])->markPiValuesAsCalculated();
|
||||
}
|
||||
if (nodes[i]->isLeaf()) {
|
||||
ninf(nodes[i])->markLambdaValuesAsCalculated();
|
||||
}
|
||||
}
|
||||
|
||||
bool finish = false;
|
||||
while (!finish) {
|
||||
finish = true;
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
if (ninf(nodes[i])->piValuesCalculated() == false
|
||||
&& ninf(nodes[i])->receivedAllPiMessages()) {
|
||||
if (!nodes[i]->hasEvidence()) {
|
||||
updatePiValues (nodes[i]);
|
||||
}
|
||||
ninf(nodes[i])->markPiValuesAsCalculated();
|
||||
finish = false;
|
||||
}
|
||||
|
||||
if (ninf(nodes[i])->lambdaValuesCalculated() == false
|
||||
&& ninf(nodes[i])->receivedAllLambdaMessages()) {
|
||||
if (!nodes[i]->hasEvidence()) {
|
||||
updateLambdaValues (nodes[i]);
|
||||
}
|
||||
ninf(nodes[i])->markLambdaValuesAsCalculated();
|
||||
finish = false;
|
||||
}
|
||||
|
||||
if (ninf(nodes[i])->piValuesCalculated()) {
|
||||
const BpLinkSet& outChildLinks
|
||||
= ninf(nodes[i])->getOutcomingChildLinks();
|
||||
for (unsigned j = 0; j < outChildLinks.size(); j++) {
|
||||
BayesNode* child = outChildLinks[j]->getDestination();
|
||||
if (!outChildLinks[j]->messageWasSended()) {
|
||||
if (ninf(nodes[i])->readyToSendPiMsgTo (child)) {
|
||||
calculateAndUpdateMessage (outChildLinks[j], false);
|
||||
ninf(child)->incNumPiMsgsReceived();
|
||||
}
|
||||
finish = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ninf(nodes[i])->lambdaValuesCalculated()) {
|
||||
const BpLinkSet& outParentLinks =
|
||||
ninf(nodes[i])->getOutcomingParentLinks();
|
||||
for (unsigned j = 0; j < outParentLinks.size(); j++) {
|
||||
BayesNode* parent = outParentLinks[j]->getDestination();
|
||||
if (!outParentLinks[j]->messageWasSended()) {
|
||||
if (ninf(nodes[i])->readyToSendLambdaMsgTo (parent)) {
|
||||
calculateAndUpdateMessage (outParentLinks[j], false);
|
||||
ninf(parent)->incNumLambdaMsgsReceived();
|
||||
}
|
||||
finish = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::runLoopySolver()
|
||||
{
|
||||
nIters_ = 0;
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
|
||||
nIters_++;
|
||||
if (DL >= 2) {
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
cout << " Iteration " << nIters_ << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
switch (BpOptions::schedule) {
|
||||
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
updateValues (links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
updateMessage (links_[i]);
|
||||
updateValues (links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
|
||||
}
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BnBpSolver::converged (void) const
|
||||
{
|
||||
// this can happen if the graph is fully disconnected
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIters_ == 0 || nIters_ == 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||
Param maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||
if (maxResidual < BpOptions::accuracy) {
|
||||
converged = true;
|
||||
} else {
|
||||
converged = false;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
Param residual = links_[i]->getResidual();
|
||||
if (DL >= 2) {
|
||||
cout << links_[i]->toString() + " residual change = " ;
|
||||
cout << residual << endl;
|
||||
}
|
||||
if (residual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < sortedOrder_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
cout << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
BpLink* link = *it;
|
||||
if (link->getResidual() < BpOptions::accuracy) {
|
||||
sortedOrder_.erase (it);
|
||||
it = sortedOrder_.begin();
|
||||
return;
|
||||
}
|
||||
updateMessage (link);
|
||||
updateValues (link);
|
||||
link->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
const BpLinkSet& outParentLinks =
|
||||
ninf(link->getDestination())->getOutcomingParentLinks();
|
||||
for (unsigned i = 0; i < outParentLinks.size(); i++) {
|
||||
if (outParentLinks[i]->getDestination() != link->getSource()
|
||||
&& outParentLinks[i]->getDestination()->hasEvidence() == false) {
|
||||
calculateMessage (outParentLinks[i]);
|
||||
BpLinkMap::iterator iter = linkMap_.find (outParentLinks[i]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (outParentLinks[i]);
|
||||
}
|
||||
}
|
||||
const BpLinkSet& outChildLinks =
|
||||
ninf(link->getDestination())->getOutcomingChildLinks();
|
||||
for (unsigned i = 0; i < outChildLinks.size(); i++) {
|
||||
if (outChildLinks[i]->getDestination() != link->getSource()) {
|
||||
calculateMessage (outChildLinks[i]);
|
||||
BpLinkMap::iterator iter = linkMap_.find (outChildLinks[i]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (outChildLinks[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (DL >= 2) {
|
||||
cout << "----------------------------------------" ;
|
||||
cout << "----------------------------------------" << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::updatePiValues (BayesNode* x)
|
||||
{
|
||||
// π(Xi)
|
||||
if (DL >= 3) {
|
||||
cout << "updating " << PI_SYMBOL << " values for " << x->label() << endl;
|
||||
}
|
||||
ParamSet& piValues = ninf(x)->getPiValues();
|
||||
const BpLinkSet& parentLinks = ninf(x)->getIncomingParentLinks();
|
||||
const vector<CptEntry>& entries = x->getCptEntries();
|
||||
stringstream* calcs1 = 0;
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
ParamSet messageProducts (entries.size());
|
||||
for (unsigned k = 0; k < entries.size(); k++) {
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
double messageProduct = Util::multIdenty();
|
||||
const DConf& conf = entries[k].getDomainConfiguration();
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
||||
messageProduct *= parentLinks[i]->getMessage()[conf[i]];
|
||||
if (DL >= 5) {
|
||||
if (i != 0) *calcs1 << " + " ;
|
||||
if (i != 0) *calcs2 << " + " ;
|
||||
*calcs1 << parentLinks[i]->toString (conf[i]);
|
||||
*calcs2 << parentLinks[i]->getMessage()[conf[i]];
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned i = 0; i < parentLinks.size(); i++) {
|
||||
messageProduct += parentLinks[i]->getMessage()[conf[i]];
|
||||
}
|
||||
}
|
||||
messageProducts[k] = messageProduct;
|
||||
if (DL >= 5) {
|
||||
cout << " mp" << k;
|
||||
cout << " = " << (*calcs1).str();
|
||||
if (parentLinks.size() == 1) {
|
||||
cout << " = " << messageProduct << endl;
|
||||
} else {
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << messageProduct << endl;
|
||||
}
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||
double sum = Util::addIdenty();
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned k = 0; k < entries.size(); k++) {
|
||||
sum += x->getProbability (xi, entries[k]) * messageProducts[k];
|
||||
if (DL >= 5) {
|
||||
if (k != 0) *calcs1 << " + " ;
|
||||
if (k != 0) *calcs2 << " + " ;
|
||||
*calcs1 << x->cptEntryToString (xi, entries[k]);
|
||||
*calcs1 << ".mp" << k;
|
||||
*calcs2 << Util::fl (x->getProbability (xi, entries[k]));
|
||||
*calcs2 << "*" << messageProducts[k];
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned k = 0; k < entries.size(); k++) {
|
||||
Util::logSum (sum,
|
||||
x->getProbability(xi,entries[k]) + messageProducts[k]);
|
||||
}
|
||||
}
|
||||
piValues[xi] = sum;
|
||||
if (DL >= 5) {
|
||||
cout << " " << PI_SYMBOL << "(" << x->label() << ")" ;
|
||||
cout << "[" << x->states()[xi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << piValues[xi] << endl;
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::updateLambdaValues (BayesNode* x)
|
||||
{
|
||||
// λ(Xi)
|
||||
if (DL >= 3) {
|
||||
cout << "updating " << LD_SYMBOL << " values for " << x->label() << endl;
|
||||
}
|
||||
ParamSet& lambdaValues = ninf(x)->getLambdaValues();
|
||||
const BpLinkSet& childLinks = ninf(x)->getIncomingChildLinks();
|
||||
stringstream* calcs1 = 0;
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
double product = Util::multIdenty();
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
||||
product *= childLinks[i]->getMessage()[xi];
|
||||
if (DL >= 5) {
|
||||
if (i != 0) *calcs1 << "." ;
|
||||
if (i != 0) *calcs2 << "*" ;
|
||||
*calcs1 << childLinks[i]->toString (xi);
|
||||
*calcs2 << childLinks[i]->getMessage()[xi];
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned i = 0; i < childLinks.size(); i++) {
|
||||
product += childLinks[i]->getMessage()[xi];
|
||||
}
|
||||
}
|
||||
lambdaValues[xi] = product;
|
||||
if (DL >= 5) {
|
||||
cout << " " << LD_SYMBOL << "(" << x->label() << ")" ;
|
||||
cout << "[" << x->states()[xi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
if (childLinks.size() == 1) {
|
||||
cout << " = " << product << endl;
|
||||
} else {
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << lambdaValues[xi] << endl;
|
||||
}
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::calculatePiMessage (BpLink* link)
|
||||
{
|
||||
// πX(Zi)
|
||||
BayesNode* z = link->getSource();
|
||||
BayesNode* x = link->getDestination();
|
||||
ParamSet& zxPiNextMessage = link->getNextMessage();
|
||||
const BpLinkSet& zChildLinks = ninf(z)->getIncomingChildLinks();
|
||||
stringstream* calcs1 = 0;
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
const ParamSet& zPiValues = ninf(z)->getPiValues();
|
||||
for (unsigned zi = 0; zi < z->nrStates(); zi++) {
|
||||
double product = zPiValues[zi];
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
*calcs1 << PI_SYMBOL << "(" << z->label() << ")";
|
||||
*calcs1 << "[" << z->states()[zi] << "]" ;
|
||||
*calcs2 << product;
|
||||
}
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
||||
if (zChildLinks[i]->getSource() != x) {
|
||||
product *= zChildLinks[i]->getMessage()[zi];
|
||||
if (DL >= 5) {
|
||||
*calcs1 << "." << zChildLinks[i]->toString (zi);
|
||||
*calcs2 << " * " << zChildLinks[i]->getMessage()[zi];
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned i = 0; i < zChildLinks.size(); i++) {
|
||||
if (zChildLinks[i]->getSource() != x) {
|
||||
product += zChildLinks[i]->getMessage()[zi];
|
||||
}
|
||||
}
|
||||
}
|
||||
zxPiNextMessage[zi] = product;
|
||||
if (DL >= 5) {
|
||||
cout << " " << link->toString();
|
||||
cout << "[" << z->states()[zi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
if (zChildLinks.size() == 1) {
|
||||
cout << " = " << product << endl;
|
||||
} else {
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << product << endl;
|
||||
}
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
Util::normalize (zxPiNextMessage);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::calculateLambdaMessage (BpLink* link)
|
||||
{
|
||||
// λY(Xi)
|
||||
BayesNode* y = link->getSource();
|
||||
BayesNode* x = link->getDestination();
|
||||
if (x->hasEvidence()) {
|
||||
return;
|
||||
}
|
||||
ParamSet& yxLambdaNextMessage = link->getNextMessage();
|
||||
const BpLinkSet& yParentLinks = ninf(y)->getIncomingParentLinks();
|
||||
const ParamSet& yLambdaValues = ninf(y)->getLambdaValues();
|
||||
const vector<CptEntry>& allEntries = y->getCptEntries();
|
||||
int parentIndex = y->getIndexOfParent (x);
|
||||
stringstream* calcs1 = 0;
|
||||
stringstream* calcs2 = 0;
|
||||
|
||||
vector<CptEntry> entries;
|
||||
DConstraint constr = make_pair (parentIndex, 0);
|
||||
for (unsigned i = 0; i < allEntries.size(); i++) {
|
||||
if (allEntries[i].matchConstraints(constr)) {
|
||||
entries.push_back (allEntries[i]);
|
||||
}
|
||||
}
|
||||
|
||||
ParamSet messageProducts (entries.size());
|
||||
for (unsigned k = 0; k < entries.size(); k++) {
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
double messageProduct = Util::multIdenty();
|
||||
const DConf& conf = entries[k].getDomainConfiguration();
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
||||
if (yParentLinks[i]->getSource() != x) {
|
||||
if (DL >= 5) {
|
||||
if (messageProduct != Util::multIdenty()) *calcs1 << "*" ;
|
||||
if (messageProduct != Util::multIdenty()) *calcs2 << "*" ;
|
||||
*calcs1 << yParentLinks[i]->toString (conf[i]);
|
||||
*calcs2 << yParentLinks[i]->getMessage()[conf[i]];
|
||||
}
|
||||
messageProduct *= yParentLinks[i]->getMessage()[conf[i]];
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned i = 0; i < yParentLinks.size(); i++) {
|
||||
if (yParentLinks[i]->getSource() != x) {
|
||||
messageProduct += yParentLinks[i]->getMessage()[conf[i]];
|
||||
}
|
||||
}
|
||||
}
|
||||
messageProducts[k] = messageProduct;
|
||||
if (DL >= 5) {
|
||||
cout << " mp" << k;
|
||||
cout << " = " << (*calcs1).str();
|
||||
if (yParentLinks.size() == 1) {
|
||||
cout << 1 << endl;
|
||||
} else if (yParentLinks.size() == 2) {
|
||||
cout << " = " << messageProduct << endl;
|
||||
} else {
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << messageProduct << endl;
|
||||
}
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned xi = 0; xi < x->nrStates(); xi++) {
|
||||
if (DL >= 5) {
|
||||
calcs1 = new stringstream;
|
||||
calcs2 = new stringstream;
|
||||
}
|
||||
vector<CptEntry> entries;
|
||||
DConstraint constr = make_pair (parentIndex, xi);
|
||||
for (unsigned i = 0; i < allEntries.size(); i++) {
|
||||
if (allEntries[i].matchConstraints(constr)) {
|
||||
entries.push_back (allEntries[i]);
|
||||
}
|
||||
}
|
||||
double outerSum = Util::addIdenty();
|
||||
for (unsigned yi = 0; yi < y->nrStates(); yi++) {
|
||||
if (DL >= 5) {
|
||||
(yi != 0) ? *calcs1 << " + {" : *calcs1 << "{" ;
|
||||
(yi != 0) ? *calcs2 << " + {" : *calcs2 << "{" ;
|
||||
}
|
||||
double innerSum = Util::addIdenty();
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned k = 0; k < entries.size(); k++) {
|
||||
if (DL >= 5) {
|
||||
if (k != 0) *calcs1 << " + " ;
|
||||
if (k != 0) *calcs2 << " + " ;
|
||||
*calcs1 << y->cptEntryToString (yi, entries[k]);
|
||||
*calcs1 << ".mp" << k;
|
||||
*calcs2 << y->getProbability (yi, entries[k]);
|
||||
*calcs2 << "*" << messageProducts[k];
|
||||
}
|
||||
innerSum += y->getProbability (yi, entries[k]) * messageProducts[k];
|
||||
}
|
||||
outerSum += innerSum * yLambdaValues[yi];
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned k = 0; k < entries.size(); k++) {
|
||||
Util::logSum (innerSum,
|
||||
y->getProbability(yi, entries[k]) + messageProducts[k]);
|
||||
}
|
||||
Util::logSum (outerSum, innerSum + yLambdaValues[yi]);
|
||||
}
|
||||
if (DL >= 5) {
|
||||
*calcs1 << "}." << LD_SYMBOL << "(" << y->label() << ")" ;
|
||||
*calcs1 << "[" << y->states()[yi] << "]";
|
||||
*calcs2 << "}*" << yLambdaValues[yi];
|
||||
}
|
||||
}
|
||||
yxLambdaNextMessage[xi] = outerSum;
|
||||
if (DL >= 5) {
|
||||
cout << " " << link->toString();
|
||||
cout << "[" << x->states()[xi] << "]" ;
|
||||
cout << " = " << (*calcs1).str();
|
||||
cout << " = " << (*calcs2).str();
|
||||
cout << " = " << yxLambdaNextMessage[xi] << endl;
|
||||
delete calcs1;
|
||||
delete calcs2;
|
||||
}
|
||||
}
|
||||
Util::normalize (yxLambdaNextMessage);
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
BnBpSolver::getJointByJunctionNode (const VarIdSet& jointVarIds)
|
||||
{
|
||||
unsigned msgSize = 1;
|
||||
vector<unsigned> dsizes (jointVarIds.size());
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
dsizes[i] = bayesNet_->getBayesNode (jointVarIds[i])->nrStates();
|
||||
msgSize *= dsizes[i];
|
||||
}
|
||||
unsigned reps = 1;
|
||||
ParamSet jointDist (msgSize, Util::multIdenty());
|
||||
for (int i = jointVarIds.size() - 1 ; i >= 0; i--) {
|
||||
Util::multiply (jointDist, getPosterioriOf (jointVarIds[i]), reps);
|
||||
reps *= dsizes[i] ;
|
||||
}
|
||||
return jointDist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
BnBpSolver::getJointByChainRule (const VarIdSet& jointVarIds) const
|
||||
{
|
||||
BnNodeSet jointVars;
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
jointVars.push_back (bayesNet_->getBayesNode (jointVarIds[i]));
|
||||
}
|
||||
|
||||
BayesNet* mrn = bayesNet_->getMinimalRequesiteNetwork (jointVarIds[0]);
|
||||
BnBpSolver solver (*mrn);
|
||||
solver.runSolver();
|
||||
ParamSet prevBeliefs = solver.getPosterioriOf (jointVarIds[0]);
|
||||
delete mrn;
|
||||
|
||||
VarNodes observedVars = {jointVars[0]};
|
||||
|
||||
for (unsigned i = 1; i < jointVarIds.size(); i++) {
|
||||
mrn = bayesNet_->getMinimalRequesiteNetwork (jointVarIds[i]);
|
||||
ParamSet newBeliefs;
|
||||
vector<DConf> confs =
|
||||
Util::getDomainConfigurations (observedVars);
|
||||
for (unsigned j = 0; j < confs.size(); j++) {
|
||||
for (unsigned k = 0; k < observedVars.size(); k++) {
|
||||
if (!observedVars[k]->hasEvidence()) {
|
||||
BayesNode* node = mrn->getBayesNode (observedVars[k]->varId());
|
||||
if (node) {
|
||||
node->setEvidence (confs[j][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
BnBpSolver solver (*mrn);
|
||||
solver.runSolver();
|
||||
ParamSet beliefs = solver.getPosterioriOf (jointVarIds[i]);
|
||||
for (unsigned k = 0; k < beliefs.size(); k++) {
|
||||
newBeliefs.push_back (beliefs[k]);
|
||||
}
|
||||
}
|
||||
|
||||
int count = -1;
|
||||
for (unsigned j = 0; j < newBeliefs.size(); j++) {
|
||||
if (j % jointVars[i]->nrStates() == 0) {
|
||||
count ++;
|
||||
}
|
||||
newBeliefs[j] *= prevBeliefs[count];
|
||||
}
|
||||
prevBeliefs = newBeliefs;
|
||||
observedVars.push_back (jointVars[i]);
|
||||
delete mrn;
|
||||
}
|
||||
return prevBeliefs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::printPiLambdaValues (const BayesNode* var) const
|
||||
{
|
||||
cout << left;
|
||||
cout << setw (10) << "states" ;
|
||||
cout << setw (20) << PI_SYMBOL << "(" + var->label() + ")" ;
|
||||
cout << setw (20) << LD_SYMBOL << "(" + var->label() + ")" ;
|
||||
cout << setw (16) << "belief" ;
|
||||
cout << endl;
|
||||
cout << "--------------------------------" ;
|
||||
cout << "--------------------------------" ;
|
||||
cout << endl;
|
||||
const States& states = var->states();
|
||||
const ParamSet& piVals = ninf(var)->getPiValues();
|
||||
const ParamSet& ldVals = ninf(var)->getLambdaValues();
|
||||
const ParamSet& beliefs = ninf(var)->getBeliefs();
|
||||
for (unsigned xi = 0; xi < var->nrStates(); xi++) {
|
||||
cout << setw (10) << states[xi];
|
||||
cout << setw (19) << piVals[xi];
|
||||
cout << setw (19) << ldVals[xi];
|
||||
cout.precision (PRECISION);
|
||||
cout << setw (16) << beliefs[xi];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BnBpSolver::printAllMessageStatus (void) const
|
||||
{
|
||||
const BnNodeSet& nodes = bayesNet_->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
printPiLambdaValues (nodes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
BpNodeInfo::BpNodeInfo (BayesNode* node)
|
||||
{
|
||||
node_ = node;
|
||||
piValsCalc_ = false;
|
||||
ldValsCalc_ = false;
|
||||
nPiMsgsRcv_ = 0;
|
||||
nLdMsgsRcv_ = 0;
|
||||
piVals_.resize (node->nrStates(), Util::one());
|
||||
ldVals_.resize (node->nrStates(), Util::one());
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
BpNodeInfo::getBeliefs (void) const
|
||||
{
|
||||
double sum = 0.0;
|
||||
ParamSet beliefs (node_->nrStates());
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
||||
beliefs[xi] = piVals_[xi] * ldVals_[xi];
|
||||
sum += beliefs[xi];
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
||||
beliefs[xi] = exp (piVals_[xi] + ldVals_[xi]);
|
||||
sum += beliefs[xi];
|
||||
}
|
||||
}
|
||||
assert (sum);
|
||||
for (unsigned xi = 0; xi < node_->nrStates(); xi++) {
|
||||
beliefs[xi] /= sum;
|
||||
}
|
||||
return beliefs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNodeInfo::markPiValuesAsCalculated (void)
|
||||
{
|
||||
piValsCalc_ = true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNodeInfo::markLambdaValuesAsCalculated (void)
|
||||
{
|
||||
ldValsCalc_ = true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BpNodeInfo::receivedAllPiMessages (void)
|
||||
{
|
||||
return node_->getParents().size() == nPiMsgsRcv_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BpNodeInfo::receivedAllLambdaMessages (void)
|
||||
{
|
||||
return node_->getChilds().size() == nLdMsgsRcv_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BpNodeInfo::readyToSendPiMsgTo (const BayesNode* child) const
|
||||
{
|
||||
for (unsigned i = 0; i < inChildLinks_.size(); i++) {
|
||||
if (inChildLinks_[i]->getSource() != child
|
||||
&& inChildLinks_[i]->messageWasSended() == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BpNodeInfo::readyToSendLambdaMsgTo (const BayesNode* parent) const
|
||||
{
|
||||
for (unsigned i = 0; i < inParentLinks_.size(); i++) {
|
||||
if (inParentLinks_[i]->getSource() != parent
|
||||
&& inParentLinks_[i]->messageWasSended() == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BpNodeInfo::receivedBottomInfluence (void) const
|
||||
{
|
||||
// if all lambda values are equal, then neither
|
||||
// this node neither its descendents have evidence,
|
||||
// we can use this to don't send lambda messages his parents
|
||||
bool childInfluenced = false;
|
||||
for (unsigned xi = 1; xi < node_->nrStates(); xi++) {
|
||||
if (ldVals_[xi] != ldVals_[0]) {
|
||||
childInfluenced = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return childInfluenced;
|
||||
}
|
||||
|
262
packages/CLPBN/clpbn/bp/BnBpSolver.h
Normal file
262
packages/CLPBN/clpbn/bp/BnBpSolver.h
Normal file
@ -0,0 +1,262 @@
|
||||
#ifndef HORUS_BNBPSOLVER_H
|
||||
#define HORUS_BNBPSOLVER_H
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "BayesNet.h"
|
||||
#include "Shared.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class BpNodeInfo;
|
||||
|
||||
static const string PI_SYMBOL = "pi" ;
|
||||
static const string LD_SYMBOL = "ld" ;
|
||||
|
||||
enum LinkOrientation {UP, DOWN};
|
||||
enum JointCalcType {CHAIN_RULE, JUNCTION_NODE};
|
||||
|
||||
class BpLink
|
||||
{
|
||||
public:
|
||||
BpLink (BayesNode* s, BayesNode* d, LinkOrientation o)
|
||||
{
|
||||
source_ = s;
|
||||
destin_ = d;
|
||||
orientation_ = o;
|
||||
if (orientation_ == LinkOrientation::DOWN) {
|
||||
v1_.resize (s->nrStates(), Util::tl (1.0/s->nrStates()));
|
||||
v2_.resize (s->nrStates(), Util::tl (1.0/s->nrStates()));
|
||||
} else {
|
||||
v1_.resize (d->nrStates(), Util::tl (1.0/d->nrStates()));
|
||||
v2_.resize (d->nrStates(), Util::tl (1.0/d->nrStates()));
|
||||
}
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
residual_ = 0;
|
||||
msgSended_ = false;
|
||||
}
|
||||
|
||||
void updateMessage (void)
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
msgSended_ = true;
|
||||
}
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = Util::getMaxNorm (v1_, v2_);
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
if (orientation_ == LinkOrientation::DOWN) {
|
||||
ss << PI_SYMBOL;
|
||||
} else {
|
||||
ss << LD_SYMBOL;
|
||||
}
|
||||
ss << "(" << source_->label();
|
||||
ss << " --> " << destin_->label() << ")" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
string toString (unsigned stateIndex) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << toString() << "[" ;
|
||||
if (orientation_ == LinkOrientation::DOWN) {
|
||||
ss << source_->states()[stateIndex] << "]" ;
|
||||
} else {
|
||||
ss << destin_->states()[stateIndex] << "]" ;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
BayesNode* getSource (void) const { return source_; }
|
||||
BayesNode* getDestination (void) const { return destin_; }
|
||||
LinkOrientation getOrientation (void) const { return orientation_; }
|
||||
const ParamSet& getMessage (void) const { return *currMsg_; }
|
||||
ParamSet& getNextMessage (void) { return *nextMsg_; }
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
double getResidual (void) const { return residual_; }
|
||||
void clearResidual (void) { residual_ = 0;}
|
||||
|
||||
private:
|
||||
BayesNode* source_;
|
||||
BayesNode* destin_;
|
||||
LinkOrientation orientation_;
|
||||
ParamSet v1_;
|
||||
ParamSet v2_;
|
||||
ParamSet* currMsg_;
|
||||
ParamSet* nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
};
|
||||
|
||||
|
||||
typedef vector<BpLink*> BpLinkSet;
|
||||
|
||||
|
||||
class BpNodeInfo
|
||||
{
|
||||
public:
|
||||
BpNodeInfo (BayesNode*);
|
||||
|
||||
ParamSet getBeliefs (void) const;
|
||||
bool receivedBottomInfluence (void) const;
|
||||
|
||||
ParamSet& getPiValues (void) { return piVals_; }
|
||||
ParamSet& getLambdaValues (void) { return ldVals_; }
|
||||
void incNumPiMsgsReceived (void) { nPiMsgsRcv_ ++; }
|
||||
void incNumLambdaMsgsReceived (void) { nLdMsgsRcv_ ++; }
|
||||
bool piValuesCalculated (void) { return piValsCalc_; }
|
||||
bool lambdaValuesCalculated (void) { return ldValsCalc_; }
|
||||
|
||||
void markPiValuesAsCalculated (void);
|
||||
void markLambdaValuesAsCalculated (void);
|
||||
bool receivedAllPiMessages (void);
|
||||
bool receivedAllLambdaMessages (void);
|
||||
bool readyToSendPiMsgTo (const BayesNode*) const ;
|
||||
bool readyToSendLambdaMsgTo (const BayesNode*) const;
|
||||
|
||||
const BpLinkSet& getIncomingParentLinks (void) { return inParentLinks_; }
|
||||
const BpLinkSet& getIncomingChildLinks (void) { return inChildLinks_; }
|
||||
const BpLinkSet& getOutcomingParentLinks (void) { return outParentLinks_; }
|
||||
const BpLinkSet& getOutcomingChildLinks (void) { return outChildLinks_; }
|
||||
|
||||
void addIncomingParentLink (BpLink* l) { inParentLinks_.push_back (l); }
|
||||
void addIncomingChildLink (BpLink* l) { inChildLinks_.push_back (l); }
|
||||
void addOutcomingParentLink (BpLink* l) { outParentLinks_.push_back (l); }
|
||||
void addOutcomingChildLink (BpLink* l) { outChildLinks_.push_back (l); }
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BpNodeInfo);
|
||||
|
||||
ParamSet piVals_; // pi values
|
||||
ParamSet ldVals_; // lambda values
|
||||
unsigned nPiMsgsRcv_;
|
||||
unsigned nLdMsgsRcv_;
|
||||
bool piValsCalc_;
|
||||
bool ldValsCalc_;
|
||||
BpLinkSet inParentLinks_;
|
||||
BpLinkSet inChildLinks_;
|
||||
BpLinkSet outParentLinks_;
|
||||
BpLinkSet outChildLinks_;
|
||||
const BayesNode* node_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class BnBpSolver : public Solver
|
||||
{
|
||||
public:
|
||||
BnBpSolver (const BayesNet&);
|
||||
~BnBpSolver (void);
|
||||
|
||||
void runSolver (void);
|
||||
ParamSet getPosterioriOf (VarId);
|
||||
ParamSet getJointDistributionOf (const VarIdSet&);
|
||||
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BnBpSolver);
|
||||
|
||||
void initializeSolver (void);
|
||||
void runPolyTreeSolver (void);
|
||||
void runLoopySolver (void);
|
||||
void maxResidualSchedule (void);
|
||||
bool converged (void) const;
|
||||
void updatePiValues (BayesNode*);
|
||||
void updateLambdaValues (BayesNode*);
|
||||
void calculateLambdaMessage (BpLink*);
|
||||
void calculatePiMessage (BpLink*);
|
||||
ParamSet getJointByJunctionNode (const VarIdSet&);
|
||||
ParamSet getJointByChainRule (const VarIdSet&) const;
|
||||
void printPiLambdaValues (const BayesNode*) const;
|
||||
void printAllMessageStatus (void) const;
|
||||
|
||||
void calculateAndUpdateMessage (BpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
cout << "calculating & updating " << link->toString() << endl;
|
||||
}
|
||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
||||
calculatePiMessage (link);
|
||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
||||
calculateLambdaMessage (link);
|
||||
}
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
|
||||
void calculateMessage (BpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
cout << "calculating " << link->toString() << endl;
|
||||
}
|
||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
||||
calculatePiMessage (link);
|
||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
||||
calculateLambdaMessage (link);
|
||||
}
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
}
|
||||
|
||||
void updateMessage (BpLink* link)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
cout << "updating " << link->toString() << endl;
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
|
||||
void updateValues (BpLink* link)
|
||||
{
|
||||
if (!link->getDestination()->hasEvidence()) {
|
||||
if (link->getOrientation() == LinkOrientation::DOWN) {
|
||||
updatePiValues (link->getDestination());
|
||||
} else if (link->getOrientation() == LinkOrientation::UP) {
|
||||
updateLambdaValues (link->getDestination());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BpNodeInfo* ninf (const BayesNode* node) const
|
||||
{
|
||||
assert (node);
|
||||
assert (node == bayesNet_->getBayesNode (node->varId()));
|
||||
assert (node->getIndex() < nodesI_.size());
|
||||
return nodesI_[node->getIndex()];
|
||||
}
|
||||
|
||||
const BayesNet* bayesNet_;
|
||||
vector<BpLink*> links_;
|
||||
vector<BpNodeInfo*> nodesI_;
|
||||
unsigned nIters_;
|
||||
JointCalcType jointCalcType_;
|
||||
|
||||
struct compare
|
||||
{
|
||||
inline bool operator() (const BpLink* e1, const BpLink* e2)
|
||||
{
|
||||
return e1->getResidual() > e2->getResidual();
|
||||
}
|
||||
};
|
||||
|
||||
typedef multiset<BpLink*, compare> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
|
||||
typedef unordered_map<BpLink*, SortedOrder::iterator> BpLinkMap;
|
||||
BpLinkMap linkMap_;
|
||||
|
||||
};
|
||||
|
||||
#endif // HORUS_BNBPSOLVER_H
|
||||
|
@ -1,811 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <cstdlib>
|
||||
#include <sstream>
|
||||
|
||||
#include "BpNetwork.h"
|
||||
#include "BpNode.h"
|
||||
#include "CptEntry.h"
|
||||
|
||||
BpNetwork::BpNetwork (void)
|
||||
{
|
||||
schedule_ = SEQUENTIAL_SCHEDULE;
|
||||
maxIter_ = 150;
|
||||
stableThreashold_ = 0.00000000000000000001;
|
||||
}
|
||||
|
||||
|
||||
|
||||
BpNetwork::~BpNetwork (void)
|
||||
{
|
||||
for (unsigned int i = 0; i < nodes_.size(); i++) {
|
||||
delete nodes_[i];
|
||||
}
|
||||
nodes_.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::setSolverParameters (Schedule schedule,
|
||||
int maxIter,
|
||||
double stableThreashold)
|
||||
{
|
||||
if (maxIter <= 0) {
|
||||
cerr << "error: maxIter must be greater or equal to 1" << endl;
|
||||
abort();
|
||||
}
|
||||
if (stableThreashold <= 0.0 || stableThreashold >= 1.0) {
|
||||
cerr << "error: stableThreashold must be greater than 0.0 " ;
|
||||
cerr << "and lesser than 1.0" << endl;
|
||||
abort();
|
||||
}
|
||||
schedule_ = schedule;
|
||||
maxIter_ = maxIter;
|
||||
stableThreashold_ = stableThreashold;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::runSolver (BayesianNode* queryVar)
|
||||
{
|
||||
vector<BayesianNode*> queryVars;
|
||||
queryVars.push_back (queryVar);
|
||||
runSolver (queryVars);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::runSolver (vector<BayesianNode*> queryVars)
|
||||
{
|
||||
if (queryVars.size() > 1) {
|
||||
addJunctionNode (queryVars);
|
||||
}
|
||||
else {
|
||||
string varName = queryVars[0]->getVariableName();
|
||||
queryNode_ = static_cast<BpNode*> (getNode (varName));
|
||||
}
|
||||
|
||||
if (!isPolyTree()) {
|
||||
if (DL_ >= 1) {
|
||||
cout << "The graph is not single connected. " ;
|
||||
cout << "Iterative belief propagation will be used." ;
|
||||
cout << endl << endl;
|
||||
}
|
||||
schedule_ = PARALLEL_SCHEDULE;
|
||||
}
|
||||
|
||||
if (schedule_ == SEQUENTIAL_SCHEDULE) {
|
||||
initializeSolver (queryVars);
|
||||
runNeapolitanSolver();
|
||||
for (unsigned int i = 0; i < nodes_.size(); i++) {
|
||||
if (nodes_[i]->hasEvidence()) {
|
||||
BpNode* v = static_cast<BpNode*> (nodes_[i]);
|
||||
addEvidence (v);
|
||||
vector<BpNode*> parents = cast (v->getParents());
|
||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||
if (!parents[i]->hasEvidence()) {
|
||||
sendLambdaMessage (v, parents[i]);
|
||||
}
|
||||
}
|
||||
vector<BpNode*> childs = cast (v->getChilds());
|
||||
for (unsigned int i = 0; i < childs.size(); i++) {
|
||||
sendPiMessage (v, childs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (schedule_ == PARALLEL_SCHEDULE) {
|
||||
BpNode::enableParallelSchedule();
|
||||
initializeSolver (queryVars);
|
||||
for (unsigned int i = 0; i < nodes_.size(); i++) {
|
||||
if (nodes_[i]->hasEvidence()) {
|
||||
addEvidence (static_cast<BpNode*> (nodes_[i]));
|
||||
}
|
||||
}
|
||||
runIterativeBpSolver();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::printCurrentStatus (void)
|
||||
{
|
||||
for (unsigned int i = 0; i < nodes_.size(); i++) {
|
||||
printCurrentStatusOf (static_cast<BpNode*> (nodes_[i]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::printCurrentStatusOf (BpNode* x)
|
||||
{
|
||||
vector<BpNode*> childs = cast (x->getChilds());
|
||||
vector<string> domain = x->getDomain();
|
||||
|
||||
cout << left;
|
||||
cout << setw (10) << "domain" ;
|
||||
cout << setw (20) << "π(" + x->getVariableName() + ")" ;
|
||||
cout << setw (20) << "λ(" + x->getVariableName() + ")" ;
|
||||
cout << setw (16) << "belief" ;
|
||||
cout << endl;
|
||||
|
||||
cout << "--------------------------------" ;
|
||||
cout << "--------------------------------" ;
|
||||
cout << endl;
|
||||
|
||||
double* piValues = x->getPiValues();
|
||||
double* lambdaValues = x->getLambdaValues();
|
||||
double* beliefs = x->getBeliefs();
|
||||
for (int xi = 0; xi < x->getDomainSize(); xi++) {
|
||||
cout << setw (10) << domain[xi];
|
||||
cout << setw (19) << piValues[xi];
|
||||
cout << setw (19) << lambdaValues[xi];
|
||||
cout.precision (PRECISION_);
|
||||
cout << setw (16) << beliefs[xi];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
if (childs.size() > 0) {
|
||||
string s = "(" + x->getVariableName() + ")" ;
|
||||
for (unsigned int j = 0; j < childs.size(); j++) {
|
||||
cout << setw (10) << "domain" ;
|
||||
cout << setw (28) << "π" + childs[j]->getVariableName() + s;
|
||||
cout << setw (28) << "λ" + childs[j]->getVariableName() + s;
|
||||
cout << endl;
|
||||
cout << "--------------------------------" ;
|
||||
cout << "--------------------------------" ;
|
||||
cout << endl;
|
||||
for (int xi = 0; xi < x->getDomainSize(); xi++) {
|
||||
cout << setw (10) << domain[xi];
|
||||
cout.precision (PRECISION_);
|
||||
cout << setw (27) << x->getPiMessage(childs[j], xi);
|
||||
cout.precision (PRECISION_);
|
||||
cout << setw (27) << x->getLambdaMessage(childs[j], xi);
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::printBeliefs (void)
|
||||
{
|
||||
for (unsigned int i = 0; i < nodes_.size(); i++) {
|
||||
BpNode* x = static_cast<BpNode*> (nodes_[i]);
|
||||
vector<string> domain = x->getDomain();
|
||||
cout << setw (20) << left << x->getVariableName() ;
|
||||
cout << setw (26) << "belief" ;
|
||||
cout << endl;
|
||||
cout << "--------------------------------------" ;
|
||||
cout << endl;
|
||||
double* beliefs = x->getBeliefs();
|
||||
for (int xi = 0; xi < x->getDomainSize(); xi++) {
|
||||
cout << setw (20) << domain[xi];
|
||||
cout.precision (PRECISION_);
|
||||
cout << setw (26) << beliefs[xi];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<double>
|
||||
BpNetwork::getBeliefs (void)
|
||||
{
|
||||
return getBeliefs (queryNode_);
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<double>
|
||||
BpNetwork::getBeliefs (BpNode* x)
|
||||
{
|
||||
double* beliefs = x->getBeliefs();
|
||||
vector<double> beliefsVec;
|
||||
for (int xi = 0; xi < x->getDomainSize(); xi++) {
|
||||
beliefsVec.push_back (beliefs[xi]);
|
||||
}
|
||||
return beliefsVec;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::initializeSolver (vector<BayesianNode*> queryVars)
|
||||
{
|
||||
if (DL_ >= 1) {
|
||||
cout << "Initializing solver" << endl;
|
||||
if (schedule_ == SEQUENTIAL_SCHEDULE) {
|
||||
cout << "-> schedule = sequential" << endl;
|
||||
} else {
|
||||
cout << "-> schedule = parallel" << endl;
|
||||
}
|
||||
cout << "-> max iters = " << maxIter_ << endl;
|
||||
cout << "-> stable threashold = " << stableThreashold_ << endl;
|
||||
cout << "-> query vars = " ;
|
||||
for (unsigned int i = 0; i < queryVars.size(); i++) {
|
||||
cout << queryVars[i]->getVariableName() << " " ;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
nIter_ = 0;
|
||||
|
||||
for (unsigned int i = 0; i < nodes_.size(); i++) {
|
||||
BpNode* node = static_cast<BpNode*> (nodes_[i]);
|
||||
node->allocateMemory();
|
||||
}
|
||||
|
||||
for (unsigned int i = 0; i < nodes_.size(); i++) {
|
||||
BpNode* x = static_cast<BpNode*> (nodes_[i]);
|
||||
|
||||
double* piValues = x->getPiValues();
|
||||
double* lambdaValues = x->getLambdaValues();
|
||||
for (int xi = 0; xi < x->getDomainSize(); xi++) {
|
||||
piValues[xi] = 1.0;
|
||||
lambdaValues[xi] = 1.0;
|
||||
}
|
||||
|
||||
vector<BpNode*> xChilds = cast (x->getChilds());
|
||||
for (unsigned int j = 0; j < xChilds.size(); j++) {
|
||||
double* piMessages = x->getPiMessages (xChilds[j]);
|
||||
for (int xi = 0; xi < x->getDomainSize(); xi++) {
|
||||
piMessages[xi] = 1.0;
|
||||
//x->setPiMessage (xChilds[j], xi, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
vector<BpNode*> xParents = cast (x->getParents());
|
||||
for (unsigned int j = 0; j < xParents.size(); j++) {
|
||||
double* lambdaMessages = xParents[j]->getLambdaMessages (x);
|
||||
for (int xi = 0; xi < xParents[j]->getDomainSize(); xi++) {
|
||||
lambdaMessages[xi] = 1.0;
|
||||
//xParents[j]->setLambdaMessage (x, xi, 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned int i = 0; i < nodes_.size(); i++) {
|
||||
BpNode* x = static_cast<BpNode*> (nodes_[i]);
|
||||
x->normalizeMessages();
|
||||
}
|
||||
printCurrentStatus();
|
||||
|
||||
|
||||
vector<BpNode*> roots = cast (getRootNodes());
|
||||
for (unsigned int i = 0; i < roots.size(); i++) {
|
||||
double* params = roots[i]->getParameters();
|
||||
double* piValues = roots[i]->getPiValues();
|
||||
for (int ri = 0; ri < roots[i]->getDomainSize(); ri++) {
|
||||
piValues[ri] = params[ri];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::addJunctionNode (vector<BayesianNode*> queryVars)
|
||||
{
|
||||
const string VAR_NAME = "_Jn";
|
||||
int nStates = 1;
|
||||
vector<BayesianNode*> parents;
|
||||
vector<string> domain;
|
||||
for (unsigned int i = 0; i < queryVars.size(); i++) {
|
||||
parents.push_back (queryVars[i]);
|
||||
nStates *= queryVars[i]->getDomainSize();
|
||||
}
|
||||
|
||||
for (int i = 0; i < nStates; i++) {
|
||||
stringstream ss;
|
||||
ss << "_jn" << i;
|
||||
domain.push_back (ss.str()); // FIXME make domain optional
|
||||
}
|
||||
|
||||
int nParams = nStates * nStates;
|
||||
double* params = new double [nParams];
|
||||
for (int i = 0; i < nParams; i++) {
|
||||
int row = i / nStates;
|
||||
int col = i % nStates;
|
||||
if (row == col) {
|
||||
params[i] = 1;
|
||||
} else {
|
||||
params[i] = 0;
|
||||
}
|
||||
}
|
||||
addNode (VAR_NAME, parents, params, nParams, domain);
|
||||
queryNode_ = static_cast<BpNode*> (getNode (VAR_NAME));
|
||||
printNetwork();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::addEvidence (BpNode* v)
|
||||
{
|
||||
if (DL_ >= 1) {
|
||||
cout << "Adding evidence: node " ;
|
||||
cout << "`" << v->getVariableName() << "' was instantiated as " ;
|
||||
cout << "`" << v->getDomain()[v->getEvidence()] << "'" ;
|
||||
cout << endl;
|
||||
}
|
||||
double* piValues = v->getPiValues();
|
||||
double* lambdaValues = v->getLambdaValues();
|
||||
for (int vi = 0; vi < v->getDomainSize(); vi++) {
|
||||
if (vi == v->getEvidence()) {
|
||||
piValues[vi] = 1.0;
|
||||
lambdaValues[vi] = 1.0;
|
||||
} else {
|
||||
piValues[vi] = 0.0;
|
||||
lambdaValues[vi] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::runNeapolitanSolver (void)
|
||||
{
|
||||
vector<BpNode*> roots = cast (getRootNodes());
|
||||
for (unsigned int i = 0; i < roots.size(); i++) {
|
||||
vector<BpNode*> childs = cast (roots[i]->getChilds());
|
||||
for (unsigned int j = 0; j < childs.size(); j++) {
|
||||
sendPiMessage (roots[i], static_cast<BpNode*> (childs[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::sendPiMessage (BpNode* z, BpNode* x)
|
||||
{
|
||||
nIter_ ++;
|
||||
if (!(maxIter_ == -1 || nIter_ < maxIter_)) {
|
||||
cout << "the maximum number of iterations was achieved, terminating..." ;
|
||||
cout << endl;
|
||||
return;
|
||||
}
|
||||
|
||||
if (DL_ >= 1) {
|
||||
cout << "π message " << z->getVariableName();
|
||||
cout << " --> " << x->getVariableName() << endl;
|
||||
}
|
||||
|
||||
updatePiMessages(z, x);
|
||||
|
||||
if (!x->hasEvidence()) {
|
||||
updatePiValues (x);
|
||||
vector<BpNode*> xChilds = cast (x->getChilds());
|
||||
for (unsigned int i = 0; i < xChilds.size(); i++) {
|
||||
sendPiMessage (x, xChilds[i]);
|
||||
}
|
||||
}
|
||||
|
||||
bool isAllOnes = true;
|
||||
double* lambdaValues = x->getLambdaValues();
|
||||
for (int xi = 0; xi < x->getDomainSize(); xi++) {
|
||||
if (lambdaValues[xi] != 1.0) {
|
||||
isAllOnes = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!isAllOnes) {
|
||||
vector<BpNode*> xParents = cast (x->getParents());
|
||||
for (unsigned int i = 0; i < xParents.size(); i++) {
|
||||
if (xParents[i] != z && !xParents[i]->hasEvidence()) {
|
||||
sendLambdaMessage (x, xParents[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::sendLambdaMessage (BpNode* y, BpNode* x)
|
||||
{
|
||||
nIter_ ++;
|
||||
if (!(maxIter_ == -1 || nIter_ < maxIter_)) {
|
||||
cout << "the maximum number of iterations was achieved, terminating..." ;
|
||||
cout << endl;
|
||||
return;
|
||||
}
|
||||
|
||||
if (DL_ >= 1) {
|
||||
cout << "λ message " << y->getVariableName();
|
||||
cout << " --> " << x->getVariableName() << endl;
|
||||
}
|
||||
|
||||
updateLambdaMessages (x, y);
|
||||
updateLambdaValues (x);
|
||||
|
||||
vector<BpNode*> xParents = cast (x->getParents());
|
||||
for (unsigned int i = 0; i < xParents.size(); i++) {
|
||||
if (!xParents[i]->hasEvidence()) {
|
||||
sendLambdaMessage (x, xParents[i]);
|
||||
}
|
||||
}
|
||||
|
||||
vector<BpNode*> xChilds = cast (x->getChilds());
|
||||
for (unsigned int i = 0; i < xChilds.size(); i++) {
|
||||
if (xChilds[i] != y) {
|
||||
sendPiMessage (x, xChilds[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::updatePiValues (BpNode* x)
|
||||
{
|
||||
// π(Xi)
|
||||
vector<BpNode*> parents = cast (x->getParents());
|
||||
for (int xi = 0; xi < x->getDomainSize(); xi++) {
|
||||
stringstream calcs1;
|
||||
stringstream calcs2;
|
||||
if (DL_ >= 2) {
|
||||
calcs1 << "π("<< x->getDomain()[xi] << ")" << endl << "= " ;
|
||||
}
|
||||
double sum = 0.0;
|
||||
vector<pair<int, int> > constraints;
|
||||
vector<CptEntry> entries = x->getCptEntriesOfRow (xi);
|
||||
for (unsigned int k = 0; k < entries.size(); k++) {
|
||||
double prod = x->getProbability (entries[k]);
|
||||
if (DL_ >= 2) {
|
||||
if (k != 0) calcs1 << endl << "+ " ;
|
||||
calcs1 << x->entryToString (entries[k]);
|
||||
if (DL_ >= 3) {
|
||||
(k == 0) ? calcs2 << "(" << prod : calcs2 << endl << "+ (" << prod;
|
||||
}
|
||||
}
|
||||
vector<int> insts = entries[k].getDomainInstantiations();
|
||||
for (unsigned int i = 0; i < parents.size(); i++) {
|
||||
double value = parents[i]->getPiMessage (x, insts[i + 1]);
|
||||
prod *= value;
|
||||
if (DL_ >= 2) {
|
||||
calcs1 << ".π" << x->getVariableName();
|
||||
calcs1 << "(" << parents[i]->getDomain()[insts[i + 1]] << ")";
|
||||
if (DL_ >= 3) calcs2 << "x" << value;
|
||||
}
|
||||
}
|
||||
sum += prod;
|
||||
if (DL_ >= 3) calcs2 << ")";
|
||||
}
|
||||
x->setPiValue (xi, sum);
|
||||
if (DL_ >= 2) {
|
||||
cout << calcs1.str();
|
||||
if (DL_ >= 3) cout << endl << "= " << calcs2.str();
|
||||
cout << " = " << sum << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::updatePiMessages (BpNode* z, BpNode* x)
|
||||
{
|
||||
// πX(Zi)
|
||||
vector<BpNode*> zChilds = cast (z->getChilds());
|
||||
for (int zi = 0; zi < z->getDomainSize(); zi++) {
|
||||
stringstream calcs1;
|
||||
stringstream calcs2;
|
||||
if (DL_ >= 2) {
|
||||
calcs1 << "π" << x->getVariableName();
|
||||
calcs1 << "(" << z->getDomain()[zi] << ") = " ;
|
||||
}
|
||||
double prod = z->getPiValue (zi);
|
||||
if (DL_ >= 2) {
|
||||
calcs1 << "π(" << z->getDomain()[zi] << ")" ;
|
||||
if (DL_ >= 3) calcs2 << prod;
|
||||
}
|
||||
for (unsigned int i = 0; i < zChilds.size(); i++) {
|
||||
if (zChilds[i] != x) {
|
||||
double value = z->getLambdaMessage (zChilds[i], zi);
|
||||
prod *= value;
|
||||
if (DL_ >= 2) {
|
||||
calcs1 << ".λ" << zChilds[i]->getVariableName();
|
||||
calcs1 << "(" << z->getDomain()[zi] + ")" ;
|
||||
if (DL_ >= 3) calcs2 << " x " << value;
|
||||
}
|
||||
}
|
||||
}
|
||||
z->setPiMessage (x, zi, prod);
|
||||
if (DL_ >= 2) {
|
||||
cout << calcs1.str();
|
||||
if (DL_ >= 3) cout << " = " << calcs2.str();
|
||||
cout << " = " << prod << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::updateLambdaValues (BpNode* x)
|
||||
{
|
||||
// λ(Xi)
|
||||
vector<BpNode*> childs = cast (x->getChilds());
|
||||
for (int xi = 0; xi < x->getDomainSize(); xi++) {
|
||||
stringstream calcs1;
|
||||
stringstream calcs2;
|
||||
if (DL_ >= 2) {
|
||||
calcs1 << "λ" << "(" << x->getDomain()[xi] << ") = " ;
|
||||
}
|
||||
double prod = 1.0;
|
||||
for (unsigned int i = 0; i < childs.size(); i++) {
|
||||
double val = x->getLambdaMessage (childs[i], xi);
|
||||
prod *= val;
|
||||
if (DL_ >= 2) {
|
||||
if (i != 0) calcs1 << "." ;
|
||||
calcs1 << "λ" << childs[i]->getVariableName();
|
||||
calcs1 << "(" << x->getDomain()[xi] + ")" ;
|
||||
if (DL_ >= 3) (i == 0) ? calcs2 << val : calcs2 << " x " << val;
|
||||
}
|
||||
}
|
||||
x->setLambdaValue (xi, prod);
|
||||
if (DL_ >= 2) {
|
||||
cout << calcs1.str();
|
||||
if (childs.size() == 0) {
|
||||
cout << 1 << endl;
|
||||
} else {
|
||||
if (DL_ >= 3) cout << " = " << calcs2.str();
|
||||
cout << " = " << prod << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::updateLambdaMessages (BpNode* x, BpNode* y)
|
||||
{
|
||||
// λY(Xi)
|
||||
int parentIndex = y->getIndexOfParent (x) + 1;
|
||||
vector<BpNode*> yParents = cast (y->getParents());
|
||||
for (int xi = 0; xi < x->getDomainSize(); xi++) {
|
||||
stringstream calcs1;
|
||||
stringstream calcs2;
|
||||
if (DL_ >= 2) {
|
||||
calcs1 << "λ" << y->getVariableName() ;
|
||||
calcs1 << "(" << x->getDomain()[xi] << ")" << endl << "= " ;
|
||||
}
|
||||
double outer_sum = 0.0;
|
||||
for (int yi = 0; yi < y->getDomainSize(); yi++) {
|
||||
if (DL_ >= 2) {
|
||||
(yi == 0) ? calcs1 << "[" : calcs1 << endl << "+ [" ;
|
||||
if (DL_ >= 3) {
|
||||
(yi == 0) ? calcs2 << "[" : calcs2 << endl << "+ [" ;
|
||||
}
|
||||
}
|
||||
double inner_sum = 0.0;
|
||||
vector<pair<int, int> > constraints;
|
||||
constraints.push_back (make_pair (0, yi));
|
||||
constraints.push_back (make_pair (parentIndex, xi));
|
||||
vector<CptEntry> entries = y->getCptEntries (constraints);
|
||||
for (unsigned int k = 0; k < entries.size(); k++) {
|
||||
double prod = y->getProbability (entries[k]);
|
||||
if (DL_ >= 2) {
|
||||
if (k != 0) calcs1 << " + " ;
|
||||
calcs1 << y->entryToString (entries[k]);
|
||||
if (DL_ >= 3) {
|
||||
(k == 0) ? calcs2 << "(" << prod : calcs2 << " + (" << prod;
|
||||
}
|
||||
}
|
||||
vector<int> insts = entries[k].getDomainInstantiations();
|
||||
for (unsigned int i = 0; i < yParents.size(); i++) {
|
||||
if (yParents[i] != x) {
|
||||
double val = yParents[i]->getPiMessage (y, insts[i + 1]);
|
||||
prod *= val;
|
||||
if (DL_ >= 2) {
|
||||
calcs1 << ".π" << y->getVariableName();
|
||||
calcs1 << "(" << yParents[i]->getDomain()[insts[i + 1]] << ")" ;
|
||||
if (DL_ >= 3) calcs2 << "x" << val;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
inner_sum += prod;
|
||||
if (DL_ >= 3) {
|
||||
calcs2 << ")" ;
|
||||
}
|
||||
}
|
||||
outer_sum += inner_sum * y->getLambdaValue (yi);
|
||||
if (DL_ >= 2) {
|
||||
calcs1 << "].λ(" << y->getDomain()[yi] << ")";
|
||||
if (DL_ >= 3) calcs2 << "]x" << y->getLambdaValue (yi);
|
||||
}
|
||||
}
|
||||
x->setLambdaMessage (y, xi, outer_sum);
|
||||
if (DL_ >= 2) {
|
||||
cout << calcs1.str();
|
||||
if (DL_ >= 3) cout << endl << "= " << calcs2.str();
|
||||
cout << " = " << outer_sum << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::runIterativeBpSolver()
|
||||
{
|
||||
int nIter = 0;
|
||||
maxIter_ = 100;
|
||||
bool converged = false;
|
||||
while (nIter < maxIter_ && !converged) {
|
||||
if (DL_ >= 1) {
|
||||
cout << endl << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
cout << " Iteration " << nIter + 1 << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
}
|
||||
|
||||
for (unsigned int i = 0; i < nodes_.size(); i++) {
|
||||
BpNode* x = static_cast<BpNode*>(nodes_[i]);
|
||||
vector<BpNode*> xParents = cast (x->getParents());
|
||||
for (unsigned int j = 0; j < xParents.size(); j++) {
|
||||
//if (!xParents[j]->hasEvidence()) {
|
||||
if (DL_ >= 1) {
|
||||
cout << endl << "λ message " << x->getVariableName();
|
||||
cout << " --> " << xParents[j]->getVariableName() << endl;
|
||||
}
|
||||
updateLambdaMessages (xParents[j], x);
|
||||
//}
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned int i = 0; i < nodes_.size(); i++) {
|
||||
BpNode* x = static_cast<BpNode*>(nodes_[i]);
|
||||
vector<BpNode*> xChilds = cast (x->getChilds());
|
||||
for (unsigned int j = 0; j < xChilds.size(); j++) {
|
||||
if (DL_ >= 1) {
|
||||
cout << endl << "π message " << x->getVariableName();
|
||||
cout << " --> " << xChilds[j]->getVariableName() << endl;
|
||||
}
|
||||
updatePiMessages (x, xChilds[j]);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
for (unsigned int i = 0; i < nodes_.size(); i++) {
|
||||
BpNode* x = static_cast<BpNode*>(nodes_[i]);
|
||||
vector<BpNode*> xChilds = cast (x->getChilds());
|
||||
for (unsigned int j = 0; j < xChilds.size(); j++) {
|
||||
if (DL_ >= 1) {
|
||||
cout << "π message " << x->getVariableName();
|
||||
cout << " --> " << xChilds[j]->getVariableName() << endl;
|
||||
}
|
||||
updatePiMessages (x, xChilds[j]);
|
||||
}
|
||||
vector<BpNode*> xParents = cast (x->getParents());
|
||||
for (unsigned int j = 0; j < xParents.size(); j++) {
|
||||
//if (!xParents[j]->hasEvidence()) {
|
||||
if (DL_ >= 1) {
|
||||
cout << "λ message " << x->getVariableName();
|
||||
cout << " --> " << xParents[j]->getVariableName() << endl;
|
||||
}
|
||||
updateLambdaMessages (xParents[j], x);
|
||||
//}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
for (unsigned int i = 0; i < nodes_.size(); i++) {
|
||||
BpNode* x = static_cast<BpNode*> (nodes_[i]);
|
||||
//cout << endl << "SWAPING MESSAGES FOR " << x->getVariableName() << ":" ;
|
||||
//cout << endl << endl;
|
||||
//printCurrentStatusOf (x);
|
||||
x->swapMessages();
|
||||
x->normalizeMessages();
|
||||
//cout << endl << "messages swaped " << endl;
|
||||
//printCurrentStatusOf (x);
|
||||
}
|
||||
|
||||
converged = true;
|
||||
for (unsigned int i = 0; i < nodes_.size(); i++) {
|
||||
BpNode* x = static_cast<BpNode*>(nodes_[i]);
|
||||
if (DL_ >= 1) {
|
||||
cout << endl << "var " << x->getVariableName() << ":" << endl;
|
||||
}
|
||||
//if (!x->hasEvidence()) {
|
||||
updatePiValues (x);
|
||||
updateLambdaValues (x);
|
||||
double change = x->getBeliefChange();
|
||||
if (DL_ >= 1) {
|
||||
cout << "belief change = " << change << endl;
|
||||
}
|
||||
if (change > stableThreashold_) {
|
||||
converged = false;
|
||||
}
|
||||
//}
|
||||
}
|
||||
|
||||
if (converged) {
|
||||
// converged = false;
|
||||
}
|
||||
if (DL_ >= 2) {
|
||||
cout << endl;
|
||||
printCurrentStatus();
|
||||
}
|
||||
nIter++;
|
||||
}
|
||||
|
||||
if (DL_ >= 1) {
|
||||
cout << endl;
|
||||
if (converged) {
|
||||
cout << "Iterative belief propagation converged in " ;
|
||||
cout << nIter << " iterations" << endl;
|
||||
} else {
|
||||
cout << "Iterative belief propagation converged didn't converge" ;
|
||||
cout << endl;
|
||||
}
|
||||
if (DL_ == 1) {
|
||||
cout << endl;
|
||||
printBeliefs();
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::addNode (string varName,
|
||||
vector<BayesianNode*> parents,
|
||||
int evidence,
|
||||
int distId)
|
||||
{
|
||||
for (unsigned int i = 0; i < dists_.size(); i++) {
|
||||
if (dists_[i]->id == distId) {
|
||||
BpNode* node = new BpNode (varName, parents, dists_[i], evidence);
|
||||
nodes_.push_back (node);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNetwork::addNode (string varName,
|
||||
vector<BayesianNode*> parents,
|
||||
double* params,
|
||||
int nParams,
|
||||
vector<string> domain)
|
||||
{
|
||||
Distribution* dist = new Distribution (params, nParams, domain);
|
||||
BpNode* node = new BpNode (varName, parents, dist);
|
||||
dists_.push_back (dist);
|
||||
nodes_.push_back (node);
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<BpNode*>
|
||||
BpNetwork::cast (vector<BayesianNode*> nodes)
|
||||
{
|
||||
vector<BpNode*> castedNodes (nodes.size());
|
||||
for (unsigned int i = 0; i < nodes.size(); i++) {
|
||||
castedNodes[i] = static_cast<BpNode*> (nodes[i]);
|
||||
}
|
||||
return castedNodes;
|
||||
}
|
||||
|
@ -1,66 +0,0 @@
|
||||
#ifndef BP_BP_NETWORK_H
|
||||
#define BP_BP_NETWORK_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "BayesianNetwork.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class BpNode;
|
||||
|
||||
enum Schedule
|
||||
{
|
||||
SEQUENTIAL_SCHEDULE,
|
||||
PARALLEL_SCHEDULE
|
||||
};
|
||||
|
||||
class BpNetwork : public BayesianNetwork
|
||||
{
|
||||
public:
|
||||
// constructs
|
||||
BpNetwork (void);
|
||||
// destruct
|
||||
~BpNetwork (void);
|
||||
// methods
|
||||
void setSolverParameters (Schedule, int, double);
|
||||
void runSolver (BayesianNode* queryVar);
|
||||
void runSolver (vector<BayesianNode*>);
|
||||
void printCurrentStatus (void);
|
||||
void printCurrentStatusOf (BpNode*);
|
||||
void printBeliefs (void);
|
||||
vector<double> getBeliefs (void);
|
||||
vector<double> getBeliefs (BpNode*);
|
||||
|
||||
private:
|
||||
BpNetwork (const BpNetwork&); // disallow copy
|
||||
void operator= (const BpNetwork&); // disallow assign
|
||||
// methods
|
||||
void initializeSolver (vector<BayesianNode*>);
|
||||
void addJunctionNode (vector<BayesianNode*>);
|
||||
void addEvidence (BpNode*);
|
||||
void runNeapolitanSolver (void);
|
||||
void sendLambdaMessage (BpNode*, BpNode*);
|
||||
void sendPiMessage (BpNode*, BpNode*);
|
||||
void updatePiValues (BpNode*);
|
||||
void updatePiMessages (BpNode*, BpNode*);
|
||||
void updateLambdaValues (BpNode*);
|
||||
void updateLambdaMessages (BpNode*, BpNode*);
|
||||
void runIterativeBpSolver (void);
|
||||
void addNode (string, vector<BayesianNode*>, int, int);
|
||||
void addNode (string, vector<BayesianNode*>,
|
||||
double*, int, vector<string>);
|
||||
vector<BpNode*> cast (vector<BayesianNode*>);
|
||||
// members
|
||||
Schedule schedule_;
|
||||
int nIter_;
|
||||
int maxIter_;
|
||||
double stableThreashold_;
|
||||
BpNode* queryNode_;
|
||||
static const int DL_ = 3;
|
||||
static const int PRECISION_ = 10;
|
||||
};
|
||||
|
||||
#endif // BP_BP_NETWORK_H
|
||||
|
@ -1,250 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include "BpNode.h"
|
||||
|
||||
bool BpNode::calculateMessageResidual_ = true;
|
||||
|
||||
|
||||
BpNode::BpNode (BayesNode* node)
|
||||
{
|
||||
ds_ = node->getDomainSize();
|
||||
const NodeSet& childs = node->getChilds();
|
||||
piVals_.resize (ds_, 1);
|
||||
ldVals_.resize (ds_, 1);
|
||||
if (calculateMessageResidual_) {
|
||||
piResiduals_.resize (childs.size(), 0.0);
|
||||
ldResiduals_.resize (childs.size(), 0.0);
|
||||
}
|
||||
childs_ = &childs;
|
||||
for (unsigned i = 0; i < childs.size(); i++) {
|
||||
//indexMap_.insert (make_pair (childs[i]->getVarId(), i));
|
||||
currPiMsgs_.push_back (ParamSet (ds_, 1));
|
||||
currLdMsgs_.push_back (ParamSet (ds_, 1));
|
||||
nextPiMsgs_.push_back (ParamSet (ds_, 1));
|
||||
nextLdMsgs_.push_back (ParamSet (ds_, 1));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
BpNode::getBeliefs (void) const
|
||||
{
|
||||
double sum = 0.0;
|
||||
ParamSet beliefs (ds_);
|
||||
for (int xi = 0; xi < ds_; xi++) {
|
||||
double prod = piVals_[xi] * ldVals_[xi];
|
||||
beliefs[xi] = prod;
|
||||
sum += prod;
|
||||
}
|
||||
assert (sum);
|
||||
//normalize the beliefs
|
||||
for (int xi = 0; xi < ds_; xi++) {
|
||||
beliefs[xi] /= sum;
|
||||
}
|
||||
return beliefs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
BpNode::getPiValue (int idx) const
|
||||
{
|
||||
assert (idx >=0 && idx < ds_);
|
||||
return piVals_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNode::setPiValue (int idx, double value)
|
||||
{
|
||||
assert (idx >=0 && idx < ds_);
|
||||
piVals_[idx] = value;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
BpNode::getLambdaValue (int idx) const
|
||||
{
|
||||
assert (idx >=0 && idx < ds_);
|
||||
return ldVals_[idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNode::setLambdaValue (int idx, double value)
|
||||
{
|
||||
assert (idx >=0 && idx < ds_);
|
||||
ldVals_[idx] = value;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet&
|
||||
BpNode::getPiValues (void)
|
||||
{
|
||||
return piVals_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet&
|
||||
BpNode::getLambdaValues (void)
|
||||
{
|
||||
return ldVals_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
BpNode::getPiMessageValue (const BayesNode* destination, int idx) const
|
||||
{
|
||||
assert (idx >=0 && idx < ds_);
|
||||
return currPiMsgs_[getIndex(destination)][idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
BpNode::getLambdaMessageValue (const BayesNode* source, int idx) const
|
||||
{
|
||||
assert (idx >=0 && idx < ds_);
|
||||
return currLdMsgs_[getIndex(source)][idx];
|
||||
}
|
||||
|
||||
|
||||
|
||||
const ParamSet&
|
||||
BpNode::getPiMessage (const BayesNode* destination) const
|
||||
{
|
||||
return currPiMsgs_[getIndex(destination)];
|
||||
}
|
||||
|
||||
|
||||
|
||||
const ParamSet&
|
||||
BpNode::getLambdaMessage (const BayesNode* source) const
|
||||
{
|
||||
return currLdMsgs_[getIndex(source)];
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet&
|
||||
BpNode::piNextMessageReference (const BayesNode* destination)
|
||||
{
|
||||
return nextPiMsgs_[getIndex(destination)];
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet&
|
||||
BpNode::lambdaNextMessageReference (const BayesNode* source)
|
||||
{
|
||||
return nextLdMsgs_[getIndex(source)];
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNode::updatePiMessage (const BayesNode* destination)
|
||||
{
|
||||
int idx = getIndex (destination);
|
||||
currPiMsgs_[idx] = nextPiMsgs_[idx];
|
||||
Util::normalize (currPiMsgs_[idx]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNode::updateLambdaMessage (const BayesNode* source)
|
||||
{
|
||||
int idx = getIndex (source);
|
||||
currLdMsgs_[idx] = nextLdMsgs_[idx];
|
||||
Util::normalize (currLdMsgs_[idx]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
BpNode::getBeliefChange (void)
|
||||
{
|
||||
double change = 0.0;
|
||||
if (oldBeliefs_.size() == 0) {
|
||||
oldBeliefs_ = getBeliefs();
|
||||
change = 9999999999.0;
|
||||
} else {
|
||||
ParamSet currentBeliefs = getBeliefs();
|
||||
for (int xi = 0; xi < ds_; xi++) {
|
||||
change += abs (currentBeliefs[xi] - oldBeliefs_[xi]);
|
||||
}
|
||||
oldBeliefs_ = currentBeliefs;
|
||||
}
|
||||
return change;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNode::updatePiResidual (const BayesNode* destination)
|
||||
{
|
||||
int idx = getIndex (destination);
|
||||
Util::normalize (nextPiMsgs_[idx]);
|
||||
//piResiduals_[idx] = Util::getL1dist (
|
||||
// currPiMsgs_[idx], nextPiMsgs_[idx]);
|
||||
piResiduals_[idx] = Util::getMaxNorm (
|
||||
currPiMsgs_[idx], nextPiMsgs_[idx]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNode::updateLambdaResidual (const BayesNode* source)
|
||||
{
|
||||
int idx = getIndex (source);
|
||||
Util::normalize (nextLdMsgs_[idx]);
|
||||
//ldResiduals_[idx] = Util::getL1dist (
|
||||
// currLdMsgs_[idx], nextLdMsgs_[idx]);
|
||||
ldResiduals_[idx] = Util::getMaxNorm (
|
||||
currLdMsgs_[idx], nextLdMsgs_[idx]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNode::clearPiResidual (const BayesNode* destination)
|
||||
{
|
||||
piResiduals_[getIndex(destination)] = 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
BpNode::clearLambdaResidual (const BayesNode* source)
|
||||
{
|
||||
ldResiduals_[getIndex(source)] = 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
BpNode::hasReceivedChildInfluence (void) const
|
||||
{
|
||||
// if all lambda values are equal, then neither
|
||||
// this node neither its descendents have evidence,
|
||||
// we can use this to don't send lambda messages his parents
|
||||
bool childInfluenced = false;
|
||||
for (int xi = 1; xi < ds_; xi++) {
|
||||
if (ldVals_[xi] != ldVals_[0]) {
|
||||
childInfluenced = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return childInfluenced;
|
||||
}
|
||||
|
@ -1,99 +0,0 @@
|
||||
#ifndef BP_BPNODE_H
|
||||
#define BP_BPNODE_H
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "BayesNode.h"
|
||||
#include "Shared.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class BpNode
|
||||
{
|
||||
public:
|
||||
BpNode (int);
|
||||
BpNode (BayesNode*);
|
||||
|
||||
ParamSet getBeliefs (void) const;
|
||||
double getPiValue (int) const;
|
||||
void setPiValue (int, double);
|
||||
double getLambdaValue (int) const;
|
||||
void setLambdaValue (int, double);
|
||||
ParamSet& getPiValues (void);
|
||||
ParamSet& getLambdaValues (void);
|
||||
double getPiMessageValue (const BayesNode*, int) const;
|
||||
double getLambdaMessageValue (const BayesNode*, int) const;
|
||||
const ParamSet& getPiMessage (const BayesNode*) const;
|
||||
const ParamSet& getLambdaMessage (const BayesNode*) const;
|
||||
ParamSet& piNextMessageReference (const BayesNode*);
|
||||
ParamSet& lambdaNextMessageReference (const BayesNode*);
|
||||
void updatePiMessage (const BayesNode*);
|
||||
void updateLambdaMessage (const BayesNode*);
|
||||
double getBeliefChange (void);
|
||||
void updatePiResidual (const BayesNode*);
|
||||
void updateLambdaResidual (const BayesNode*);
|
||||
void clearPiResidual (const BayesNode*);
|
||||
void clearLambdaResidual (const BayesNode*);
|
||||
bool hasReceivedChildInfluence (void) const;
|
||||
// inlines
|
||||
double getPiResidual (const BayesNode*);
|
||||
double getLambdaResidual (const BayesNode*);
|
||||
int getIndex (const BayesNode*) const;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (BpNode);
|
||||
|
||||
IndexMap indexMap_;
|
||||
ParamSet piVals_; // pi values
|
||||
ParamSet ldVals_; // lambda values
|
||||
vector<ParamSet> currPiMsgs_; // current pi messages
|
||||
vector<ParamSet> currLdMsgs_; // current lambda messages
|
||||
vector<ParamSet> nextPiMsgs_;
|
||||
vector<ParamSet> nextLdMsgs_;
|
||||
ParamSet oldBeliefs_;
|
||||
ParamSet piResiduals_;
|
||||
ParamSet ldResiduals_;
|
||||
int ds_;
|
||||
const NodeSet* childs_;
|
||||
static bool calculateMessageResidual_;
|
||||
// static const double MAX_CHANGE_ = 10000000.0;
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline double
|
||||
BpNode::getPiResidual (const BayesNode* destination)
|
||||
{
|
||||
return piResiduals_[getIndex(destination)];
|
||||
}
|
||||
|
||||
|
||||
inline double
|
||||
BpNode::getLambdaResidual (const BayesNode* source)
|
||||
{
|
||||
return ldResiduals_[getIndex(source)];
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline int
|
||||
BpNode::getIndex (const BayesNode* node) const
|
||||
{
|
||||
assert (node);
|
||||
//assert (indexMap_.find(node->getVarId()) != indexMap_.end());
|
||||
//return indexMap_.find(node->getVarId())->second;
|
||||
for (unsigned i = 0; childs_->size(); i++) {
|
||||
if ((*childs_)[i]->getVarId() == node->getVarId()) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
assert (false);
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
344
packages/CLPBN/clpbn/bp/CFactorGraph.cpp
Normal file
344
packages/CLPBN/clpbn/bp/CFactorGraph.cpp
Normal file
@ -0,0 +1,344 @@
|
||||
|
||||
#include "CFactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "Distribution.h"
|
||||
|
||||
|
||||
bool CFactorGraph::checkForIdenticalFactors_ = true;
|
||||
|
||||
CFactorGraph::CFactorGraph (const FactorGraph& fg)
|
||||
{
|
||||
groundFg_ = &fg;
|
||||
freeColor_ = 0;
|
||||
|
||||
const FgVarSet& varNodes = fg.getVarNodes();
|
||||
varSignatures_.reserve (varNodes.size());
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
unsigned c = (varNodes[i]->neighbors().size() * 2) + 1;
|
||||
varSignatures_.push_back (Signature (c));
|
||||
}
|
||||
|
||||
const FgFacSet& facNodes = fg.getFactorNodes();
|
||||
factorSignatures_.reserve (facNodes.size());
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
unsigned c = facNodes[i]->neighbors().size() + 1;
|
||||
factorSignatures_.push_back (Signature (c));
|
||||
}
|
||||
|
||||
varColors_.resize (varNodes.size());
|
||||
factorColors_.resize (facNodes.size());
|
||||
setInitialColors();
|
||||
createGroups();
|
||||
}
|
||||
|
||||
|
||||
|
||||
CFactorGraph::~CFactorGraph (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
delete varClusters_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < factorClusters_.size(); i++) {
|
||||
delete factorClusters_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CFactorGraph::setInitialColors (void)
|
||||
{
|
||||
// create the initial variable colors
|
||||
VarColorMap colorMap;
|
||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
unsigned dsize = varNodes[i]->nrStates();
|
||||
VarColorMap::iterator it = colorMap.find (dsize);
|
||||
if (it == colorMap.end()) {
|
||||
it = colorMap.insert (make_pair (
|
||||
dsize, vector<Color> (dsize+1,-1))).first;
|
||||
}
|
||||
unsigned idx;
|
||||
if (varNodes[i]->hasEvidence()) {
|
||||
idx = varNodes[i]->getEvidence();
|
||||
} else {
|
||||
idx = dsize;
|
||||
}
|
||||
vector<Color>& stateColors = it->second;
|
||||
if (stateColors[idx] == -1) {
|
||||
stateColors[idx] = getFreeColor();
|
||||
}
|
||||
setColor (varNodes[i], stateColors[idx]);
|
||||
}
|
||||
|
||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
||||
if (checkForIdenticalFactors_) {
|
||||
for (unsigned i = 0; i < facNodes.size() - 1; i++) {
|
||||
// facNodes[i]->factor()->orderFactorVariables();
|
||||
// FIXME
|
||||
}
|
||||
for (unsigned i = 0, s = facNodes.size(); i < s; i++) {
|
||||
Distribution* dist1 = facNodes[i]->getDistribution();
|
||||
for (unsigned j = 0; j < i; j++) {
|
||||
Distribution* dist2 = facNodes[j]->getDistribution();
|
||||
if (dist1 != dist2 && dist1->params == dist2->params) {
|
||||
facNodes[i]->factor()->setDistribution (dist2);
|
||||
// delete dist2;
|
||||
break;
|
||||
}
|
||||
/*
|
||||
if (ok) {
|
||||
const FgVarSet& fiVars = factors[i]->getFgVarNodes();
|
||||
const FgVarSet& fjVars = factors[j]->getFgVarNodes();
|
||||
if (fiVars.size() != fjVars.size()) continue;
|
||||
for (unsigned k = 0; k < fiVars.size(); k++) {
|
||||
if (fiVars[k]->nrStates() != fjVars[k]->nrStates()) {
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// create the initial factor colors
|
||||
DistColorMap distColors;
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const Distribution* dist = facNodes[i]->getDistribution();
|
||||
DistColorMap::iterator it = distColors.find (dist);
|
||||
if (it == distColors.end()) {
|
||||
it = distColors.insert (make_pair (dist, getFreeColor())).first;
|
||||
}
|
||||
setColor (facNodes[i], it->second);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CFactorGraph::createGroups (void)
|
||||
{
|
||||
VarSignMap varGroups;
|
||||
FacSignMap factorGroups;
|
||||
unsigned nIters = 0;
|
||||
bool groupsHaveChanged = true;
|
||||
const FgVarSet& varNodes = groundFg_->getVarNodes();
|
||||
const FgFacSet& facNodes = groundFg_->getFactorNodes();
|
||||
|
||||
while (groupsHaveChanged || nIters == 1) {
|
||||
nIters ++;
|
||||
|
||||
unsigned prevFactorGroupsSize = factorGroups.size();
|
||||
factorGroups.clear();
|
||||
// set a new color to the factors with the same signature
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const Signature& signature = getSignature (facNodes[i]);
|
||||
FacSignMap::iterator it = factorGroups.find (signature);
|
||||
if (it == factorGroups.end()) {
|
||||
it = factorGroups.insert (make_pair (signature, FgFacSet())).first;
|
||||
}
|
||||
it->second.push_back (facNodes[i]);
|
||||
}
|
||||
for (FacSignMap::iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
FgFacSet& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
}
|
||||
|
||||
// set a new color to the variables with the same signature
|
||||
unsigned prevVarGroupsSize = varGroups.size();
|
||||
varGroups.clear();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
const Signature& signature = getSignature (varNodes[i]);
|
||||
VarSignMap::iterator it = varGroups.find (signature);
|
||||
if (it == varGroups.end()) {
|
||||
it = varGroups.insert (make_pair (signature, FgVarSet())).first;
|
||||
}
|
||||
it->second.push_back (varNodes[i]);
|
||||
}
|
||||
for (VarSignMap::iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
FgVarSet& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
}
|
||||
|
||||
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||
|| prevFactorGroupsSize != factorGroups.size();
|
||||
}
|
||||
//printGroups (varGroups, factorGroups);
|
||||
createClusters (varGroups, factorGroups);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CFactorGraph::createClusters (const VarSignMap& varGroups,
|
||||
const FacSignMap& factorGroups)
|
||||
{
|
||||
varClusters_.reserve (varGroups.size());
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
const FgVarSet& groupVars = it->second;
|
||||
VarCluster* vc = new VarCluster (groupVars);
|
||||
for (unsigned i = 0; i < groupVars.size(); i++) {
|
||||
vid2VarCluster_.insert (make_pair (groupVars[i]->varId(), vc));
|
||||
}
|
||||
varClusters_.push_back (vc);
|
||||
}
|
||||
|
||||
factorClusters_.reserve (factorGroups.size());
|
||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
FgFacNode* groupFactor = it->second[0];
|
||||
const FgVarSet& neighs = groupFactor->neighbors();
|
||||
VarClusterSet varClusters;
|
||||
varClusters.reserve (neighs.size());
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
VarId vid = neighs[i]->varId();
|
||||
varClusters.push_back (vid2VarCluster_.find (vid)->second);
|
||||
}
|
||||
factorClusters_.push_back (new FacCluster (it->second, varClusters));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
const Signature&
|
||||
CFactorGraph::getSignature (const FgVarNode* varNode)
|
||||
{
|
||||
Signature& sign = varSignatures_[varNode->getIndex()];
|
||||
vector<Color>::iterator it = sign.colors.begin();
|
||||
const FgFacSet& neighs = varNode->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
*it = getColor (neighs[i]);
|
||||
it ++;
|
||||
*it = neighs[i]->factor()->getPositionOf (varNode->varId());
|
||||
it ++;
|
||||
}
|
||||
*it = getColor (varNode);
|
||||
return sign;
|
||||
}
|
||||
|
||||
|
||||
|
||||
const Signature&
|
||||
CFactorGraph::getSignature (const FgFacNode* facNode)
|
||||
{
|
||||
Signature& sign = factorSignatures_[facNode->getIndex()];
|
||||
vector<Color>::iterator it = sign.colors.begin();
|
||||
const FgVarSet& neighs = facNode->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
*it = getColor (neighs[i]);
|
||||
it ++;
|
||||
}
|
||||
*it = getColor (facNode);
|
||||
return sign;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
CFactorGraph::getCompressedFactorGraph (void)
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0];
|
||||
FgVarNode* newVar = new FgVarNode (var);
|
||||
varClusters_[i]->setRepresentativeVariable (newVar);
|
||||
fg->addVariable (newVar);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < factorClusters_.size(); i++) {
|
||||
const VarClusterSet& myVarClusters = factorClusters_[i]->getVarClusters();
|
||||
VarNodes myGroundVars;
|
||||
myGroundVars.reserve (myVarClusters.size());
|
||||
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
||||
FgVarNode* v = myVarClusters[j]->getRepresentativeVariable();
|
||||
myGroundVars.push_back (v);
|
||||
}
|
||||
Factor* newFactor = new Factor (myGroundVars,
|
||||
factorClusters_[i]->getGroundFactors()[0]->getDistribution());
|
||||
FgFacNode* fn = new FgFacNode (newFactor);
|
||||
factorClusters_[i]->setRepresentativeFactor (fn);
|
||||
fg->addFactor (fn);
|
||||
for (unsigned j = 0; j < myGroundVars.size(); j++) {
|
||||
fg->addEdge (fn, static_cast<FgVarNode*> (myGroundVars[j]));
|
||||
}
|
||||
}
|
||||
fg->setIndexes();
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
CFactorGraph::getGroundEdgeCount (const FacCluster* fc,
|
||||
const VarCluster* vc) const
|
||||
{
|
||||
const FgFacSet& clusterGroundFactors = fc->getGroundFactors();
|
||||
FgVarNode* varNode = vc->getGroundFgVarNodes()[0];
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||
if (clusterGroundFactors[i]->factor()->getPositionOf (varNode->varId()) != -1) {
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
// CFgVarSet vars = vc->getGroundFgVarNodes();
|
||||
// for (unsigned i = 1; i < vars.size(); i++) {
|
||||
// FgVarNode* var = vc->getGroundFgVarNodes()[i];
|
||||
// unsigned count2 = 0;
|
||||
// for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||
// if (clusterGroundFactors[i]->getPositionOf (var) != -1) {
|
||||
// count2 ++;
|
||||
// }
|
||||
// }
|
||||
// if (count != count2) { cout << "oops!" << endl; abort(); }
|
||||
// }
|
||||
return count;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CFactorGraph::printGroups (const VarSignMap& varGroups,
|
||||
const FacSignMap& factorGroups) const
|
||||
{
|
||||
unsigned count = 1;
|
||||
cout << "variable groups:" << endl;
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
const FgVarSet& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << count << ": " ;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
cout << groupMembers[i]->label() << " " ;
|
||||
}
|
||||
count ++;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
count = 1;
|
||||
cout << endl << "factor groups:" << endl;
|
||||
for (FacSignMap::const_iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
const FgFacSet& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << ++count << ": " ;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
cout << groupMembers[i]->getLabel() << " " ;
|
||||
}
|
||||
count ++;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
237
packages/CLPBN/clpbn/bp/CFactorGraph.h
Normal file
237
packages/CLPBN/clpbn/bp/CFactorGraph.h
Normal file
@ -0,0 +1,237 @@
|
||||
#ifndef HORUS_CFACTORGRAPH_H
|
||||
#define HORUS_CFACTORGRAPH_H
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "Shared.h"
|
||||
|
||||
class VarCluster;
|
||||
class FacCluster;
|
||||
class Distribution;
|
||||
class Signature;
|
||||
|
||||
class SignatureHash;
|
||||
|
||||
|
||||
typedef long Color;
|
||||
typedef unordered_map<unsigned, vector<Color> > VarColorMap;
|
||||
typedef unordered_map<const Distribution*, Color> DistColorMap;
|
||||
typedef unordered_map<VarId, VarCluster*> VarId2VarCluster;
|
||||
typedef vector<VarCluster*> VarClusterSet;
|
||||
typedef vector<FacCluster*> FacClusterSet;
|
||||
typedef unordered_map<Signature, FgVarSet, SignatureHash> VarSignMap;
|
||||
typedef unordered_map<Signature, FgFacSet, SignatureHash> FacSignMap;
|
||||
|
||||
|
||||
|
||||
struct Signature {
|
||||
Signature (unsigned size)
|
||||
{
|
||||
colors.resize (size);
|
||||
}
|
||||
bool operator< (const Signature& sig) const
|
||||
{
|
||||
if (colors.size() < sig.colors.size()) {
|
||||
return true;
|
||||
} else if (colors.size() > sig.colors.size()) {
|
||||
return false;
|
||||
} else {
|
||||
for (unsigned i = 0; i < colors.size(); i++) {
|
||||
if (colors[i] < sig.colors[i]) {
|
||||
return true;
|
||||
} else if (colors[i] > sig.colors[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool operator== (const Signature& sig) const
|
||||
{
|
||||
if (colors.size() != sig.colors.size()) {
|
||||
return false;
|
||||
}
|
||||
for (unsigned i = 0; i < colors.size(); i++) {
|
||||
if (colors[i] != sig.colors[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
vector<Color> colors;
|
||||
};
|
||||
|
||||
|
||||
|
||||
struct SignatureHash {
|
||||
size_t operator() (const Signature &sig) const
|
||||
{
|
||||
size_t val = hash<size_t>()(sig.colors.size());
|
||||
for (unsigned i = 0; i < sig.colors.size(); i++) {
|
||||
val ^= hash<size_t>()(sig.colors[i]);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
class VarCluster
|
||||
{
|
||||
public:
|
||||
VarCluster (const FgVarSet& vs)
|
||||
{
|
||||
for (unsigned i = 0; i < vs.size(); i++) {
|
||||
groundVars_.push_back (vs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void addFacCluster (FacCluster* fc)
|
||||
{
|
||||
factorClusters_.push_back (fc);
|
||||
}
|
||||
|
||||
const FacClusterSet& getFacClusters (void) const
|
||||
{
|
||||
return factorClusters_;
|
||||
}
|
||||
|
||||
FgVarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||
void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; }
|
||||
const FgVarSet& getGroundFgVarNodes (void) const { return groundVars_; }
|
||||
|
||||
private:
|
||||
FgVarSet groundVars_;
|
||||
FacClusterSet factorClusters_;
|
||||
FgVarNode* representVar_;
|
||||
};
|
||||
|
||||
|
||||
class FacCluster
|
||||
{
|
||||
public:
|
||||
FacCluster (const FgFacSet& groundFactors, const VarClusterSet& vcs)
|
||||
{
|
||||
groundFactors_ = groundFactors;
|
||||
varClusters_ = vcs;
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
varClusters_[i]->addFacCluster (this);
|
||||
}
|
||||
}
|
||||
|
||||
const VarClusterSet& getVarClusters (void) const
|
||||
{
|
||||
return varClusters_;
|
||||
}
|
||||
|
||||
bool containsGround (const FgFacNode* fn)
|
||||
{
|
||||
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
||||
if (groundFactors_[i] == fn) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
FgFacNode* getRepresentativeFactor (void) const
|
||||
{
|
||||
return representFactor_;
|
||||
}
|
||||
void setRepresentativeFactor (FgFacNode* fn)
|
||||
{
|
||||
representFactor_ = fn;
|
||||
}
|
||||
const FgFacSet& getGroundFactors (void) const
|
||||
{
|
||||
return groundFactors_;
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
FgFacSet groundFactors_;
|
||||
VarClusterSet varClusters_;
|
||||
FgFacNode* representFactor_;
|
||||
};
|
||||
|
||||
|
||||
class CFactorGraph
|
||||
{
|
||||
public:
|
||||
CFactorGraph (const FactorGraph&);
|
||||
~CFactorGraph (void);
|
||||
|
||||
FactorGraph* getCompressedFactorGraph (void);
|
||||
unsigned getGroundEdgeCount (const FacCluster*, const VarCluster*) const;
|
||||
|
||||
FgVarNode* getEquivalentVariable (VarId vid)
|
||||
{
|
||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||
return vc->getRepresentativeVariable();
|
||||
}
|
||||
|
||||
const VarClusterSet& getVariableClusters (void) { return varClusters_; }
|
||||
const FacClusterSet& getFacClusters (void) { return factorClusters_; }
|
||||
|
||||
static void enableCheckForIdenticalFactors (void)
|
||||
{
|
||||
checkForIdenticalFactors_ = true;
|
||||
}
|
||||
|
||||
static void disableCheckForIdenticalFactors (void)
|
||||
{
|
||||
checkForIdenticalFactors_ = false;
|
||||
}
|
||||
|
||||
private:
|
||||
void setInitialColors (void);
|
||||
void createGroups (void);
|
||||
void createClusters (const VarSignMap&, const FacSignMap&);
|
||||
const Signature& getSignature (const FgVarNode*);
|
||||
const Signature& getSignature (const FgFacNode*);
|
||||
void printGroups (const VarSignMap&, const FacSignMap&) const;
|
||||
|
||||
Color getFreeColor (void) {
|
||||
++ freeColor_;
|
||||
return freeColor_ - 1;
|
||||
}
|
||||
|
||||
Color getColor (const FgVarNode* vn) const
|
||||
{
|
||||
return varColors_[vn->getIndex()];
|
||||
}
|
||||
Color getColor (const FgFacNode* fn) const {
|
||||
return factorColors_[fn->getIndex()];
|
||||
}
|
||||
|
||||
void setColor (const FgVarNode* vn, Color c)
|
||||
{
|
||||
varColors_[vn->getIndex()] = c;
|
||||
}
|
||||
|
||||
void setColor (const FgFacNode* fn, Color c)
|
||||
{
|
||||
factorColors_[fn->getIndex()] = c;
|
||||
}
|
||||
|
||||
VarCluster* getVariableCluster (VarId vid) const
|
||||
{
|
||||
return vid2VarCluster_.find (vid)->second;
|
||||
}
|
||||
|
||||
Color freeColor_;
|
||||
vector<Color> varColors_;
|
||||
vector<Color> factorColors_;
|
||||
vector<Signature> varSignatures_;
|
||||
vector<Signature> factorSignatures_;
|
||||
VarClusterSet varClusters_;
|
||||
FacClusterSet factorClusters_;
|
||||
VarId2VarCluster vid2VarCluster_;
|
||||
const FactorGraph* groundFg_;
|
||||
bool static checkForIdenticalFactors_;
|
||||
};
|
||||
|
||||
#endif // HORUS_CFACTORGRAPH_H
|
||||
|
263
packages/CLPBN/clpbn/bp/CbpSolver.cpp
Normal file
263
packages/CLPBN/clpbn/bp/CbpSolver.cpp
Normal file
@ -0,0 +1,263 @@
|
||||
#include "CbpSolver.h"
|
||||
|
||||
|
||||
CbpSolver::~CbpSolver (void)
|
||||
{
|
||||
delete lfg_;
|
||||
delete factorGraph_;
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
links_.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
CbpSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
FgVarNode* var = lfg_->getEquivalentVariable (vid);
|
||||
ParamSet probs;
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->nrStates(), Util::noEvidence());
|
||||
probs[var->getEvidence()] = Util::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->nrStates(), Util::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::multiply (probs, l->getPoweredMessage());
|
||||
}
|
||||
Util::normalize (probs);
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (probs, l->getPoweredMessage());
|
||||
}
|
||||
Util::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
}
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
CbpSolver::getJointDistributionOf (const VarIdSet& jointVarIds)
|
||||
{
|
||||
unsigned msgSize = 1;
|
||||
vector<unsigned> dsizes (jointVarIds.size());
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
dsizes[i] = lfg_->getEquivalentVariable (jointVarIds[i])->nrStates();
|
||||
msgSize *= dsizes[i];
|
||||
}
|
||||
unsigned reps = 1;
|
||||
ParamSet jointDist (msgSize, Util::multIdenty());
|
||||
for (int i = jointVarIds.size() - 1 ; i >= 0; i--) {
|
||||
Util::multiply (jointDist, getPosterioriOf (jointVarIds[i]), reps);
|
||||
reps *= dsizes[i];
|
||||
}
|
||||
return jointDist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::initializeSolver (void)
|
||||
{
|
||||
unsigned nGroundVars, nGroundFacs, nWithoutNeighs;
|
||||
if (COLLECT_STATISTICS) {
|
||||
nGroundVars = factorGraph_->getVarNodes().size();
|
||||
nGroundFacs = factorGraph_->getFactorNodes().size();
|
||||
const FgVarSet& vars = factorGraph_->getVarNodes();
|
||||
nWithoutNeighs = 0;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
const FgFacSet& factors = vars[i]->neighbors();
|
||||
if (factors.size() == 1 && factors[0]->neighbors().size() == 1) {
|
||||
nWithoutNeighs ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lfg_ = new CFactorGraph (*factorGraph_);
|
||||
|
||||
// cout << "Uncompressed Factor Graph" << endl;
|
||||
// factorGraph_->printGraphicalModel();
|
||||
// factorGraph_->exportToGraphViz ("uncompressed_fg.dot");
|
||||
factorGraph_ = lfg_->getCompressedFactorGraph();
|
||||
|
||||
if (COLLECT_STATISTICS) {
|
||||
unsigned nClusterVars = factorGraph_->getVarNodes().size();
|
||||
unsigned nClusterFacs = factorGraph_->getFactorNodes().size();
|
||||
Statistics::updateCompressingStatistics (nGroundVars, nGroundFacs,
|
||||
nClusterVars, nClusterFacs,
|
||||
nWithoutNeighs);
|
||||
}
|
||||
|
||||
// cout << "Compressed Factor Graph" << endl;
|
||||
// factorGraph_->printGraphicalModel();
|
||||
// factorGraph_->exportToGraphViz ("compressed_fg.dot");
|
||||
// abort();
|
||||
FgBpSolver::initializeSolver();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::createLinks (void)
|
||||
{
|
||||
const FacClusterSet fcs = lfg_->getFacClusters();
|
||||
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||
const VarClusterSet vcs = fcs[i]->getVarClusters();
|
||||
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]);
|
||||
links_.push_back (new CbpSolverLink (fcs[i]->getRepresentativeFactor(),
|
||||
vcs[j]->getRepresentativeVariable(), c));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
if (DL >= 2 && DL < 5) {
|
||||
cout << "calculating " << links_[i]->toString() << endl;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
cout << endl << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
SpLink* link = *it;
|
||||
if (DL >= 2) {
|
||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||
}
|
||||
if (link->getResidual() < BpOptions::accuracy) {
|
||||
return;
|
||||
}
|
||||
link->updateMessage();
|
||||
link->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
|
||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getVariable() != link->getVariable()) {
|
||||
if (DL >= 2 && DL < 5) {
|
||||
cout << " calculating " << links[j]->toString() << endl;
|
||||
}
|
||||
calculateMessage (links[j]);
|
||||
SpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// in counting bp, the message that a variable X sends to
|
||||
// to a factor F depends on the message that F sent to the X
|
||||
const SpLinkSet& links = ninf(link->getFactor())->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getVariable() != link->getVariable()) {
|
||||
if (DL >= 2 && DL < 5) {
|
||||
cout << " calculating " << links[i]->toString() << endl;
|
||||
}
|
||||
calculateMessage (links[i]);
|
||||
SpLinkMap::iterator iter = linkMap_.find (links[i]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
CbpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
{
|
||||
ParamSet msg;
|
||||
const FgVarNode* src = link->getVariable();
|
||||
const FgFacNode* dst = link->getFactor();
|
||||
const CbpSolverLink* l = static_cast<const CbpSolverLink*> (link);
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->nrStates(), Util::noEvidence());
|
||||
double value = link->getMessage()[src->getEvidence()];
|
||||
msg[src->getEvidence()] = Util::pow (value, l->getNumberOfEdges() - 1);
|
||||
} else {
|
||||
msg = link->getMessage();
|
||||
Util::pow (msg, l->getNumberOfEdges() - 1);
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " " << "init: " << Util::parametersToString (msg) << endl;
|
||||
}
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::multiply (msg, l->getPoweredMessage());
|
||||
if (DL >= 5) {
|
||||
cout << " msg from " << l->getFactor()->getLabel() << ": " ;
|
||||
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links[i]);
|
||||
Util::add (msg, l->getPoweredMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " result = " << Util::parametersToString (msg) << endl;
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CbpSolver::printLinkInformation (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
CbpSolverLink* l = static_cast<CbpSolverLink*> (links_[i]);
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " ;
|
||||
cout << Util::parametersToString (l->getMessage()) << endl;
|
||||
cout << " next msg = " ;
|
||||
cout << Util::parametersToString (l->getNextMessage()) << endl;
|
||||
cout << " powered = " ;
|
||||
cout << Util::parametersToString (l->getPoweredMessage()) << endl;
|
||||
cout << " residual = " << l->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
58
packages/CLPBN/clpbn/bp/CbpSolver.h
Normal file
58
packages/CLPBN/clpbn/bp/CbpSolver.h
Normal file
@ -0,0 +1,58 @@
|
||||
#ifndef HORUS_CBP_H
|
||||
#define HORUS_CBP_H
|
||||
|
||||
#include "FgBpSolver.h"
|
||||
#include "CFactorGraph.h"
|
||||
|
||||
class Factor;
|
||||
|
||||
class CbpSolverLink : public SpLink
|
||||
{
|
||||
public:
|
||||
CbpSolverLink (FgFacNode* fn, FgVarNode* vn, unsigned c) : SpLink (fn, vn)
|
||||
{
|
||||
edgeCount_ = c;
|
||||
poweredMsg_.resize (vn->nrStates(), Util::one());
|
||||
}
|
||||
|
||||
void updateMessage (void)
|
||||
{
|
||||
poweredMsg_ = *nextMsg_;
|
||||
swap (currMsg_, nextMsg_);
|
||||
msgSended_ = true;
|
||||
Util::pow (poweredMsg_, edgeCount_);
|
||||
}
|
||||
|
||||
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
||||
const ParamSet& getPoweredMessage (void) const { return poweredMsg_; }
|
||||
|
||||
private:
|
||||
ParamSet poweredMsg_;
|
||||
unsigned edgeCount_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class CbpSolver : public FgBpSolver
|
||||
{
|
||||
public:
|
||||
CbpSolver (FactorGraph& fg) : FgBpSolver (fg) { }
|
||||
~CbpSolver (void);
|
||||
|
||||
ParamSet getPosterioriOf (VarId);
|
||||
ParamSet getJointDistributionOf (const VarIdSet&);
|
||||
|
||||
private:
|
||||
void initializeSolver (void);
|
||||
void createLinks (void);
|
||||
|
||||
void maxResidualSchedule (void);
|
||||
ParamSet getVar2FactorMsg (const SpLink*) const;
|
||||
void printLinkInformation (void) const;
|
||||
|
||||
|
||||
CFactorGraph* lfg_;
|
||||
};
|
||||
|
||||
#endif // HORUS_CBP_H
|
||||
|
@ -1,198 +0,0 @@
|
||||
#include "CountingBP.h"
|
||||
|
||||
|
||||
CountingBP::~CountingBP (void)
|
||||
{
|
||||
delete lfg_;
|
||||
delete fg_;
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
links_.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
CountingBP::getPosterioriOf (Vid vid) const
|
||||
{
|
||||
FgVarNode* var = lfg_->getEquivalentVariable (vid);
|
||||
ParamSet probs;
|
||||
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->getDomainSize(), 0.0);
|
||||
probs[var->getEvidence()] = 1.0;
|
||||
} else {
|
||||
probs.resize (var->getDomainSize(), 1.0);
|
||||
CLinkSet links = varsI_[var->getIndex()]->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
ParamSet msg = links[i]->getMessage();
|
||||
CountingBPLink* l = static_cast<CountingBPLink*> (links[i]);
|
||||
Util::pow (msg, l->getNumberOfEdges());
|
||||
for (unsigned j = 0; j < msg.size(); j++) {
|
||||
probs[j] *= msg[j];
|
||||
}
|
||||
}
|
||||
Util::normalize (probs);
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CountingBP::initializeSolver (void)
|
||||
{
|
||||
lfg_ = new LiftedFG (*fg_);
|
||||
unsigned nUncVars = fg_->getFgVarNodes().size();
|
||||
unsigned nUncFactors = fg_->getFactors().size();
|
||||
CFgVarSet vars = fg_->getFgVarNodes();
|
||||
unsigned nNeighborLessVars = 0;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
CFactorSet factors = vars[i]->getFactors();
|
||||
if (factors.size() == 1 && factors[0]->getFgVarNodes().size() == 1) {
|
||||
nNeighborLessVars ++;
|
||||
}
|
||||
}
|
||||
// cout << "UNCOMPRESSED FACTOR GRAPH" << endl;
|
||||
// fg_->printGraphicalModel();
|
||||
fg_->exportToDotFormat ("uncompress.dot");
|
||||
|
||||
FactorGraph *temp;
|
||||
temp = fg_;
|
||||
fg_ = lfg_->getCompressedFactorGraph();
|
||||
unsigned nCompVars = fg_->getFgVarNodes().size();
|
||||
unsigned nCompFactors = fg_->getFactors().size();
|
||||
|
||||
Statistics::updateCompressingStats (nUncVars,
|
||||
nUncFactors,
|
||||
nCompVars,
|
||||
nCompFactors,
|
||||
nNeighborLessVars);
|
||||
|
||||
cout << "COMPRESSED FACTOR GRAPH" << endl;
|
||||
fg_->printGraphicalModel();
|
||||
//fg_->exportToDotFormat ("compress.dot");
|
||||
|
||||
SPSolver::initializeSolver();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CountingBP::createLinks (void)
|
||||
{
|
||||
const FactorClusterSet fcs = lfg_->getFactorClusters();
|
||||
for (unsigned i = 0; i < fcs.size(); i++) {
|
||||
const VarClusterSet vcs = fcs[i]->getVarClusters();
|
||||
for (unsigned j = 0; j < vcs.size(); j++) {
|
||||
unsigned c = lfg_->getGroundEdgeCount (fcs[i], vcs[j]);
|
||||
links_.push_back (
|
||||
new CountingBPLink (fcs[i]->getRepresentativeFactor(),
|
||||
vcs[j]->getRepresentativeVariable(), c));
|
||||
//cout << (links_.back())->toString() << " edge count =" << c << endl;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CountingBP::deleteJunction (Factor* f, FgVarNode*)
|
||||
{
|
||||
f->freeDistribution();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
CountingBP::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIter_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
if (DL >= 2 && DL < 5) {
|
||||
cout << "calculating " << links_[i]->toString() << endl;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
cout << endl << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
Link* link = *it;
|
||||
if (DL >= 2) {
|
||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||
}
|
||||
if (link->getResidual() < SolverOptions::accuracy) {
|
||||
return;
|
||||
}
|
||||
link->updateMessage();
|
||||
link->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
CFactorSet factorNeighbors = link->getVariable()->getFactors();
|
||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||
CLinkSet links = factorsI_[factorNeighbors[i]->getIndex()]->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getVariable() != link->getVariable()) { //FIXMEFIXME
|
||||
if (DL >= 2 && DL < 5) {
|
||||
cout << " calculating " << links[j]->toString() << endl;
|
||||
}
|
||||
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
|
||||
LinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
CountingBP::getVar2FactorMsg (const Link* link) const
|
||||
{
|
||||
const FgVarNode* src = link->getVariable();
|
||||
const Factor* dest = link->getFactor();
|
||||
ParamSet msg;
|
||||
if (src->hasEvidence()) {
|
||||
cout << "has evidence" << endl;
|
||||
msg.resize (src->getDomainSize(), 0.0);
|
||||
msg[src->getEvidence()] = link->getMessage()[src->getEvidence()];
|
||||
cout << "-> " << link->getVariable()->getLabel() << " " << link->getFactor()->getLabel() << endl;
|
||||
cout << "-> p2s " << Util::parametersToString (msg) << endl;
|
||||
} else {
|
||||
msg = link->getMessage();
|
||||
}
|
||||
const CountingBPLink* l = static_cast<const CountingBPLink*> (link);
|
||||
Util::pow (msg, l->getNumberOfEdges() - 1);
|
||||
CLinkSet links = varsI_[src->getIndex()]->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dest) {
|
||||
ParamSet msgFromFactor = links[i]->getMessage();
|
||||
CountingBPLink* l = static_cast<CountingBPLink*> (links[i]);
|
||||
Util::pow (msgFromFactor, l->getNumberOfEdges());
|
||||
for (unsigned j = 0; j < msgFromFactor.size(); j++) {
|
||||
msg[j] *= msgFromFactor[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
@ -1,45 +0,0 @@
|
||||
#ifndef BP_COUNTING_BP_H
|
||||
#define BP_COUNTING_BP_H
|
||||
|
||||
#include "SPSolver.h"
|
||||
#include "LiftedFG.h"
|
||||
|
||||
class Factor;
|
||||
class FgVarNode;
|
||||
|
||||
class CountingBPLink : public Link
|
||||
{
|
||||
public:
|
||||
CountingBPLink (Factor* f, FgVarNode* v, unsigned c) : Link (f, v)
|
||||
{
|
||||
edgeCount_ = c;
|
||||
}
|
||||
|
||||
unsigned getNumberOfEdges (void) const { return edgeCount_; }
|
||||
|
||||
private:
|
||||
unsigned edgeCount_;
|
||||
};
|
||||
|
||||
|
||||
class CountingBP : public SPSolver
|
||||
{
|
||||
public:
|
||||
CountingBP (FactorGraph& fg) : SPSolver (fg) { }
|
||||
~CountingBP (void);
|
||||
|
||||
ParamSet getPosterioriOf (Vid) const;
|
||||
|
||||
private:
|
||||
void initializeSolver (void);
|
||||
void createLinks (void);
|
||||
void deleteJunction (Factor*, FgVarNode*);
|
||||
|
||||
void maxResidualSchedule (void);
|
||||
ParamSet getVar2FactorMsg (const Link*) const;
|
||||
|
||||
LiftedFG* lfg_;
|
||||
};
|
||||
|
||||
#endif // BP_COUNTING_BP_H
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef BP_CPT_ENTRY_H
|
||||
#define BP_CPT_ENTRY_H
|
||||
#ifndef HORUS_CPTENTRY_H
|
||||
#define HORUS_CPTENTRY_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
@ -39,5 +39,5 @@ class CptEntry
|
||||
DConf conf_;
|
||||
};
|
||||
|
||||
#endif //BP_CPT_ENTRY_H
|
||||
#endif // HORUS_CPTENTRY_H
|
||||
|
||||
|
@ -1,40 +0,0 @@
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include <Distribution.h>
|
||||
|
||||
Distribution::Distribution (int id,
|
||||
double* params,
|
||||
int nParams,
|
||||
vector<string> domain)
|
||||
{
|
||||
this->id = id;
|
||||
this->params = params;
|
||||
this->nParams = nParams;
|
||||
this->domain = domain;
|
||||
}
|
||||
|
||||
|
||||
Distribution::Distribution (double* params,
|
||||
int nParams,
|
||||
vector<string> domain)
|
||||
{
|
||||
this->id = -1;
|
||||
this->params = params;
|
||||
this->nParams = nParams;
|
||||
this->domain = domain;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/*
|
||||
Distribution::~Distribution()
|
||||
{
|
||||
delete params;
|
||||
for (unsigned int i = 0; i < cptEntries.size(); i++) {
|
||||
delete cptEntries[i];
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#ifndef BP_DISTRIBUTION_H
|
||||
#define BP_DISTRIBUTION_H
|
||||
#ifndef HORUS_DISTRIBUTION_H
|
||||
#define HORUS_DISTRIBUTION_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
@ -11,18 +11,16 @@ using namespace std;
|
||||
struct Distribution
|
||||
{
|
||||
public:
|
||||
Distribution (unsigned id, bool shared = false)
|
||||
Distribution (unsigned id)
|
||||
{
|
||||
this->id = id;
|
||||
this->params = params;
|
||||
this->shared = shared;
|
||||
}
|
||||
|
||||
Distribution (const ParamSet& params, bool shared = false)
|
||||
Distribution (const ParamSet& params, unsigned id = -1)
|
||||
{
|
||||
this->id = -1;
|
||||
this->id = id;
|
||||
this->params = params;
|
||||
this->shared = shared;
|
||||
}
|
||||
|
||||
void updateParameters (const ParamSet& params)
|
||||
@ -33,11 +31,10 @@ struct Distribution
|
||||
unsigned id;
|
||||
ParamSet params;
|
||||
vector<CptEntry> entries;
|
||||
bool shared;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (Distribution);
|
||||
};
|
||||
|
||||
#endif //BP_DISTRIBUTION_H
|
||||
#endif // HORUS_DISTRIBUTION_H
|
||||
|
||||
|
322
packages/CLPBN/clpbn/bp/ElimGraph.cpp
Normal file
322
packages/CLPBN/clpbn/bp/ElimGraph.cpp
Normal file
@ -0,0 +1,322 @@
|
||||
#include <limits>
|
||||
|
||||
#include "ElimGraph.h"
|
||||
#include "BayesNet.h"
|
||||
|
||||
|
||||
ElimHeuristic ElimGraph::elimHeuristic_ = MIN_NEIGHBORS;
|
||||
|
||||
|
||||
ElimGraph::ElimGraph (const BayesNet& bayesNet)
|
||||
{
|
||||
const BnNodeSet& bnNodes = bayesNet.getBayesNodes();
|
||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
||||
if (bnNodes[i]->hasEvidence() == false) {
|
||||
addNode (new EgNode (bnNodes[i]));
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
||||
if (bnNodes[i]->hasEvidence() == false) {
|
||||
EgNode* n = getEgNode (bnNodes[i]->varId());
|
||||
const BnNodeSet& childs = bnNodes[i]->getChilds();
|
||||
for (unsigned j = 0; j < childs.size(); j++) {
|
||||
if (childs[j]->hasEvidence() == false) {
|
||||
addEdge (n, getEgNode (childs[j]->varId()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < bnNodes.size(); i++) {
|
||||
vector<EgNode*> neighs;
|
||||
const vector<BayesNode*>& parents = bnNodes[i]->getParents();
|
||||
for (unsigned i = 0; i < parents.size(); i++) {
|
||||
if (parents[i]->hasEvidence() == false) {
|
||||
neighs.push_back (getEgNode (parents[i]->varId()));
|
||||
}
|
||||
}
|
||||
if (neighs.size() > 0) {
|
||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||
if (!neighbors (neighs[i], neighs[j])) {
|
||||
addEdge (neighs[i], neighs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
|
||||
ElimGraph::~ElimGraph (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
delete nodes_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::addNode (EgNode* n)
|
||||
{
|
||||
nodes_.push_back (n);
|
||||
vid2nodes_.insert (make_pair (n->varId(), n));
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getEgNode (VarId vid) const
|
||||
{
|
||||
unordered_map<VarId,EgNode*>::const_iterator it = vid2nodes_.find (vid);
|
||||
if (it == vid2nodes_.end()) {
|
||||
return 0;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarIdSet
|
||||
ElimGraph::getEliminatingOrder (const VarIdSet& exclude)
|
||||
{
|
||||
VarIdSet elimOrder;
|
||||
marked_.resize (nodes_.size(), false);
|
||||
|
||||
for (unsigned i = 0; i < exclude.size(); i++) {
|
||||
EgNode* node = getEgNode (exclude[i]);
|
||||
assert (node);
|
||||
marked_[*node] = true;
|
||||
}
|
||||
|
||||
unsigned nVarsToEliminate = nodes_.size() - exclude.size();
|
||||
for (unsigned i = 0; i < nVarsToEliminate; i++) {
|
||||
EgNode* node = getLowestCostNode();
|
||||
marked_[*node] = true;
|
||||
elimOrder.push_back (node->varId());
|
||||
connectAllNeighbors (node);
|
||||
}
|
||||
return elimOrder;
|
||||
}
|
||||
|
||||
|
||||
|
||||
EgNode*
|
||||
ElimGraph::getLowestCostNode (void) const
|
||||
{
|
||||
EgNode* bestNode = 0;
|
||||
unsigned minCost = std::numeric_limits<unsigned>::max();
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (marked_[i]) continue;
|
||||
unsigned cost = 0;
|
||||
switch (elimHeuristic_) {
|
||||
case MIN_NEIGHBORS:
|
||||
cost = getNeighborsCost (nodes_[i]);
|
||||
break;
|
||||
case MIN_WEIGHT:
|
||||
cost = getWeightCost (nodes_[i]);
|
||||
break;
|
||||
case MIN_FILL:
|
||||
cost = getFillCost (nodes_[i]);
|
||||
break;
|
||||
case WEIGHTED_MIN_FILL:
|
||||
cost = getWeightedFillCost (nodes_[i]);
|
||||
break;
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
if (cost < minCost) {
|
||||
bestNode = nodes_[i];
|
||||
minCost = cost;
|
||||
}
|
||||
}
|
||||
assert (bestNode);
|
||||
return bestNode;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
ElimGraph::getNeighborsCost (const EgNode* n) const
|
||||
{
|
||||
unsigned cost = 0;
|
||||
const vector<EgNode*>& neighs = n->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
if (marked_[*neighs[i]] == false) {
|
||||
cost ++;
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
ElimGraph::getWeightCost (const EgNode* n) const
|
||||
{
|
||||
unsigned cost = 1;
|
||||
const vector<EgNode*>& neighs = n->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
if (marked_[*neighs[i]] == false) {
|
||||
cost *= neighs[i]->nrStates();
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
ElimGraph::getFillCost (const EgNode* n) const
|
||||
{
|
||||
unsigned cost = 0;
|
||||
const vector<EgNode*>& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
||||
if (marked_[*neighs[i]] == true) continue;
|
||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||
if (marked_[*neighs[j]] == true) continue;
|
||||
if (!neighbors (neighs[i], neighs[j])) {
|
||||
cost ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
ElimGraph::getWeightedFillCost (const EgNode* n) const
|
||||
{
|
||||
unsigned cost = 0;
|
||||
const vector<EgNode*>& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
||||
if (marked_[*neighs[i]] == true) continue;
|
||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||
if (marked_[*neighs[j]] == true) continue;
|
||||
if (!neighbors (neighs[i], neighs[j])) {
|
||||
cost += neighs[i]->nrStates() * neighs[j]->nrStates();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::connectAllNeighbors (const EgNode* n)
|
||||
{
|
||||
const vector<EgNode*>& neighs = n->neighbors();
|
||||
if (neighs.size() > 0) {
|
||||
for (unsigned i = 0; i < neighs.size() - 1; i++) {
|
||||
if (marked_[*neighs[i]] == true) continue;
|
||||
for (unsigned j = i+1; j < neighs.size(); j++) {
|
||||
if (marked_[*neighs[j]] == true) continue;
|
||||
if (!neighbors (neighs[i], neighs[j])) {
|
||||
addEdge (neighs[i], neighs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
ElimGraph::neighbors (const EgNode* n1, const EgNode* n2) const
|
||||
{
|
||||
const vector<EgNode*>& neighs = n1->neighbors();
|
||||
for (unsigned i = 0; i < neighs.size(); i++) {
|
||||
if (neighs[i] == n2) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::setIndexes (void)
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
nodes_[i]->setIndex (i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::printGraphicalModel (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
cout << "node " << nodes_[i]->label() << " neighs:" ;
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
cout << " " << neighs[j]->label();
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
ElimGraph::exportToGraphViz (const char* fileName,
|
||||
bool showNeighborless,
|
||||
const VarIdSet& highlightVarIds) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "Markov::exportToDotFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
out << "strict graph {" << endl;
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
|
||||
out << '"' << nodes_[i]->label() << '"' ;
|
||||
if (nodes_[i]->hasEvidence()) {
|
||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||
} else {
|
||||
out << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < highlightVarIds.size(); i++) {
|
||||
EgNode* node =getEgNode (highlightVarIds[i]);
|
||||
if (node) {
|
||||
out << '"' << node->label() << '"' ;
|
||||
out << " [shape=box3d]" << endl;
|
||||
} else {
|
||||
cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes_.size(); i++) {
|
||||
vector<EgNode*> neighs = nodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
out << '"' << nodes_[i]->label() << '"' << " -- " ;
|
||||
out << '"' << neighs[j]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
out << "}" << endl;
|
||||
out.close();
|
||||
}
|
||||
|
76
packages/CLPBN/clpbn/bp/ElimGraph.h
Normal file
76
packages/CLPBN/clpbn/bp/ElimGraph.h
Normal file
@ -0,0 +1,76 @@
|
||||
#ifndef HORUS_ELIMGRAPH_H
|
||||
#define HORUS_ELIMGRAPH_H
|
||||
|
||||
#include "unordered_map"
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "Shared.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
enum ElimHeuristic
|
||||
{
|
||||
MIN_NEIGHBORS,
|
||||
MIN_WEIGHT,
|
||||
MIN_FILL,
|
||||
WEIGHTED_MIN_FILL
|
||||
};
|
||||
|
||||
|
||||
class EgNode : public VarNode {
|
||||
public:
|
||||
EgNode (VarNode* var) : VarNode (var) { }
|
||||
void addNeighbor (EgNode* n)
|
||||
{
|
||||
neighs_.push_back (n);
|
||||
}
|
||||
|
||||
const vector<EgNode*>& neighbors (void) const { return neighs_; }
|
||||
private:
|
||||
vector<EgNode*> neighs_;
|
||||
};
|
||||
|
||||
|
||||
class ElimGraph
|
||||
{
|
||||
public:
|
||||
ElimGraph (const BayesNet&);
|
||||
~ElimGraph (void);
|
||||
|
||||
void addEdge (EgNode* n1, EgNode* n2)
|
||||
{
|
||||
assert (n1 != n2);
|
||||
n1->addNeighbor (n2);
|
||||
n2->addNeighbor (n1);
|
||||
}
|
||||
void addNode (EgNode*);
|
||||
EgNode* getEgNode (VarId) const;
|
||||
VarIdSet getEliminatingOrder (const VarIdSet&);
|
||||
void printGraphicalModel (void) const;
|
||||
void exportToGraphViz (const char*, bool = true,
|
||||
const VarIdSet& = VarIdSet()) const;
|
||||
void setIndexes();
|
||||
|
||||
static void setEliminationHeuristic (ElimHeuristic h)
|
||||
{
|
||||
elimHeuristic_ = h;
|
||||
}
|
||||
|
||||
private:
|
||||
EgNode* getLowestCostNode (void) const;
|
||||
unsigned getNeighborsCost (const EgNode*) const;
|
||||
unsigned getWeightCost (const EgNode*) const;
|
||||
unsigned getFillCost (const EgNode*) const;
|
||||
unsigned getWeightedFillCost (const EgNode*) const;
|
||||
void connectAllNeighbors (const EgNode*);
|
||||
bool neighbors (const EgNode*, const EgNode*) const;
|
||||
|
||||
|
||||
vector<EgNode*> nodes_;
|
||||
vector<bool> marked_;
|
||||
unordered_map<VarId,EgNode*> vid2nodes_;
|
||||
static ElimHeuristic elimHeuristic_;
|
||||
};
|
||||
|
||||
#endif // HORUS_ELIMGRAPH_H
|
||||
|
@ -1,33 +1,38 @@
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "Factor.h"
|
||||
#include "FgVarNode.h"
|
||||
#include "StatesIndexer.h"
|
||||
|
||||
|
||||
Factor::Factor (const Factor& g)
|
||||
{
|
||||
copyFactor (g);
|
||||
copyFromFactor (g);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (FgVarNode* var)
|
||||
Factor::Factor (VarId vid, unsigned nStates)
|
||||
{
|
||||
Factor (FgVarSet() = {var});
|
||||
varids_.push_back (vid);
|
||||
ranges_.push_back (nStates);
|
||||
dist_ = new Distribution (ParamSet (nStates, 1.0));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const FgVarSet& vars)
|
||||
Factor::Factor (const VarNodes& vars)
|
||||
{
|
||||
vars_ = vars;
|
||||
int nParams = 1;
|
||||
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||
nParams *= vars_[i]->getDomainSize();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varids_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
nParams *= vars[i]->nrStates();
|
||||
}
|
||||
// create a uniform distribution
|
||||
double val = 1.0 / nParams;
|
||||
@ -36,28 +41,43 @@ Factor::Factor (const FgVarSet& vars)
|
||||
|
||||
|
||||
|
||||
Factor::Factor (FgVarNode* var,
|
||||
const ParamSet& params)
|
||||
Factor::Factor (VarId vid, unsigned nStates, const ParamSet& params)
|
||||
{
|
||||
vars_.push_back (var);
|
||||
varids_.push_back (vid);
|
||||
ranges_.push_back (nStates);
|
||||
dist_ = new Distribution (params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (FgVarSet& vars,
|
||||
Distribution* dist)
|
||||
Factor::Factor (VarNodes& vars, Distribution* dist)
|
||||
{
|
||||
vars_ = vars;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varids_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
}
|
||||
dist_ = dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const FgVarSet& vars,
|
||||
Factor::Factor (const VarNodes& vars, const ParamSet& params)
|
||||
{
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varids_.push_back (vars[i]->varId());
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
}
|
||||
dist_ = new Distribution (params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Factor::Factor (const VarIdSet& vids,
|
||||
const Ranges& ranges,
|
||||
const ParamSet& params)
|
||||
{
|
||||
vars_ = vars;
|
||||
varids_ = vids;
|
||||
ranges_ = ranges;
|
||||
dist_ = new Distribution (params);
|
||||
}
|
||||
|
||||
@ -73,9 +93,10 @@ Factor::setParameters (const ParamSet& params)
|
||||
|
||||
|
||||
void
|
||||
Factor::copyFactor (const Factor& g)
|
||||
Factor::copyFromFactor (const Factor& g)
|
||||
{
|
||||
vars_ = g.getFgVarNodes();
|
||||
varids_ = g.getVarIds();
|
||||
ranges_ = g.getRanges();
|
||||
dist_ = new Distribution (g.getDistribution()->params);
|
||||
}
|
||||
|
||||
@ -84,50 +105,43 @@ Factor::copyFactor (const Factor& g)
|
||||
void
|
||||
Factor::multiplyByFactor (const Factor& g, const vector<CptEntry>* entries)
|
||||
{
|
||||
if (vars_.size() == 0) {
|
||||
copyFactor (g);
|
||||
if (varids_.size() == 0) {
|
||||
copyFromFactor (g);
|
||||
return;
|
||||
}
|
||||
|
||||
const FgVarSet& gVs = g.getFgVarNodes();
|
||||
const ParamSet& gPs = g.getParameters();
|
||||
const VarIdSet& gvarids = g.getVarIds();
|
||||
const Ranges& granges = g.getRanges();
|
||||
const ParamSet& gparams = g.getParameters();
|
||||
|
||||
bool factorsAreEqual = true;
|
||||
if (gVs.size() == vars_.size()) {
|
||||
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||
if (gVs[i] != vars_[i]) {
|
||||
factorsAreEqual = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
factorsAreEqual = false;
|
||||
}
|
||||
|
||||
if (factorsAreEqual) {
|
||||
if (varids_ == gvarids) {
|
||||
// optimization: if the factors contain the same set of variables,
|
||||
// we can do 1 to 1 operations on the parameteres
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
dist_->params[i] *= gPs[i];
|
||||
// we can do a 1 to 1 operation on the parameters
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
Util::multiply (dist_->params, gparams);
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
Util::add (dist_->params, gparams);
|
||||
}
|
||||
} else {
|
||||
bool hasCommonVars = false;
|
||||
vector<unsigned> gVsIndexes;
|
||||
for (unsigned i = 0; i < gVs.size(); i++) {
|
||||
int idx = getIndexOf (gVs[i]);
|
||||
if (idx == -1) {
|
||||
insertVariable (gVs[i]);
|
||||
gVsIndexes.push_back (vars_.size() - 1);
|
||||
vector<unsigned> gvarpos;
|
||||
for (unsigned i = 0; i < gvarids.size(); i++) {
|
||||
int pos = getPositionOf (gvarids[i]);
|
||||
if (pos == -1) {
|
||||
insertVariable (gvarids[i], granges[i]);
|
||||
gvarpos.push_back (varids_.size() - 1);
|
||||
} else {
|
||||
hasCommonVars = true;
|
||||
gVsIndexes.push_back (idx);
|
||||
gvarpos.push_back (pos);
|
||||
}
|
||||
}
|
||||
if (hasCommonVars) {
|
||||
vector<unsigned> gVsOffsets (gVs.size());
|
||||
gVsOffsets[gVs.size() - 1] = 1;
|
||||
for (int i = gVs.size() - 2; i >= 0; i--) {
|
||||
gVsOffsets[i] = gVsOffsets[i + 1] * gVs[i + 1]->getDomainSize();
|
||||
vector<unsigned> gvaroffsets (gvarids.size());
|
||||
gvaroffsets[gvarids.size() - 1] = 1;
|
||||
for (int i = gvarids.size() - 2; i >= 0; i--) {
|
||||
gvaroffsets[i] = gvaroffsets[i + 1] * granges[i + 1];
|
||||
}
|
||||
|
||||
if (entries == 0) {
|
||||
@ -137,50 +151,88 @@ Factor::multiplyByFactor (const Factor& g, const vector<CptEntry>* entries)
|
||||
for (unsigned i = 0; i < entries->size(); i++) {
|
||||
unsigned idx = 0;
|
||||
const DConf& conf = (*entries)[i].getDomainConfiguration();
|
||||
for (unsigned j = 0; j < gVsIndexes.size(); j++) {
|
||||
idx += gVsOffsets[j] * conf[ gVsIndexes[j] ];
|
||||
for (unsigned j = 0; j < gvarpos.size(); j++) {
|
||||
idx += gvaroffsets[j] * conf[ gvarpos[j] ];
|
||||
}
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
dist_->params[i] *= gparams[idx];
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
dist_->params[i] += gparams[idx];
|
||||
}
|
||||
dist_->params[i] = dist_->params[i] * gPs[idx];
|
||||
}
|
||||
} else {
|
||||
// optimization: if the original factors doesn't have common variables,
|
||||
// we don't need to marry the states of the common variables
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
dist_->params[i] *= gPs[count];
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
dist_->params[i] *= gparams[count];
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
dist_->params[i] += gparams[count];
|
||||
}
|
||||
count ++;
|
||||
if (count >= gPs.size()) {
|
||||
if (count >= gparams.size()) {
|
||||
count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
dist_->entries.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::insertVariable (FgVarNode* var)
|
||||
Factor::insertVariable (VarId vid, unsigned nStates)
|
||||
{
|
||||
assert (getIndexOf (var) == -1);
|
||||
assert (getPositionOf (vid) == -1);
|
||||
ParamSet newPs;
|
||||
newPs.reserve (dist_->params.size() * var->getDomainSize());
|
||||
newPs.reserve (dist_->params.size() * nStates);
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
for (unsigned j = 0; j < var->getDomainSize(); j++) {
|
||||
for (unsigned j = 0; j < nStates; j++) {
|
||||
newPs.push_back (dist_->params[i]);
|
||||
}
|
||||
}
|
||||
vars_.push_back (var);
|
||||
varids_.push_back (vid);
|
||||
ranges_.push_back (nStates);
|
||||
dist_->updateParameters (newPs);
|
||||
dist_->entries.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::removeVariable (const FgVarNode* var)
|
||||
Factor::removeAllVariablesExcept (VarId vid)
|
||||
{
|
||||
int varIndex = getIndexOf (var);
|
||||
assert (varIndex >= 0 && varIndex < (int)vars_.size());
|
||||
assert (getPositionOf (vid) != -1);
|
||||
while (varids_.back() != vid) {
|
||||
removeLastVariable();
|
||||
}
|
||||
while (varids_.front() != vid) {
|
||||
removeFirstVariable();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::removeVariable (VarId vid)
|
||||
{
|
||||
int pos = getPositionOf (vid);
|
||||
assert (pos != -1);
|
||||
|
||||
if (vid == varids_.back()) {
|
||||
removeLastVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
if (vid == varids_.front()) {
|
||||
removeFirstVariable(); // optimization
|
||||
return;
|
||||
}
|
||||
|
||||
// number of parameters separating a different state of `var',
|
||||
// with the states of the remaining variables fixed
|
||||
@ -190,36 +242,36 @@ Factor::removeVariable (const FgVarNode* var)
|
||||
// on the left of `var', with the states of the remaining vars fixed
|
||||
unsigned leftVarOffset = 1;
|
||||
|
||||
for (int i = vars_.size() - 1; i > varIndex; i--) {
|
||||
varOffset *= vars_[i]->getDomainSize();
|
||||
leftVarOffset *= vars_[i]->getDomainSize();
|
||||
for (int i = varids_.size() - 1; i > pos; i--) {
|
||||
varOffset *= ranges_[i];
|
||||
leftVarOffset *= ranges_[i];
|
||||
}
|
||||
leftVarOffset *= vars_[varIndex]->getDomainSize();
|
||||
leftVarOffset *= ranges_[pos];
|
||||
|
||||
unsigned offset = 0;
|
||||
unsigned count1 = 0;
|
||||
unsigned count2 = 0;
|
||||
unsigned newPsSize = dist_->params.size() / vars_[varIndex]->getDomainSize();
|
||||
unsigned newPsSize = dist_->params.size() / ranges_[pos];
|
||||
|
||||
ParamSet newPs;
|
||||
newPs.reserve (newPsSize);
|
||||
|
||||
// stringstream ss;
|
||||
// ss << "marginalizing " << vars_[varIndex]->getLabel();
|
||||
// ss << " from factor " << getLabel() << endl;
|
||||
while (newPs.size() < newPsSize) {
|
||||
// ss << " sum = ";
|
||||
double sum = 0.0;
|
||||
for (unsigned i = 0; i < vars_[varIndex]->getDomainSize(); i++) {
|
||||
// if (i != 0) ss << " + ";
|
||||
// ss << dist_->params[offset];
|
||||
double sum = Util::addIdenty();
|
||||
for (unsigned i = 0; i < ranges_[pos]; i++) {
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
sum += dist_->params[offset];
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
Util::logSum (sum, dist_->params[offset]);
|
||||
}
|
||||
offset += varOffset;
|
||||
}
|
||||
newPs.push_back (sum);
|
||||
count1 ++;
|
||||
if (varIndex == (int)vars_.size() - 1) {
|
||||
offset = count1 * vars_[varIndex]->getDomainSize();
|
||||
if (pos == (int)varids_.size() - 1) {
|
||||
offset = count1 * ranges_[pos];
|
||||
} else {
|
||||
if (((offset - varOffset + 1) % leftVarOffset) == 0) {
|
||||
count1 = 0;
|
||||
@ -227,11 +279,200 @@ Factor::removeVariable (const FgVarNode* var)
|
||||
}
|
||||
offset = (leftVarOffset * count2) + count1;
|
||||
}
|
||||
// ss << " = " << sum << endl;
|
||||
}
|
||||
// cout << ss.str() << endl;
|
||||
vars_.erase (vars_.begin() + varIndex);
|
||||
varids_.erase (varids_.begin() + pos);
|
||||
ranges_.erase (ranges_.begin() + pos);
|
||||
dist_->updateParameters (newPs);
|
||||
dist_->entries.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::removeFirstVariable (void)
|
||||
{
|
||||
ParamSet& params = dist_->params;
|
||||
unsigned nStates = ranges_.front();
|
||||
unsigned sep = params.size() / nStates;
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned i = sep; i < params.size(); i++) {
|
||||
params[i % sep] += params[i];
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned i = sep; i < params.size(); i++) {
|
||||
Util::logSum (params[i % sep], params[i]);
|
||||
}
|
||||
}
|
||||
params.resize (sep);
|
||||
varids_.erase (varids_.begin());
|
||||
ranges_.erase (ranges_.begin());
|
||||
dist_->entries.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::removeLastVariable (void)
|
||||
{
|
||||
ParamSet& params = dist_->params;
|
||||
unsigned nStates = ranges_.back();
|
||||
unsigned idx1 = 0;
|
||||
unsigned idx2 = 0;
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
while (idx1 < params.size()) {
|
||||
params[idx2] = params[idx1];
|
||||
idx1 ++;
|
||||
for (unsigned j = 1; j < nStates; j++) {
|
||||
params[idx2] += params[idx1];
|
||||
idx1 ++;
|
||||
}
|
||||
idx2 ++;
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
while (idx1 < params.size()) {
|
||||
params[idx2] = params[idx1];
|
||||
idx1 ++;
|
||||
for (unsigned j = 1; j < nStates; j++) {
|
||||
Util::logSum (params[idx2], params[idx1]);
|
||||
idx1 ++;
|
||||
}
|
||||
idx2 ++;
|
||||
}
|
||||
}
|
||||
params.resize (idx2);
|
||||
varids_.pop_back();
|
||||
ranges_.pop_back();
|
||||
dist_->entries.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::orderVariables (void)
|
||||
{
|
||||
VarIdSet sortedVarIds = varids_;
|
||||
sort (sortedVarIds.begin(), sortedVarIds.end());
|
||||
orderVariables (sortedVarIds);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::orderVariables (const VarIdSet& newVarIdOrder)
|
||||
{
|
||||
assert (newVarIdOrder.size() == varids_.size());
|
||||
if (newVarIdOrder == varids_) {
|
||||
return;
|
||||
}
|
||||
|
||||
Ranges newRangeOrder;
|
||||
for (unsigned i = 0; i < newVarIdOrder.size(); i++) {
|
||||
unsigned pos = getPositionOf (newVarIdOrder[i]);
|
||||
newRangeOrder.push_back (ranges_[pos]);
|
||||
}
|
||||
|
||||
vector<unsigned> positions;
|
||||
for (unsigned i = 0; i < newVarIdOrder.size(); i++) {
|
||||
positions.push_back (getPositionOf (newVarIdOrder[i]));
|
||||
}
|
||||
|
||||
unsigned N = ranges_.size();
|
||||
ParamSet newPs (dist_->params.size());
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
unsigned li = i;
|
||||
// calculate vector index corresponding to linear index
|
||||
vector<unsigned> vi (N);
|
||||
for (int k = N-1; k >= 0; k--) {
|
||||
vi[k] = li % ranges_[k];
|
||||
li /= ranges_[k];
|
||||
}
|
||||
// convert permuted vector index to corresponding linear index
|
||||
unsigned prod = 1;
|
||||
unsigned new_li = 0;
|
||||
for (int k = N-1; k >= 0; k--) {
|
||||
new_li += vi[positions[k]] * prod;
|
||||
prod *= ranges_[positions[k]];
|
||||
}
|
||||
newPs[new_li] = dist_->params[i];
|
||||
}
|
||||
varids_ = newVarIdOrder;
|
||||
ranges_ = newRangeOrder;
|
||||
dist_->params = newPs;
|
||||
dist_->entries.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::removeInconsistentEntries (VarId vid, unsigned evidence)
|
||||
{
|
||||
int pos = getPositionOf (vid);
|
||||
assert (pos != -1);
|
||||
ParamSet newPs;
|
||||
newPs.reserve (dist_->params.size() / ranges_[pos]);
|
||||
StatesIndexer idx (ranges_);
|
||||
for (unsigned i = 0; i < evidence; i++) {
|
||||
idx.incrementState (pos);
|
||||
}
|
||||
while (idx.valid()) {
|
||||
newPs.push_back (dist_->params[idx.getLinearIndex()]);
|
||||
idx.nextSameState (pos);
|
||||
}
|
||||
varids_.erase (varids_.begin() + pos);
|
||||
ranges_.erase (ranges_.begin() + pos);
|
||||
dist_->updateParameters (newPs);
|
||||
dist_->entries.clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
Factor::getLabel (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "f(" ;
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
ss << VarNode (varids_[i], ranges_[i]).label();
|
||||
}
|
||||
ss << ")" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::printFactor (void) const
|
||||
{
|
||||
VarNodes vars;
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
vars.push_back (new VarNode (varids_[i], ranges_[i]));
|
||||
}
|
||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
cout << "f(" << jointStrings[i] << ")" ;
|
||||
cout << " = " << dist_->params[i] << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
delete vars[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
Factor::getPositionOf (VarId vid) const
|
||||
{
|
||||
for (unsigned i = 0; i < varids_.size(); i++) {
|
||||
if (varids_[i] == vid) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
@ -242,21 +483,20 @@ Factor::getCptEntries (void) const
|
||||
if (dist_->entries.size() == 0) {
|
||||
vector<DConf> confs (dist_->params.size());
|
||||
for (unsigned i = 0; i < dist_->params.size(); i++) {
|
||||
confs[i].resize (vars_.size());
|
||||
confs[i].resize (varids_.size());
|
||||
}
|
||||
|
||||
unsigned nReps = 1;
|
||||
for (int i = vars_.size() - 1; i >= 0; i--) {
|
||||
for (int i = varids_.size() - 1; i >= 0; i--) {
|
||||
unsigned index = 0;
|
||||
while (index < dist_->params.size()) {
|
||||
for (unsigned j = 0; j < vars_[i]->getDomainSize(); j++) {
|
||||
for (unsigned j = 0; j < ranges_[i]; j++) {
|
||||
for (unsigned r = 0; r < nReps; r++) {
|
||||
confs[index][i] = j;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
nReps *= vars_[i]->getDomainSize();
|
||||
nReps *= ranges_[i];
|
||||
}
|
||||
dist_->entries.clear();
|
||||
dist_->entries.reserve (dist_->params.size());
|
||||
@ -267,53 +507,3 @@ Factor::getCptEntries (void) const
|
||||
return dist_->entries;
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
Factor::getLabel (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "Φ(" ;
|
||||
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||
if (i != 0) ss << "," ;
|
||||
ss << vars_[i]->getLabel();
|
||||
}
|
||||
ss << ")" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Factor::printFactor (void)
|
||||
{
|
||||
stringstream ss;
|
||||
ss << getLabel() << endl;
|
||||
ss << "--------------------" << endl;
|
||||
VarSet vs;
|
||||
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||
vs.push_back (vars_[i]);
|
||||
}
|
||||
vector<string> domainConfs = Util::getInstantiations (vs);
|
||||
const vector<CptEntry>& entries = getCptEntries();
|
||||
for (unsigned i = 0; i < entries.size(); i++) {
|
||||
ss << "Φ(" << domainConfs[i] << ")" ;
|
||||
unsigned idx = entries[i].getParameterIndex();
|
||||
ss << " = " << dist_->params[idx] << endl;
|
||||
}
|
||||
cout << ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
Factor::getIndexOf (const FgVarNode* var) const
|
||||
{
|
||||
for (unsigned i = 0; i < vars_.size(); i++) {
|
||||
if (vars_[i] == var) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
@ -1,48 +1,69 @@
|
||||
#ifndef BP_FACTOR_H
|
||||
#define BP_FACTOR_H
|
||||
#ifndef HORUS_FACTOR_H
|
||||
#define HORUS_FACTOR_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "Distribution.h"
|
||||
#include "CptEntry.h"
|
||||
#include "VarNode.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
class FgVarNode;
|
||||
class Distribution;
|
||||
|
||||
|
||||
class Factor
|
||||
{
|
||||
public:
|
||||
Factor (void) { }
|
||||
Factor (const Factor&);
|
||||
Factor (FgVarNode*);
|
||||
Factor (CFgVarSet);
|
||||
Factor (FgVarNode*, const ParamSet&);
|
||||
Factor (FgVarSet&, Distribution*);
|
||||
Factor (CFgVarSet, CParamSet);
|
||||
Factor (VarId, unsigned);
|
||||
Factor (const VarNodes&);
|
||||
Factor (VarId, unsigned, const ParamSet&);
|
||||
Factor (VarNodes&, Distribution*);
|
||||
Factor (const VarNodes&, const ParamSet&);
|
||||
Factor (const VarIdSet&, const Ranges&, const ParamSet&);
|
||||
|
||||
void setParameters (CParamSet);
|
||||
void copyFactor (const Factor& f);
|
||||
void multiplyByFactor (const Factor& f, const vector<CptEntry>* = 0);
|
||||
void insertVariable (FgVarNode* index);
|
||||
void removeVariable (const FgVarNode* var);
|
||||
const vector<CptEntry>& getCptEntries (void) const;
|
||||
void setParameters (const ParamSet&);
|
||||
void copyFromFactor (const Factor& f);
|
||||
void multiplyByFactor (const Factor&, const vector<CptEntry>* = 0);
|
||||
void insertVariable (VarId, unsigned);
|
||||
void removeAllVariablesExcept (VarId);
|
||||
void removeVariable (VarId);
|
||||
void removeFirstVariable (void);
|
||||
void removeLastVariable (void);
|
||||
void orderVariables (void);
|
||||
void orderVariables (const VarIdSet&);
|
||||
void removeInconsistentEntries (VarId, unsigned);
|
||||
string getLabel (void) const;
|
||||
void printFactor (void);
|
||||
void printFactor (void) const;
|
||||
int getPositionOf (VarId) const;
|
||||
const vector<CptEntry>& getCptEntries (void) const;
|
||||
|
||||
CFgVarSet getFgVarNodes (void) const { return vars_; }
|
||||
CParamSet getParameters (void) const { return dist_->params; }
|
||||
const VarIdSet& getVarIds (void) const { return varids_; }
|
||||
const Ranges& getRanges (void) const { return ranges_; }
|
||||
const ParamSet& getParameters (void) const { return dist_->params; }
|
||||
Distribution* getDistribution (void) const { return dist_; }
|
||||
unsigned getIndex (void) const { return index_; }
|
||||
void setIndex (unsigned index) { index_ = index; }
|
||||
void freeDistribution (void) { delete dist_; dist_ = 0;}
|
||||
int getIndexOf (const FgVarNode*) const;
|
||||
unsigned nrVariables (void) const { return varids_.size(); }
|
||||
unsigned nrParameters() const { return dist_->params.size(); }
|
||||
|
||||
void setDistribution (Distribution* dist)
|
||||
{
|
||||
dist_ = dist;
|
||||
}
|
||||
void freeDistribution (void)
|
||||
{
|
||||
delete dist_;
|
||||
dist_ = 0;
|
||||
}
|
||||
|
||||
private:
|
||||
FgVarSet vars_;
|
||||
|
||||
VarIdSet varids_;
|
||||
Ranges ranges_;
|
||||
Distribution* dist_;
|
||||
unsigned index_;
|
||||
};
|
||||
|
||||
#endif //BP_FACTOR_H
|
||||
#endif // HORUS_FACTOR_H
|
||||
|
||||
|
@ -1,18 +1,48 @@
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "FgVarNode.h"
|
||||
#include "Factor.h"
|
||||
#include "BayesNet.h"
|
||||
|
||||
|
||||
FactorGraph::FactorGraph (const char* fileName)
|
||||
|
||||
FactorGraph::FactorGraph (const BayesNet& bn)
|
||||
{
|
||||
const BnNodeSet& nodes = bn.getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
FgVarNode* varNode = new FgVarNode (nodes[i]);
|
||||
addVariable (varNode);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
const BnNodeSet& parents = nodes[i]->getParents();
|
||||
if (!(nodes[i]->hasEvidence() && parents.size() == 0)) {
|
||||
VarNodes neighs;
|
||||
neighs.push_back (varNodes_[nodes[i]->getIndex()]);
|
||||
for (unsigned j = 0; j < parents.size(); j++) {
|
||||
neighs.push_back (varNodes_[parents[j]->getIndex()]);
|
||||
}
|
||||
FgFacNode* fn = new FgFacNode (
|
||||
new Factor (neighs, nodes[i]->getDistribution()));
|
||||
addFactor (fn);
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::readFromUaiFormat (const char* fileName)
|
||||
{
|
||||
ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
@ -29,90 +59,159 @@ FactorGraph::FactorGraph (const char* fileName)
|
||||
}
|
||||
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
int nVars;
|
||||
unsigned nVars;
|
||||
is >> nVars;
|
||||
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
vector<int> domainSizes (nVars);
|
||||
for (int i = 0; i < nVars; i++) {
|
||||
int ds;
|
||||
for (unsigned i = 0; i < nVars; i++) {
|
||||
unsigned ds;
|
||||
is >> ds;
|
||||
domainSizes[i] = ds;
|
||||
}
|
||||
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
for (int i = 0; i < nVars; i++) {
|
||||
for (unsigned i = 0; i < nVars; i++) {
|
||||
addVariable (new FgVarNode (i, domainSizes[i]));
|
||||
}
|
||||
|
||||
int nFactors;
|
||||
unsigned nFactors;
|
||||
is >> nFactors;
|
||||
for (int i = 0; i < nFactors; i++) {
|
||||
for (unsigned i = 0; i < nFactors; i++) {
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
int nFactorVars;
|
||||
unsigned nFactorVars;
|
||||
is >> nFactorVars;
|
||||
FgVarSet factorVars;
|
||||
for (int j = 0; j < nFactorVars; j++) {
|
||||
int vid;
|
||||
VarNodes neighs;
|
||||
for (unsigned j = 0; j < nFactorVars; j++) {
|
||||
unsigned vid;
|
||||
is >> vid;
|
||||
FgVarNode* var = getFgVarNode (vid);
|
||||
if (!var) {
|
||||
FgVarNode* neigh = getFgVarNode (vid);
|
||||
if (!neigh) {
|
||||
cerr << "error: invalid variable identifier (" << vid << ")" << endl;
|
||||
abort();
|
||||
}
|
||||
factorVars.push_back (var);
|
||||
neighs.push_back (neigh);
|
||||
}
|
||||
Factor* f = new Factor (factorVars);
|
||||
factors_.push_back (f);
|
||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||
factorVars[j]->addFactor (f);
|
||||
FgFacNode* fn = new FgFacNode (new Factor (neighs));
|
||||
addFactor (fn);
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < nFactors; i++) {
|
||||
for (unsigned i = 0; i < nFactors; i++) {
|
||||
while (is.peek() == '#' || is.peek() == '\n') getline (is, line);
|
||||
int nParams;
|
||||
unsigned nParams;
|
||||
is >> nParams;
|
||||
if (facNodes_[i]->getParameters().size() != nParams) {
|
||||
cerr << "error: invalid number of parameters for factor " ;
|
||||
cerr << facNodes_[i]->getLabel() ;
|
||||
cerr << ", expected: " << facNodes_[i]->getParameters().size();
|
||||
cerr << ", given: " << nParams << endl;
|
||||
abort();
|
||||
}
|
||||
ParamSet params (nParams);
|
||||
for (int j = 0; j < nParams; j++) {
|
||||
for (unsigned j = 0; j < nParams; j++) {
|
||||
double param;
|
||||
is >> param;
|
||||
params[j] = param;
|
||||
}
|
||||
factors_[i]->setParameters (params);
|
||||
if (NSPACE == NumberSpace::LOGARITHM) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
facNodes_[i]->factor()->setParameters (params);
|
||||
}
|
||||
is.close();
|
||||
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
varNodes_[i]->setIndex (i);
|
||||
}
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph::FactorGraph (const BayesNet& bn)
|
||||
void
|
||||
FactorGraph::readFromLibDaiFormat (const char* fileName)
|
||||
{
|
||||
const BnNodeSet& nodes = bn.getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
FgVarNode* varNode = new FgVarNode (nodes[i]);
|
||||
varNode->setIndex (i);
|
||||
addVariable (varNode);
|
||||
ifstream is (fileName);
|
||||
if (!is.is_open()) {
|
||||
cerr << "error: cannot read from file " + std::string (fileName) << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
const BnNodeSet& parents = nodes[i]->getParents();
|
||||
if (!(nodes[i]->hasEvidence() && parents.size() == 0)) {
|
||||
FgVarSet factorVars = { varNodes_[nodes[i]->getIndex()] };
|
||||
for (unsigned j = 0; j < parents.size(); j++) {
|
||||
factorVars.push_back (varNodes_[parents[j]->getIndex()]);
|
||||
string line;
|
||||
unsigned nFactors;
|
||||
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> nFactors;
|
||||
|
||||
if (is.fail()) {
|
||||
cerr << "error: cannot read the number of factors" << endl;
|
||||
abort();
|
||||
}
|
||||
Factor* f = new Factor (factorVars, nodes[i]->getDistribution());
|
||||
factors_.push_back (f);
|
||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||
factorVars[j]->addFactor (f);
|
||||
|
||||
getline (is, line);
|
||||
if (is.fail() || line.size() > 0) {
|
||||
cerr << "error: cannot read the number of factors" << endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < nFactors; i++) {
|
||||
unsigned nVars;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
|
||||
is >> nVars;
|
||||
VarIdSet vids;
|
||||
for (unsigned j = 0; j < nVars; j++) {
|
||||
VarId vid;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> vid;
|
||||
vids.push_back (vid);
|
||||
}
|
||||
|
||||
VarNodes neighs;
|
||||
unsigned nParams = 1;
|
||||
for (unsigned j = 0; j < nVars; j++) {
|
||||
unsigned dsize;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> dsize;
|
||||
FgVarNode* var = getFgVarNode (vids[j]);
|
||||
if (var == 0) {
|
||||
var = new FgVarNode (vids[j], dsize);
|
||||
addVariable (var);
|
||||
} else {
|
||||
if (var->nrStates() != dsize) {
|
||||
cerr << "error: variable `" << vids[j] << "' appears in two or " ;
|
||||
cerr << "more factors with different domain sizes" << endl;
|
||||
}
|
||||
}
|
||||
neighs.push_back (var);
|
||||
nParams *= var->nrStates();
|
||||
}
|
||||
ParamSet params (nParams, 0);
|
||||
unsigned nNonzeros;
|
||||
while ((is.peek()) == '#')
|
||||
getline (is, line);
|
||||
is >> nNonzeros;
|
||||
|
||||
for (unsigned j = 0; j < nNonzeros; j++) {
|
||||
unsigned index;
|
||||
Param val;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> index;
|
||||
while ((is.peek()) == '#') getline (is, line);
|
||||
is >> val;
|
||||
params[index] = val;
|
||||
}
|
||||
reverse (neighs.begin(), neighs.end());
|
||||
if (NSPACE == NumberSpace::LOGARITHM) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
FgFacNode* fn = new FgFacNode (new Factor (neighs, params));
|
||||
addFactor (fn);
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
addEdge (fn, static_cast<FgVarNode*> (neighs[j]));
|
||||
}
|
||||
}
|
||||
is.close();
|
||||
setIndexes();
|
||||
}
|
||||
|
||||
|
||||
@ -122,82 +221,63 @@ FactorGraph::~FactorGraph (void)
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
delete varNodes_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
delete factors_[i];
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
delete facNodes_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addVariable (FgVarNode* varNode)
|
||||
FactorGraph::addVariable (FgVarNode* vn)
|
||||
{
|
||||
varNodes_.push_back (varNode);
|
||||
varNode->setIndex (varNodes_.size() - 1);
|
||||
indexMap_.insert (make_pair (varNode->getVarId(), varNodes_.size() - 1));
|
||||
varNodes_.push_back (vn);
|
||||
vn->setIndex (varNodes_.size() - 1);
|
||||
indexMap_.insert (make_pair (vn->varId(), varNodes_.size() - 1));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::removeVariable (const FgVarNode* var)
|
||||
FactorGraph::addFactor (FgFacNode* fn)
|
||||
{
|
||||
if (varNodes_[varNodes_.size() - 1] == var) {
|
||||
varNodes_.pop_back();
|
||||
} else {
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
if (varNodes_[i] == var) {
|
||||
varNodes_.erase (varNodes_.begin() + i);
|
||||
return;
|
||||
}
|
||||
}
|
||||
assert (false);
|
||||
}
|
||||
indexMap_.erase (indexMap_.find (var->getVarId()));
|
||||
facNodes_.push_back (fn);
|
||||
fn->setIndex (facNodes_.size() - 1);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addEdge (FgVarNode* vn, FgFacNode* fn)
|
||||
{
|
||||
vn->addNeighbor (fn);
|
||||
fn->addNeighbor (vn);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::addFactor (Factor* f)
|
||||
FactorGraph::addEdge (FgFacNode* fn, FgVarNode* vn)
|
||||
{
|
||||
factors_.push_back (f);
|
||||
const FgVarSet& factorVars = f->getFgVarNodes();
|
||||
for (unsigned i = 0; i < factorVars.size(); i++) {
|
||||
factorVars[i]->addFactor (f);
|
||||
}
|
||||
fn->addNeighbor (vn);
|
||||
vn->addNeighbor (fn);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::removeFactor (const Factor* f)
|
||||
VarNode*
|
||||
FactorGraph::getVariableNode (VarId vid) const
|
||||
{
|
||||
const FgVarSet& factorVars = f->getFgVarNodes();
|
||||
for (unsigned i = 0; i < factorVars.size(); i++) {
|
||||
if (factorVars[i]) {
|
||||
factorVars[i]->removeFactor (f);
|
||||
}
|
||||
}
|
||||
if (factors_[factors_.size() - 1] == f) {
|
||||
factors_.pop_back();
|
||||
} else {
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
if (factors_[i] == f) {
|
||||
factors_.erase (factors_.begin() + i);
|
||||
return;
|
||||
}
|
||||
}
|
||||
assert (false);
|
||||
}
|
||||
FgVarNode* vn = getFgVarNode (vid);
|
||||
assert (vn);
|
||||
return vn;
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarSet
|
||||
FactorGraph::getVariables (void) const
|
||||
VarNodes
|
||||
FactorGraph::getVariableNodes (void) const
|
||||
{
|
||||
VarSet vars;
|
||||
VarNodes vars;
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
vars.push_back (varNodes_[i]);
|
||||
}
|
||||
@ -206,10 +286,10 @@ FactorGraph::getVariables (void) const
|
||||
|
||||
|
||||
|
||||
Variable*
|
||||
FactorGraph::getVariable (Vid vid) const
|
||||
bool
|
||||
FactorGraph::isTree (void) const
|
||||
{
|
||||
return getFgVarNode (vid);
|
||||
return !containsCycle();
|
||||
}
|
||||
|
||||
|
||||
@ -220,8 +300,8 @@ FactorGraph::setIndexes (void)
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
varNodes_[i]->setIndex (i);
|
||||
}
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
factors_[i]->setIndex (i);
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
facNodes_[i]->setIndex (i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -231,8 +311,8 @@ void
|
||||
FactorGraph::freeDistributions (void)
|
||||
{
|
||||
set<Distribution*> dists;
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
dists.insert (factors_[i]->getDistribution());
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
dists.insert (facNodes_[i]->factor()->getDistribution());
|
||||
}
|
||||
for (set<Distribution*>::iterator it = dists.begin();
|
||||
it != dists.end(); it++) {
|
||||
@ -246,19 +326,18 @@ void
|
||||
FactorGraph::printGraphicalModel (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
cout << "variable number " << varNodes_[i]->getIndex() << endl;
|
||||
cout << "Id = " << varNodes_[i]->getVarId() << endl;
|
||||
cout << "Label = " << varNodes_[i]->getLabel() << endl;
|
||||
cout << "Domain size = " << varNodes_[i]->getDomainSize() << endl;
|
||||
cout << "VarId = " << varNodes_[i]->varId() << endl;
|
||||
cout << "Label = " << varNodes_[i]->label() << endl;
|
||||
cout << "Nr States = " << varNodes_[i]->nrStates() << endl;
|
||||
cout << "Evidence = " << varNodes_[i]->getEvidence() << endl;
|
||||
cout << "Factors = " ;
|
||||
for (unsigned j = 0; j < varNodes_[i]->getFactors().size(); j++) {
|
||||
cout << varNodes_[i]->getFactors()[j]->getLabel() << " " ;
|
||||
for (unsigned j = 0; j < varNodes_[i]->neighbors().size(); j++) {
|
||||
cout << varNodes_[i]->neighbors()[j]->getLabel() << " " ;
|
||||
}
|
||||
cout << endl << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
factors_[i]->printFactor();
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
facNodes_[i]->factor()->printFactor();
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
@ -266,7 +345,7 @@ FactorGraph::printGraphicalModel (void) const
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::exportToDotFormat (const char* fileName) const
|
||||
FactorGraph::exportToGraphViz (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
@ -279,24 +358,23 @@ FactorGraph::exportToDotFormat (const char* fileName) const
|
||||
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
if (varNodes_[i]->hasEvidence()) {
|
||||
out << '"' << varNodes_[i]->getLabel() << '"' ;
|
||||
out << '"' << varNodes_[i]->label() << '"' ;
|
||||
out << " [style=filled, fillcolor=yellow]" << endl;
|
||||
}
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
out << '"' << factors_[i]->getLabel() << '"' ;
|
||||
out << " [label=\"" << factors_[i]->getLabel() << "\\n(";
|
||||
out << factors_[i]->getDistribution()->id << ")" << "\"" ;
|
||||
out << ", shape=box]" << endl;
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||
out << " [label=\"" << facNodes_[i]->getLabel();
|
||||
out << "\"" << ", shape=box]" << endl;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
CFgVarSet myVars = factors_[i]->getFgVarNodes();
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const FgVarSet& myVars = facNodes_[i]->neighbors();
|
||||
for (unsigned j = 0; j < myVars.size(); j++) {
|
||||
out << '"' << factors_[i]->getLabel() << '"' ;
|
||||
out << '"' << facNodes_[i]->getLabel() << '"' ;
|
||||
out << " -- " ;
|
||||
out << '"' << myVars[j]->getLabel() << '"' << endl;
|
||||
out << '"' << myVars[j]->label() << '"' << endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -319,13 +397,13 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||
out << "MARKOV" << endl;
|
||||
out << varNodes_.size() << endl;
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
out << varNodes_[i]->getDomainSize() << " " ;
|
||||
out << varNodes_[i]->nrStates() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
|
||||
out << factors_.size() << endl;
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
CFgVarSet factorVars = factors_[i]->getFgVarNodes();
|
||||
out << facNodes_.size() << endl;
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const FgVarSet& factorVars = facNodes_[i]->neighbors();
|
||||
out << factorVars.size();
|
||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||
out << " " << factorVars[j]->getIndex();
|
||||
@ -333,8 +411,8 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||
out << endl;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
CParamSet params = factors_[i]->getParameters();
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const ParamSet& params = facNodes_[i]->getParameters();
|
||||
out << endl << params.size() << endl << " " ;
|
||||
for (unsigned j = 0; j < params.size(); j++) {
|
||||
out << params[j] << " " ;
|
||||
@ -345,3 +423,102 @@ FactorGraph::exportToUaiFormat (const char* fileName) const
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FactorGraph::exportToLibDaiFormat (const char* fileName) const
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "FactorGraph::exportToLibDaiFormat()" << endl;
|
||||
abort();
|
||||
}
|
||||
out << facNodes_.size() << endl << endl;
|
||||
for (unsigned i = 0; i < facNodes_.size(); i++) {
|
||||
const FgVarSet& factorVars = facNodes_[i]->neighbors();
|
||||
out << factorVars.size() << endl;
|
||||
for (int j = factorVars.size() - 1; j >= 0; j--) {
|
||||
out << factorVars[j]->varId() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
for (unsigned j = 0; j < factorVars.size(); j++) {
|
||||
out << factorVars[j]->nrStates() << " " ;
|
||||
}
|
||||
out << endl;
|
||||
const ParamSet& params = facNodes_[i]->factor()->getParameters();
|
||||
out << params.size() << endl;
|
||||
for (unsigned j = 0; j < params.size(); j++) {
|
||||
out << j << " " << params[j] << endl;
|
||||
}
|
||||
out << endl;
|
||||
}
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (void) const
|
||||
{
|
||||
vector<bool> visitedVars (varNodes_.size(), false);
|
||||
vector<bool> visitedFactors (facNodes_.size(), false);
|
||||
for (unsigned i = 0; i < varNodes_.size(); i++) {
|
||||
int v = varNodes_[i]->getIndex();
|
||||
if (!visitedVars[v]) {
|
||||
if (containsCycle (varNodes_[i], 0, visitedVars, visitedFactors)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (const FgVarNode* v,
|
||||
const FgFacNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedVars[v->getIndex()] = true;
|
||||
const FgFacSet& adjacencies = v->neighbors();
|
||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||
int w = adjacencies[i]->getIndex();
|
||||
if (!visitedFactors[w]) {
|
||||
if (containsCycle (adjacencies[i], v, visitedVars, visitedFactors)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
else if (visitedFactors[w] && adjacencies[i] != p) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false; // no cycle detected in this component
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FactorGraph::containsCycle (const FgFacNode* v,
|
||||
const FgVarNode* p,
|
||||
vector<bool>& visitedVars,
|
||||
vector<bool>& visitedFactors) const
|
||||
{
|
||||
visitedFactors[v->getIndex()] = true;
|
||||
const FgVarSet& adjacencies = v->neighbors();
|
||||
for (unsigned i = 0; i < adjacencies.size(); i++) {
|
||||
int w = adjacencies[i]->getIndex();
|
||||
if (!visitedVars[w]) {
|
||||
if (containsCycle (adjacencies[i], v, visitedVars, visitedFactors)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
else if (visitedVars[w] && adjacencies[i] != p) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false; // no cycle detected in this component
|
||||
}
|
||||
|
||||
|
@ -1,41 +1,116 @@
|
||||
#ifndef BP_FACTOR_GRAPH_H
|
||||
#define BP_FACTOR_GRAPH_H
|
||||
#ifndef HORUS_FACTORGRAPH_H
|
||||
#define HORUS_FACTORGRAPH_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "GraphicalModel.h"
|
||||
#include "Shared.h"
|
||||
#include "Distribution.h"
|
||||
#include "Factor.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class FgVarNode;
|
||||
class Factor;
|
||||
class BayesNet;
|
||||
|
||||
class FgFacNode;
|
||||
|
||||
class FgVarNode : public VarNode
|
||||
{
|
||||
public:
|
||||
FgVarNode (VarId varId, unsigned nrStates) : VarNode (varId, nrStates) { }
|
||||
FgVarNode (const VarNode* v) : VarNode (v) { }
|
||||
void addNeighbor (FgFacNode* fn)
|
||||
{
|
||||
neighs_.push_back (fn);
|
||||
}
|
||||
const vector<FgFacNode*>& neighbors (void) const
|
||||
{
|
||||
return neighs_;
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
||||
// members
|
||||
vector<FgFacNode*> neighs_;
|
||||
};
|
||||
|
||||
|
||||
class FgFacNode
|
||||
{
|
||||
public:
|
||||
FgFacNode (Factor* factor)
|
||||
{
|
||||
factor_ = factor;
|
||||
index_ = -1;
|
||||
}
|
||||
Factor* factor() const
|
||||
{
|
||||
return factor_;
|
||||
}
|
||||
void addNeighbor (FgVarNode* vn)
|
||||
{
|
||||
neighs_.push_back (vn);
|
||||
}
|
||||
const vector<FgVarNode*>& neighbors (void) const
|
||||
{
|
||||
return neighs_;
|
||||
}
|
||||
int getIndex (void) const
|
||||
{
|
||||
assert (index_ != -1);
|
||||
return index_;
|
||||
}
|
||||
void setIndex (int index)
|
||||
{
|
||||
index_ = index;
|
||||
}
|
||||
Distribution* getDistribution (void)
|
||||
{
|
||||
return factor_->getDistribution();
|
||||
}
|
||||
const ParamSet& getParameters (void) const
|
||||
{
|
||||
return factor_->getParameters();
|
||||
}
|
||||
string getLabel (void)
|
||||
{
|
||||
return factor_->getLabel();
|
||||
}
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (FgFacNode);
|
||||
|
||||
Factor* factor_;
|
||||
int index_;
|
||||
vector<FgVarNode*> neighs_;
|
||||
};
|
||||
|
||||
|
||||
class FactorGraph : public GraphicalModel
|
||||
{
|
||||
public:
|
||||
FactorGraph (void) {};
|
||||
FactorGraph (const char*);
|
||||
FactorGraph (const BayesNet&);
|
||||
~FactorGraph (void);
|
||||
|
||||
void readFromUaiFormat (const char*);
|
||||
void readFromLibDaiFormat (const char*);
|
||||
void addVariable (FgVarNode*);
|
||||
void removeVariable (const FgVarNode*);
|
||||
void addFactor (Factor*);
|
||||
void removeFactor (const Factor*);
|
||||
VarSet getVariables (void) const;
|
||||
Variable* getVariable (unsigned) const;
|
||||
void addFactor (FgFacNode*);
|
||||
void addEdge (FgVarNode*, FgFacNode*);
|
||||
void addEdge (FgFacNode*, FgVarNode*);
|
||||
VarNode* getVariableNode (unsigned) const;
|
||||
VarNodes getVariableNodes (void) const;
|
||||
bool isTree (void) const;
|
||||
void setIndexes (void);
|
||||
void freeDistributions (void);
|
||||
void printGraphicalModel (void) const;
|
||||
void exportToDotFormat (const char*) const;
|
||||
void exportToGraphViz (const char*) const;
|
||||
void exportToUaiFormat (const char*) const;
|
||||
void exportToLibDaiFormat (const char*) const;
|
||||
|
||||
const FgVarSet& getFgVarNodes (void) const { return varNodes_; }
|
||||
const FactorSet& getFactors (void) const { return factors_; }
|
||||
const FgVarSet& getVarNodes (void) const { return varNodes_; }
|
||||
const FgFacSet& getFactorNodes (void) const { return facNodes_; }
|
||||
|
||||
FgVarNode* getFgVarNode (Vid vid) const
|
||||
FgVarNode* getFgVarNode (VarId vid) const
|
||||
{
|
||||
IndexMap::const_iterator it = indexMap_.find (vid);
|
||||
if (it == indexMap_.end()) {
|
||||
@ -46,12 +121,20 @@ class FactorGraph : public GraphicalModel
|
||||
}
|
||||
|
||||
private:
|
||||
bool containsCycle (void) const;
|
||||
bool containsCycle (const FgVarNode*, const FgFacNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
bool containsCycle (const FgFacNode*, const FgVarNode*,
|
||||
vector<bool>&, vector<bool>&) const;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN (FactorGraph);
|
||||
|
||||
FgVarSet varNodes_;
|
||||
FactorSet factors_;
|
||||
FgFacSet facNodes_;
|
||||
|
||||
typedef unordered_map<unsigned, unsigned> IndexMap;
|
||||
IndexMap indexMap_;
|
||||
};
|
||||
|
||||
#endif // BP_FACTOR_GRAPH_H
|
||||
#endif // HORUS_FACTORGRAPH_H
|
||||
|
||||
|
499
packages/CLPBN/clpbn/bp/FgBpSolver.cpp
Normal file
499
packages/CLPBN/clpbn/bp/FgBpSolver.cpp
Normal file
@ -0,0 +1,499 @@
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "FgBpSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "Factor.h"
|
||||
#include "Shared.h"
|
||||
|
||||
|
||||
FgBpSolver::FgBpSolver (const FactorGraph& fg) : Solver (&fg)
|
||||
{
|
||||
factorGraph_ = &fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
FgBpSolver::~FgBpSolver (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
||||
delete facsI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::runSolver (void)
|
||||
{
|
||||
clock_t start;
|
||||
if (COLLECT_STATISTICS) {
|
||||
start = clock();
|
||||
}
|
||||
if (false) {
|
||||
//if (!BpOptions::useAlwaysLoopySolver && factorGraph_->isTree()) {
|
||||
runTreeSolver();
|
||||
} else {
|
||||
runLoopySolver();
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
if (nIters_ < BpOptions::maxIter) {
|
||||
cout << "Sum-Product converged in " ;
|
||||
cout << nIters_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
unsigned size = factorGraph_->getVarNodes().size();
|
||||
if (COLLECT_STATISTICS) {
|
||||
unsigned nIters = 0;
|
||||
bool loopy = factorGraph_->isTree() == false;
|
||||
if (loopy) nIters = nIters_;
|
||||
double time = (double (clock() - start)) / CLOCKS_PER_SEC;
|
||||
Statistics::updateStatistics (size, loopy, nIters, time);
|
||||
}
|
||||
if (EXPORT_TO_GRAPHVIZ && size > EXPORT_MINIMAL_SIZE) {
|
||||
stringstream ss;
|
||||
ss << Statistics::getSolvedNetworksCounting() << "." << size << ".dot" ;
|
||||
factorGraph_->exportToGraphViz (ss.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
FgBpSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
assert (factorGraph_->getFgVarNode (vid));
|
||||
FgVarNode* var = factorGraph_->getFgVarNode (vid);
|
||||
ParamSet probs;
|
||||
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->nrStates(), Util::noEvidence());
|
||||
probs[var->getEvidence()] = Util::withEvidence();
|
||||
} else {
|
||||
probs.resize (var->nrStates(), Util::multIdenty());
|
||||
const SpLinkSet& links = ninf(var)->getLinks();
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Util::multiply (probs, links[i]->getMessage());
|
||||
}
|
||||
Util::normalize (probs);
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
Util::add (probs, links[i]->getMessage());
|
||||
}
|
||||
Util::normalize (probs);
|
||||
Util::fromLog (probs);
|
||||
}
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
FgBpSolver::getJointDistributionOf (const VarIdSet& jointVarIds)
|
||||
{
|
||||
unsigned msgSize = 1;
|
||||
vector<unsigned> dsizes (jointVarIds.size());
|
||||
for (unsigned i = 0; i < jointVarIds.size(); i++) {
|
||||
dsizes[i] = factorGraph_->getFgVarNode (jointVarIds[i])->nrStates();
|
||||
msgSize *= dsizes[i];
|
||||
}
|
||||
unsigned reps = 1;
|
||||
ParamSet jointDist (msgSize, Util::multIdenty());
|
||||
for (int i = jointVarIds.size() - 1 ; i >= 0; i--) {
|
||||
Util::multiply (jointDist, getPosterioriOf (jointVarIds[i]), reps);
|
||||
reps *= dsizes[i];
|
||||
}
|
||||
return jointDist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::runTreeSolver (void)
|
||||
{
|
||||
initializeSolver();
|
||||
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
|
||||
bool finish = false;
|
||||
while (!finish) {
|
||||
finish = true;
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const SpLinkSet& links = ninf (facNodes[i])->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (!links[j]->messageWasSended()) {
|
||||
if (readyToSendMessage (links[j])) {
|
||||
calculateAndUpdateMessage (links[j], false);
|
||||
}
|
||||
finish = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FgBpSolver::readyToSendMessage (const SpLink* link) const
|
||||
{
|
||||
const FgVarSet& neighbors = link->getFactor()->neighbors();
|
||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
||||
if (neighbors[i] != link->getVariable()) {
|
||||
const SpLinkSet& links = ninf (neighbors[i])->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getFactor() != link->getFactor() &&
|
||||
!links[j]->messageWasSended()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::runLoopySolver (void)
|
||||
{
|
||||
initializeSolver();
|
||||
nIters_ = 0;
|
||||
|
||||
while (!converged() && nIters_ < BpOptions::maxIter) {
|
||||
|
||||
nIters_ ++;
|
||||
if (DL >= 2) {
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
cout << " Iteration " << nIters_ << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
|
||||
case BpOptions::Schedule::SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateAndUpdateMessage (links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
updateMessage(links_[i]);
|
||||
}
|
||||
break;
|
||||
|
||||
case BpOptions::Schedule::MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::initializeSolver (void)
|
||||
{
|
||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
varsI_.reserve (varNodes.size());
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
varsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
|
||||
for (unsigned i = 0; i < facsI_.size(); i++) {
|
||||
delete facsI_[i];
|
||||
}
|
||||
facsI_.reserve (facNodes.size());
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
facsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
createLinks();
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
FgFacNode* src = links_[i]->getFactor();
|
||||
FgVarNode* dst = links_[i]->getVariable();
|
||||
ninf (dst)->addSpLink (links_[i]);
|
||||
ninf (src)->addSpLink (links_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::createLinks (void)
|
||||
{
|
||||
const FgFacSet& facNodes = factorGraph_->getFactorNodes();
|
||||
for (unsigned i = 0; i < facNodes.size(); i++) {
|
||||
const FgVarSet& neighbors = facNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||
links_.push_back (new SpLink (facNodes[i], neighbors[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
FgBpSolver::converged (void)
|
||||
{
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIters_ == 0 || nIters_ == 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (BpOptions::schedule == BpOptions::Schedule::MAX_RESIDUAL) {
|
||||
Param maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||
if (maxResidual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
} else {
|
||||
converged = true;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (DL >= 2) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > BpOptions::accuracy) {
|
||||
converged = false;
|
||||
if (DL == 0) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIters_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
calculateMessage (links_[i]);
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
cout << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
SpLink* link = *it;
|
||||
if (link->getResidual() < BpOptions::accuracy) {
|
||||
return;
|
||||
}
|
||||
updateMessage (link);
|
||||
link->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
const FgFacSet& factorNeighbors = link->getVariable()->neighbors();
|
||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||
if (factorNeighbors[i] != link->getFactor()) {
|
||||
const SpLinkSet& links = ninf(factorNeighbors[i])->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getVariable() != link->getVariable()) {
|
||||
calculateMessage (links[j]);
|
||||
SpLinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (DL >= 2) {
|
||||
cout << "----------------------------------------" ;
|
||||
cout << "----------------------------------------" << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::calculateFactor2VariableMsg (SpLink* link) const
|
||||
{
|
||||
const FgFacNode* src = link->getFactor();
|
||||
const FgVarNode* dst = link->getVariable();
|
||||
const SpLinkSet& links = ninf(src)->getLinks();
|
||||
// calculate the product of messages that were sent
|
||||
// to factor `src', except from var `dst'
|
||||
unsigned msgSize = 1;
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
msgSize *= links[i]->getVariable()->nrStates();
|
||||
}
|
||||
unsigned repetitions = 1;
|
||||
ParamSet msgProduct (msgSize, Util::multIdenty());
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
if (DL >= 5) {
|
||||
cout << " message from " << links[i]->getVariable()->label();
|
||||
cout << ": " << endl;
|
||||
}
|
||||
Util::multiply (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->nrStates();
|
||||
} else {
|
||||
unsigned ds = links[i]->getVariable()->nrStates();
|
||||
Util::multiply (msgProduct, ParamSet (ds, 1.0), repetitions);
|
||||
repetitions *= ds;
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (int i = links.size() - 1; i >= 0; i--) {
|
||||
if (links[i]->getVariable() != dst) {
|
||||
Util::add (msgProduct, getVar2FactorMsg (links[i]), repetitions);
|
||||
repetitions *= links[i]->getVariable()->nrStates();
|
||||
} else {
|
||||
unsigned ds = links[i]->getVariable()->nrStates();
|
||||
Util::add (msgProduct, ParamSet (ds, 1.0), repetitions);
|
||||
repetitions *= ds;
|
||||
}
|
||||
}
|
||||
}
|
||||
Factor result (src->factor()->getVarIds(),
|
||||
src->factor()->getRanges(),
|
||||
msgProduct);
|
||||
result.multiplyByFactor (*(src->factor()));
|
||||
if (DL >= 5) {
|
||||
cout << " message product: " ;
|
||||
cout << Util::parametersToString (msgProduct) << endl;
|
||||
cout << " original factor: " ;
|
||||
cout << Util::parametersToString (src->getParameters()) << endl;
|
||||
cout << " factor product: " ;
|
||||
cout << Util::parametersToString (result.getParameters()) << endl;
|
||||
}
|
||||
result.removeAllVariablesExcept (dst->varId());
|
||||
if (DL >= 5) {
|
||||
cout << " marginalized: " ;
|
||||
cout << Util::parametersToString (result.getParameters()) << endl;
|
||||
}
|
||||
const ParamSet& resultParams = result.getParameters();
|
||||
ParamSet& message = link->getNextMessage();
|
||||
for (unsigned i = 0; i < resultParams.size(); i++) {
|
||||
message[i] = resultParams[i];
|
||||
}
|
||||
Util::normalize (message);
|
||||
if (DL >= 5) {
|
||||
cout << " curr msg: " ;
|
||||
cout << Util::parametersToString (link->getMessage()) << endl;
|
||||
cout << " next msg: " ;
|
||||
cout << Util::parametersToString (message) << endl;
|
||||
}
|
||||
result.freeDistribution();
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
FgBpSolver::getVar2FactorMsg (const SpLink* link) const
|
||||
{
|
||||
const FgVarNode* src = link->getVariable();
|
||||
const FgFacNode* dst = link->getFactor();
|
||||
ParamSet msg;
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->nrStates(), Util::noEvidence());
|
||||
msg[src->getEvidence()] = Util::withEvidence();
|
||||
if (DL >= 5) {
|
||||
cout << Util::parametersToString (msg);
|
||||
}
|
||||
} else {
|
||||
msg.resize (src->nrStates(), Util::one());
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << Util::parametersToString (msg);
|
||||
}
|
||||
const SpLinkSet& links = ninf (src)->getLinks();
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
Util::multiply (msg, links[i]->getMessage());
|
||||
if (DL >= 5) {
|
||||
cout << " x " << Util::parametersToString (links[i]->getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dst) {
|
||||
Util::add (msg, links[i]->getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " = " << Util::parametersToString (msg);
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
FgBpSolver::printLinkInformation (void) const
|
||||
{
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
SpLink* l = links_[i];
|
||||
cout << l->toString() << ":" << endl;
|
||||
cout << " curr msg = " ;
|
||||
cout << Util::parametersToString (l->getMessage()) << endl;
|
||||
cout << " next msg = " ;
|
||||
cout << Util::parametersToString (l->getNextMessage()) << endl;
|
||||
cout << " residual = " << l->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
175
packages/CLPBN/clpbn/bp/FgBpSolver.h
Normal file
175
packages/CLPBN/clpbn/bp/FgBpSolver.h
Normal file
@ -0,0 +1,175 @@
|
||||
#ifndef HORUS_FGBPSOLVER_H
|
||||
#define HORUS_FGBPSOLVER_H
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "Factor.h"
|
||||
#include "FactorGraph.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
|
||||
class SpLink
|
||||
{
|
||||
public:
|
||||
SpLink (FgFacNode* fn, FgVarNode* vn)
|
||||
{
|
||||
fac_ = fn;
|
||||
var_ = vn;
|
||||
v1_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
|
||||
v2_.resize (vn->nrStates(), Util::tl (1.0 / vn->nrStates()));
|
||||
currMsg_ = &v1_;
|
||||
nextMsg_ = &v2_;
|
||||
msgSended_ = false;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
virtual ~SpLink (void) {};
|
||||
|
||||
virtual void updateMessage (void)
|
||||
{
|
||||
swap (currMsg_, nextMsg_);
|
||||
msgSended_ = true;
|
||||
}
|
||||
|
||||
void updateResidual (void)
|
||||
{
|
||||
residual_ = Util::getMaxNorm (v1_, v2_);
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << fac_->getLabel();
|
||||
ss << " -- " ;
|
||||
ss << var_->label();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
FgFacNode* getFactor (void) const { return fac_; }
|
||||
FgVarNode* getVariable (void) const { return var_; }
|
||||
const ParamSet& getMessage (void) const { return *currMsg_; }
|
||||
ParamSet& getNextMessage (void) { return *nextMsg_; }
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
double getResidual (void) const { return residual_; }
|
||||
void clearResidual (void) { residual_ = 0.0; }
|
||||
|
||||
protected:
|
||||
FgFacNode* fac_;
|
||||
FgVarNode* var_;
|
||||
ParamSet v1_;
|
||||
ParamSet v2_;
|
||||
ParamSet* currMsg_;
|
||||
ParamSet* nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
};
|
||||
|
||||
|
||||
typedef vector<SpLink*> SpLinkSet;
|
||||
|
||||
|
||||
class SPNodeInfo
|
||||
{
|
||||
public:
|
||||
void addSpLink (SpLink* link) { links_.push_back (link); }
|
||||
const SpLinkSet& getLinks (void) { return links_; }
|
||||
|
||||
private:
|
||||
SpLinkSet links_;
|
||||
};
|
||||
|
||||
|
||||
class FgBpSolver : public Solver
|
||||
{
|
||||
public:
|
||||
FgBpSolver (const FactorGraph&);
|
||||
virtual ~FgBpSolver (void);
|
||||
|
||||
void runSolver (void);
|
||||
virtual ParamSet getPosterioriOf (VarId);
|
||||
virtual ParamSet getJointDistributionOf (const VarIdSet&);
|
||||
|
||||
protected:
|
||||
virtual void initializeSolver (void);
|
||||
virtual void createLinks (void);
|
||||
virtual void maxResidualSchedule (void);
|
||||
virtual void calculateFactor2VariableMsg (SpLink*) const;
|
||||
virtual ParamSet getVar2FactorMsg (const SpLink*) const;
|
||||
virtual void printLinkInformation (void) const;
|
||||
|
||||
void calculateAndUpdateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
cout << "calculating & updating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
link->updateMessage();
|
||||
}
|
||||
|
||||
void calculateMessage (SpLink* link, bool calcResidual = true)
|
||||
{
|
||||
if (DL >= 3) {
|
||||
cout << "calculating " << link->toString() << endl;
|
||||
}
|
||||
calculateFactor2VariableMsg (link);
|
||||
if (calcResidual) {
|
||||
link->updateResidual();
|
||||
}
|
||||
}
|
||||
|
||||
void updateMessage (SpLink* link)
|
||||
{
|
||||
link->updateMessage();
|
||||
if (DL >= 3) {
|
||||
cout << "updating " << link->toString() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SPNodeInfo* ninf (const FgVarNode* var) const
|
||||
{
|
||||
return varsI_[var->getIndex()];
|
||||
}
|
||||
|
||||
SPNodeInfo* ninf (const FgFacNode* fac) const
|
||||
{
|
||||
return facsI_[fac->getIndex()];
|
||||
}
|
||||
|
||||
struct CompareResidual {
|
||||
inline bool operator() (const SpLink* link1, const SpLink* link2)
|
||||
{
|
||||
return link1->getResidual() > link2->getResidual();
|
||||
}
|
||||
};
|
||||
|
||||
SpLinkSet links_;
|
||||
unsigned nIters_;
|
||||
vector<SPNodeInfo*> varsI_;
|
||||
vector<SPNodeInfo*> facsI_;
|
||||
const FactorGraph* factorGraph_;
|
||||
|
||||
typedef multiset<SpLink*, CompareResidual> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
|
||||
typedef unordered_map<SpLink*, SortedOrder::iterator> SpLinkMap;
|
||||
SpLinkMap linkMap_;
|
||||
|
||||
private:
|
||||
void runTreeSolver (void);
|
||||
bool readyToSendMessage (const SpLink*) const;
|
||||
void runLoopySolver (void);
|
||||
bool converged (void);
|
||||
|
||||
|
||||
};
|
||||
|
||||
#endif // HORUS_FGBPSOLVER_H
|
||||
|
@ -1,43 +0,0 @@
|
||||
#ifndef BP_FG_VAR_NODE_H
|
||||
#define BP_FG_VAR_NODE_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "Variable.h"
|
||||
#include "Shared.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class Factor;
|
||||
|
||||
class FgVarNode : public Variable
|
||||
{
|
||||
public:
|
||||
FgVarNode (unsigned vid, unsigned dsize) : Variable (vid, dsize) { }
|
||||
FgVarNode (const Variable* v) : Variable (v) { }
|
||||
|
||||
void addFactor (Factor* f) { factors_.push_back (f); }
|
||||
CFactorSet getFactors (void) const { return factors_; }
|
||||
|
||||
void removeFactor (const Factor* f)
|
||||
{
|
||||
if (factors_[factors_.size() -1] == f) {
|
||||
factors_.pop_back();
|
||||
} else {
|
||||
for (unsigned i = 0; i < factors_.size(); i++) {
|
||||
if (factors_[i] == f) {
|
||||
factors_.erase (factors_.begin() + i);
|
||||
return;
|
||||
}
|
||||
}
|
||||
assert (false);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (FgVarNode);
|
||||
// members
|
||||
FactorSet factors_;
|
||||
};
|
||||
|
||||
#endif // BP_FG_VAR_NODE_H
|
@ -1,18 +1,54 @@
|
||||
#ifndef BP_GRAPHICAL_MODEL_H
|
||||
#define BP_GRAPHICAL_MODEL_H
|
||||
#ifndef HORUS_GRAPHICALMODEL_H
|
||||
#define HORUS_GRAPHICALMODEL_H
|
||||
|
||||
#include "Variable.h"
|
||||
#include "VarNode.h"
|
||||
#include "Shared.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
struct VariableInfo
|
||||
{
|
||||
VariableInfo (string l, const States& sts)
|
||||
{
|
||||
label = l;
|
||||
states = sts;
|
||||
}
|
||||
string label;
|
||||
States states;
|
||||
};
|
||||
|
||||
|
||||
class GraphicalModel
|
||||
{
|
||||
public:
|
||||
virtual ~GraphicalModel (void) {};
|
||||
virtual Variable* getVariable (Vid) const = 0;
|
||||
virtual VarSet getVariables (void) const = 0;
|
||||
virtual VarNode* getVariableNode (VarId) const = 0;
|
||||
virtual VarNodes getVariableNodes (void) const = 0;
|
||||
virtual void printGraphicalModel (void) const = 0;
|
||||
|
||||
static void addVariableInformation (VarId vid, string label,
|
||||
const States& states)
|
||||
{
|
||||
assert (varsInfo_.find (vid) == varsInfo_.end());
|
||||
varsInfo_.insert (make_pair (vid, VariableInfo (label, states)));
|
||||
}
|
||||
static VariableInfo getVariableInformation (VarId vid)
|
||||
{
|
||||
assert (varsInfo_.find (vid) != varsInfo_.end());
|
||||
return varsInfo_.find (vid)->second;
|
||||
}
|
||||
static bool variablesHaveInformation (void)
|
||||
{
|
||||
return varsInfo_.size() != 0;
|
||||
}
|
||||
static void clearVariablesInformation (void)
|
||||
{
|
||||
varsInfo_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
static unordered_map<VarId,VariableInfo> varsInfo_;
|
||||
};
|
||||
|
||||
#endif // BP_GRAPHICAL_MODEL_H
|
||||
#endif // HORUS_GRAPHICALMODEL_H
|
||||
|
||||
|
@ -5,15 +5,18 @@
|
||||
|
||||
#include "BayesNet.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "SPSolver.h"
|
||||
#include "BPSolver.h"
|
||||
#include "CountingBP.h"
|
||||
#include "VarElimSolver.h"
|
||||
#include "BnBpSolver.h"
|
||||
#include "FgBpSolver.h"
|
||||
#include "CbpSolver.h"
|
||||
|
||||
#include "StatesIndexer.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
void BayesianNetwork (int, const char* []);
|
||||
void markovNetwork (int, const char* []);
|
||||
void runSolver (Solver*, const VarSet&);
|
||||
void processArguments (BayesNet&, int, const char* []);
|
||||
void processArguments (FactorGraph&, int, const char* []);
|
||||
void runSolver (Solver*, const VarNodes&);
|
||||
|
||||
const string USAGE = "usage: \
|
||||
./hcli FILE [VARIABLE | OBSERVED_VARIABLE=EVIDENCE]..." ;
|
||||
@ -22,32 +25,6 @@ const string USAGE = "usage: \
|
||||
int
|
||||
main (int argc, const char* argv[])
|
||||
{
|
||||
/*
|
||||
FactorGraph fg;
|
||||
FgVarNode* varNode1 = new FgVarNode (0, 2);
|
||||
FgVarNode* varNode2 = new FgVarNode (1, 2);
|
||||
FgVarNode* varNode3 = new FgVarNode (2, 2);
|
||||
fg.addVariable (varNode1);
|
||||
fg.addVariable (varNode2);
|
||||
fg.addVariable (varNode3);
|
||||
Distribution* dist = new Distribution (ParamSet() = {1.2, 1.4, 2.0, 0.4});
|
||||
fg.addFactor (new Factor (FgVarSet() = {varNode1, varNode2}, dist));
|
||||
fg.addFactor (new Factor (FgVarSet() = {varNode3, varNode2}, dist));
|
||||
//fg.printGraphicalModel();
|
||||
//SPSolver sp (fg);
|
||||
//sp.runSolver();
|
||||
//sp.printAllPosterioris();
|
||||
//ParamSet p = sp.getJointDistributionOf (VidSet() = {0, 1, 2});
|
||||
//cout << Util::parametersToString (p) << endl;
|
||||
CountingBP cbp (fg);
|
||||
//cbp.runSolver();
|
||||
//cbp.printAllPosterioris();
|
||||
ParamSet p2 = cbp.getJointDistributionOf (VidSet() = {0, 1, 2});
|
||||
cout << Util::parametersToString (p2) << endl;
|
||||
fg.freeDistributions();
|
||||
Statistics::printCompressingStats ("compressing.stats");
|
||||
return 0;
|
||||
*/
|
||||
if (!argv[1]) {
|
||||
cerr << "error: no graphical model specified" << endl;
|
||||
cerr << USAGE << endl;
|
||||
@ -56,12 +33,20 @@ main (int argc, const char* argv[])
|
||||
const string& fileName = argv[1];
|
||||
const string& extension = fileName.substr (fileName.find_last_of ('.') + 1);
|
||||
if (extension == "xml") {
|
||||
BayesianNetwork (argc, argv);
|
||||
BayesNet bn;
|
||||
bn.readFromBifFormat (argv[1]);
|
||||
processArguments (bn, argc, argv);
|
||||
} else if (extension == "uai") {
|
||||
markovNetwork (argc, argv);
|
||||
FactorGraph fg;
|
||||
fg.readFromUaiFormat (argv[1]);
|
||||
processArguments (fg, argc, argv);
|
||||
} else if (extension == "fg") {
|
||||
FactorGraph fg;
|
||||
fg.readFromLibDaiFormat (argv[1]);
|
||||
processArguments (fg, argc, argv);
|
||||
} else {
|
||||
cerr << "error: the graphical model must be defined either " ;
|
||||
cerr << "in a xml file or uai file" << endl;
|
||||
cerr << "in a xml, uai or libDAI file" << endl;
|
||||
exit (0);
|
||||
}
|
||||
return 0;
|
||||
@ -70,12 +55,9 @@ main (int argc, const char* argv[])
|
||||
|
||||
|
||||
void
|
||||
BayesianNetwork (int argc, const char* argv[])
|
||||
processArguments (BayesNet& bn, int argc, const char* argv[])
|
||||
{
|
||||
BayesNet bn (argv[1]);
|
||||
//bn.printGraphicalModel();
|
||||
|
||||
VarSet queryVars;
|
||||
VarNodes queryVars;
|
||||
for (int i = 2; i < argc; i++) {
|
||||
const string& arg = argv[i];
|
||||
if (arg.find ('=') == std::string::npos) {
|
||||
@ -86,6 +68,7 @@ BayesianNetwork (int argc, const char* argv[])
|
||||
cerr << "error: there isn't a variable labeled of " ;
|
||||
cerr << "`" << arg << "'" ;
|
||||
cerr << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
@ -95,11 +78,13 @@ BayesianNetwork (int argc, const char* argv[])
|
||||
if (label.empty()) {
|
||||
cerr << "error: missing left argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
if (state.empty()) {
|
||||
cerr << "error: missing right argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
BayesNode* node = bn.getBayesNode (label);
|
||||
@ -109,42 +94,54 @@ BayesianNetwork (int argc, const char* argv[])
|
||||
} else {
|
||||
cerr << "error: `" << state << "' " ;
|
||||
cerr << "is not a valid state for " ;
|
||||
cerr << "`" << node->getLabel() << "'" ;
|
||||
cerr << "`" << node->label() << "'" ;
|
||||
cerr << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
cerr << "error: there isn't a variable labeled of " ;
|
||||
cerr << "`" << label << "'" ;
|
||||
cerr << endl;
|
||||
bn.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Solver* solver;
|
||||
if (SolverOptions::convertBn2Fg) {
|
||||
FactorGraph* fg = new FactorGraph (bn);
|
||||
Solver* solver = 0;
|
||||
FactorGraph* fg = 0;
|
||||
switch (InfAlgorithms::infAlgorithm) {
|
||||
case InfAlgorithms::VE:
|
||||
fg = new FactorGraph (bn);
|
||||
solver = new VarElimSolver (*fg);
|
||||
break;
|
||||
case InfAlgorithms::BN_BP:
|
||||
solver = new BnBpSolver (bn);
|
||||
break;
|
||||
case InfAlgorithms::FG_BP:
|
||||
fg = new FactorGraph (bn);
|
||||
fg->printGraphicalModel();
|
||||
solver = new SPSolver (*fg);
|
||||
solver = new FgBpSolver (*fg);
|
||||
break;
|
||||
case InfAlgorithms::CBP:
|
||||
fg = new FactorGraph (bn);
|
||||
solver = new CbpSolver (*fg);
|
||||
break;
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
runSolver (solver, queryVars);
|
||||
delete fg;
|
||||
} else {
|
||||
solver = new BPSolver (bn);
|
||||
runSolver (solver, queryVars);
|
||||
}
|
||||
bn.freeDistributions();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
markovNetwork (int argc, const char* argv[])
|
||||
processArguments (FactorGraph& fg, int argc, const char* argv[])
|
||||
{
|
||||
FactorGraph fg (argv[1]);
|
||||
//fg.printGraphicalModel();
|
||||
|
||||
VarSet queryVars;
|
||||
VarNodes queryVars;
|
||||
for (int i = 2; i < argc; i++) {
|
||||
const string& arg = argv[i];
|
||||
if (arg.find ('=') == std::string::npos) {
|
||||
@ -152,19 +149,21 @@ markovNetwork (int argc, const char* argv[])
|
||||
cerr << "error: `" << arg << "' " ;
|
||||
cerr << "is not a valid variable id" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
Vid vid;
|
||||
VarId vid;
|
||||
stringstream ss;
|
||||
ss << arg;
|
||||
ss >> vid;
|
||||
Variable* queryVar = fg.getFgVarNode (vid);
|
||||
VarNode* queryVar = fg.getFgVarNode (vid);
|
||||
if (queryVar) {
|
||||
queryVars.push_back (queryVar);
|
||||
} else {
|
||||
cerr << "error: there isn't a variable with " ;
|
||||
cerr << "`" << vid << "' as id" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
@ -172,53 +171,73 @@ markovNetwork (int argc, const char* argv[])
|
||||
if (arg.substr (0, pos).empty()) {
|
||||
cerr << "error: missing left argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
if (arg.substr (pos + 1).empty()) {
|
||||
cerr << "error: missing right argument" << endl;
|
||||
cerr << USAGE << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
if (!Util::isInteger (arg.substr (0, pos))) {
|
||||
cerr << "error: `" << arg.substr (0, pos) << "' " ;
|
||||
cerr << "is not a variable id" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
Vid vid;
|
||||
VarId vid;
|
||||
stringstream ss;
|
||||
ss << arg.substr (0, pos);
|
||||
ss >> vid;
|
||||
Variable* var = fg.getFgVarNode (vid);
|
||||
VarNode* var = fg.getFgVarNode (vid);
|
||||
if (var) {
|
||||
if (!Util::isInteger (arg.substr (pos + 1))) {
|
||||
cerr << "error: `" << arg.substr (pos + 1) << "' " ;
|
||||
cerr << "is not a state index" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
int stateIndex;
|
||||
stringstream ss;
|
||||
ss << arg.substr (pos + 1);
|
||||
ss >> stateIndex;
|
||||
if (var->isValidStateIndex (stateIndex)) {
|
||||
if (var->isValidState (stateIndex)) {
|
||||
var->setEvidence (stateIndex);
|
||||
} else {
|
||||
cerr << "error: `" << stateIndex << "' " ;
|
||||
cerr << "is not a valid state index for variable " ;
|
||||
cerr << "`" << var->getVarId() << "'" ;
|
||||
cerr << "`" << var->varId() << "'" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
} else {
|
||||
cerr << "error: there isn't a variable with " ;
|
||||
cerr << "`" << vid << "' as id" ;
|
||||
cerr << endl;
|
||||
fg.freeDistributions();
|
||||
exit (0);
|
||||
}
|
||||
}
|
||||
}
|
||||
Solver* solver = new SPSolver (fg);
|
||||
Solver* solver = 0;
|
||||
switch (InfAlgorithms::infAlgorithm) {
|
||||
case InfAlgorithms::VE:
|
||||
solver = new VarElimSolver (fg);
|
||||
break;
|
||||
case InfAlgorithms::BN_BP:
|
||||
case InfAlgorithms::FG_BP:
|
||||
solver = new FgBpSolver (fg);
|
||||
break;
|
||||
case InfAlgorithms::CBP:
|
||||
solver = new CbpSolver (fg);
|
||||
break;
|
||||
default:
|
||||
assert (false);
|
||||
}
|
||||
runSolver (solver, queryVars);
|
||||
fg.freeDistributions();
|
||||
}
|
||||
@ -226,11 +245,11 @@ markovNetwork (int argc, const char* argv[])
|
||||
|
||||
|
||||
void
|
||||
runSolver (Solver* solver, const VarSet& queryVars)
|
||||
runSolver (Solver* solver, const VarNodes& queryVars)
|
||||
{
|
||||
VidSet vids;
|
||||
VarIdSet vids;
|
||||
for (unsigned i = 0; i < queryVars.size(); i++) {
|
||||
vids.push_back (queryVars[i]->getVarId());
|
||||
vids.push_back (queryVars[i]->varId());
|
||||
}
|
||||
if (queryVars.size() == 0) {
|
||||
solver->runSolver();
|
||||
@ -239,6 +258,7 @@ runSolver (Solver* solver, const VarSet& queryVars)
|
||||
solver->runSolver();
|
||||
solver->printPosterioriOf (vids[0]);
|
||||
} else {
|
||||
solver->runSolver();
|
||||
solver->printJointDistributionOf (vids);
|
||||
}
|
||||
delete solver;
|
||||
|
@ -8,9 +8,11 @@
|
||||
|
||||
#include "BayesNet.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "BPSolver.h"
|
||||
#include "SPSolver.h"
|
||||
#include "CountingBP.h"
|
||||
#include "VarElimSolver.h"
|
||||
#include "BnBpSolver.h"
|
||||
#include "FgBpSolver.h"
|
||||
#include "CbpSolver.h"
|
||||
#include "ElimGraph.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
@ -18,26 +20,27 @@ using namespace std;
|
||||
int
|
||||
createNetwork (void)
|
||||
{
|
||||
//Statistics::numCreatedNets ++;
|
||||
//cout << "creating network number " << Statistics::numCreatedNets << endl;
|
||||
|
||||
Statistics::incrementPrimaryNetworksCounting();
|
||||
// cout << "creating network number " ;
|
||||
// cout << Statistics::getPrimaryNetworksCounting() << endl;
|
||||
// if (Statistics::getPrimaryNetworksCounting() > 98) {
|
||||
// Statistics::writeStatisticsToFile ("../../compressing.stats");
|
||||
// }
|
||||
BayesNet* bn = new BayesNet();
|
||||
YAP_Term varList = YAP_ARG1;
|
||||
BnNodeSet nodes;
|
||||
vector<VarIdSet> parents;
|
||||
while (varList != YAP_TermNil()) {
|
||||
YAP_Term var = YAP_HeadOfTerm (varList);
|
||||
Vid vid = (Vid) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
|
||||
VarId vid = (VarId) YAP_IntOfTerm (YAP_ArgOfTerm (1, var));
|
||||
unsigned dsize = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (2, var));
|
||||
int evidence = (int) YAP_IntOfTerm (YAP_ArgOfTerm (3, var));
|
||||
YAP_Term parentL = YAP_ArgOfTerm (4, var);
|
||||
unsigned distId = (unsigned) YAP_IntOfTerm (YAP_ArgOfTerm (5, var));
|
||||
BnNodeSet parents;
|
||||
parents.push_back (VarIdSet());
|
||||
while (parentL != YAP_TermNil()) {
|
||||
unsigned parentId = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (parentL));
|
||||
BayesNode* parent = bn->getBayesNode (parentId);
|
||||
if (!parent) {
|
||||
parent = bn->addNode (parentId);
|
||||
}
|
||||
parents.push_back (parent);
|
||||
parents.back().push_back (parentId);
|
||||
parentL = YAP_TailOfTerm (parentL);
|
||||
}
|
||||
Distribution* dist = bn->getDistribution (distId);
|
||||
@ -45,20 +48,19 @@ createNetwork (void)
|
||||
dist = new Distribution (distId);
|
||||
bn->addDistribution (dist);
|
||||
}
|
||||
BayesNode* node = bn->getBayesNode (vid);
|
||||
if (node) {
|
||||
node->setData (dsize, evidence, parents, dist);
|
||||
} else {
|
||||
bn->addNode (vid, dsize, evidence, parents, dist);
|
||||
}
|
||||
assert (bn->getBayesNode (vid) == 0);
|
||||
nodes.push_back (bn->addNode (vid, dsize, evidence, dist));
|
||||
varList = YAP_TailOfTerm (varList);
|
||||
}
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
BnNodeSet ps;
|
||||
for (unsigned j = 0; j < parents[i].size(); j++) {
|
||||
assert (bn->getBayesNode (parents[i][j]) != 0);
|
||||
ps.push_back (bn->getBayesNode (parents[i][j]));
|
||||
}
|
||||
nodes[i]->setParents (ps);
|
||||
}
|
||||
bn->setIndexes();
|
||||
|
||||
// if (Statistics::numCreatedNets == 1688) {
|
||||
// Statistics::writeStats();
|
||||
// exit (0);
|
||||
// }
|
||||
YAP_Int p = (YAP_Int) (bn);
|
||||
return YAP_Unify (YAP_MkIntTerm (p), YAP_ARG2);
|
||||
}
|
||||
@ -68,23 +70,22 @@ createNetwork (void)
|
||||
int
|
||||
setExtraVarsInfo (void)
|
||||
{
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
// BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
GraphicalModel::clearVariablesInformation();
|
||||
YAP_Term varsInfoL = YAP_ARG2;
|
||||
while (varsInfoL != YAP_TermNil()) {
|
||||
YAP_Term head = YAP_HeadOfTerm (varsInfoL);
|
||||
Vid vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
|
||||
VarId vid = YAP_IntOfTerm (YAP_ArgOfTerm (1, head));
|
||||
YAP_Atom label = YAP_AtomOfTerm (YAP_ArgOfTerm (2, head));
|
||||
YAP_Term domainL = YAP_ArgOfTerm (3, head);
|
||||
Domain domain;
|
||||
while (domainL != YAP_TermNil()) {
|
||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (domainL));
|
||||
domain.push_back ((char*) YAP_AtomName (atom));
|
||||
domainL = YAP_TailOfTerm (domainL);
|
||||
YAP_Term statesL = YAP_ArgOfTerm (3, head);
|
||||
States states;
|
||||
while (statesL != YAP_TermNil()) {
|
||||
YAP_Atom atom = YAP_AtomOfTerm (YAP_HeadOfTerm (statesL));
|
||||
states.push_back ((char*) YAP_AtomName (atom));
|
||||
statesL = YAP_TailOfTerm (statesL);
|
||||
}
|
||||
BayesNode* node = bn->getBayesNode (vid);
|
||||
assert (node);
|
||||
node->setLabel ((char*) YAP_AtomName (label));
|
||||
node->setDomain (domain);
|
||||
GraphicalModel::addVariableInformation (vid,
|
||||
(char*) YAP_AtomName (label), states);
|
||||
varsInfoL = YAP_TailOfTerm (varsInfoL);
|
||||
}
|
||||
return TRUE;
|
||||
@ -106,12 +107,10 @@ setParameters (void)
|
||||
params.push_back ((double) YAP_FloatOfTerm (YAP_HeadOfTerm (paramL)));
|
||||
paramL = YAP_TailOfTerm (paramL);
|
||||
}
|
||||
bn->getDistribution(distId)->updateParameters(params);
|
||||
if (Statistics::numCreatedNets == 4) {
|
||||
cout << "dist " << distId << " parameters:" ;
|
||||
cout << Util::parametersToString (params);
|
||||
cout << endl;
|
||||
if (NSPACE == NumberSpace::LOGARITHM) {
|
||||
Util::toLog (params);
|
||||
}
|
||||
bn->getDistribution(distId)->updateParameters (params);
|
||||
distList = YAP_TailOfTerm (distList);
|
||||
}
|
||||
return TRUE;
|
||||
@ -124,113 +123,73 @@ runSolver (void)
|
||||
{
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
YAP_Term taskList = YAP_ARG2;
|
||||
vector<VidSet> tasks;
|
||||
VidSet marginalVids;
|
||||
|
||||
vector<VarIdSet> tasks;
|
||||
std::set<VarId> vids;
|
||||
while (taskList != YAP_TermNil()) {
|
||||
if (YAP_IsPairTerm (YAP_HeadOfTerm (taskList))) {
|
||||
VidSet jointVids;
|
||||
tasks.push_back (VarIdSet());
|
||||
YAP_Term jointList = YAP_HeadOfTerm (taskList);
|
||||
while (jointList != YAP_TermNil()) {
|
||||
Vid vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
||||
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (jointList));
|
||||
assert (bn->getBayesNode (vid));
|
||||
jointVids.push_back (vid);
|
||||
tasks.back().push_back (vid);
|
||||
vids.insert (vid);
|
||||
jointList = YAP_TailOfTerm (jointList);
|
||||
}
|
||||
tasks.push_back (jointVids);
|
||||
} else {
|
||||
Vid vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
|
||||
VarId vid = (unsigned) YAP_IntOfTerm (YAP_HeadOfTerm (taskList));
|
||||
assert (bn->getBayesNode (vid));
|
||||
tasks.push_back (VidSet() = {vid});
|
||||
marginalVids.push_back (vid);
|
||||
tasks.push_back (VarIdSet() = {vid});
|
||||
vids.insert (vid);
|
||||
}
|
||||
taskList = YAP_TailOfTerm (taskList);
|
||||
}
|
||||
|
||||
// cout << "inference tasks:" << endl;
|
||||
// for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
// cout << "i" << ": " ;
|
||||
// if (tasks[i].size() == 1) {
|
||||
// cout << tasks[i][0] << endl;
|
||||
// } else {
|
||||
// for (unsigned j = 0; j < tasks[i].size(); j++) {
|
||||
// cout << tasks[i][j] << " " ;
|
||||
// }
|
||||
// cout << endl;
|
||||
// }
|
||||
// }
|
||||
|
||||
Solver* solver = 0;
|
||||
GraphicalModel* gm = 0;
|
||||
VidSet vids;
|
||||
const BnNodeSet& nodes = bn->getBayesNodes();
|
||||
for (unsigned i = 0; i < nodes.size(); i++) {
|
||||
vids.push_back (nodes[i]->getVarId());
|
||||
}
|
||||
if (marginalVids.size() != 0) {
|
||||
bn->exportToDotFormat ("bn unbayes.dot");
|
||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (marginalVids);
|
||||
mrn->exportToDotFormat ("bn bayes.dot");
|
||||
//BayesNet* mrn = bn->getMinimalRequesiteNetwork (vids);
|
||||
if (SolverOptions::convertBn2Fg) {
|
||||
gm = new FactorGraph (*mrn);
|
||||
if (SolverOptions::compressFactorGraph) {
|
||||
solver = new CountingBP (*static_cast<FactorGraph*> (gm));
|
||||
} else {
|
||||
solver = new SPSolver (*static_cast<FactorGraph*> (gm));
|
||||
}
|
||||
if (SolverOptions::runBayesBall) {
|
||||
Solver* bpSolver = 0;
|
||||
GraphicalModel* graphicalModel = 0;
|
||||
CFactorGraph::disableCheckForIdenticalFactors();
|
||||
if (InfAlgorithms::infAlgorithm != InfAlgorithms::VE) {
|
||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (
|
||||
VarIdSet (vids.begin(), vids.end()));
|
||||
if (InfAlgorithms::infAlgorithm == InfAlgorithms::BN_BP) {
|
||||
graphicalModel = mrn;
|
||||
bpSolver = new BnBpSolver (*static_cast<BayesNet*> (graphicalModel));
|
||||
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::FG_BP) {
|
||||
graphicalModel = new FactorGraph (*mrn);
|
||||
bpSolver = new FgBpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
||||
delete mrn;
|
||||
} else if (InfAlgorithms::infAlgorithm == InfAlgorithms::CBP) {
|
||||
graphicalModel = new FactorGraph (*mrn);
|
||||
bpSolver = new CbpSolver (*static_cast<FactorGraph*> (graphicalModel));
|
||||
delete mrn;
|
||||
}
|
||||
} else {
|
||||
gm = mrn;
|
||||
solver = new BPSolver (*static_cast<BayesNet*> (gm));
|
||||
}
|
||||
solver->runSolver();
|
||||
bpSolver->runSolver();
|
||||
}
|
||||
|
||||
vector<ParamSet> results;
|
||||
results.reserve (tasks.size());
|
||||
for (unsigned i = 0; i < tasks.size(); i++) {
|
||||
//if (i == 1) exit (0);
|
||||
if (InfAlgorithms::infAlgorithm == InfAlgorithms::VE) {
|
||||
BayesNet* mrn = bn->getMinimalRequesiteNetwork (tasks[i]);
|
||||
VarElimSolver* veSolver = new VarElimSolver (*mrn);
|
||||
if (tasks[i].size() == 1) {
|
||||
results.push_back (solver->getPosterioriOf (tasks[i][0]));
|
||||
results.push_back (veSolver->getPosterioriOf (tasks[i][0]));
|
||||
} else {
|
||||
static int count = 0;
|
||||
cout << "calculating joint... " << count ++ << endl;
|
||||
//if (count == 5225) {
|
||||
// Statistics::printCompressingStats ("compressing.stats");
|
||||
//}
|
||||
Solver* solver2 = 0;
|
||||
GraphicalModel* gm2 = 0;
|
||||
bn->exportToDotFormat ("joint.dot");
|
||||
BayesNet* mrn2;
|
||||
if (SolverOptions::runBayesBall) {
|
||||
mrn2 = bn->getMinimalRequesiteNetwork (tasks[i]);
|
||||
results.push_back (veSolver->getJointDistributionOf (tasks[i]));
|
||||
}
|
||||
delete mrn;
|
||||
delete veSolver;
|
||||
} else {
|
||||
mrn2 = bn;
|
||||
}
|
||||
if (SolverOptions::convertBn2Fg) {
|
||||
gm2 = new FactorGraph (*mrn2);
|
||||
if (SolverOptions::compressFactorGraph) {
|
||||
solver2 = new CountingBP (*static_cast<FactorGraph*> (gm2));
|
||||
if (tasks[i].size() == 1) {
|
||||
results.push_back (bpSolver->getPosterioriOf (tasks[i][0]));
|
||||
} else {
|
||||
solver2 = new SPSolver (*static_cast<FactorGraph*> (gm2));
|
||||
}
|
||||
if (SolverOptions::runBayesBall) {
|
||||
delete mrn2;
|
||||
}
|
||||
} else {
|
||||
gm2 = mrn2;
|
||||
solver2 = new BPSolver (*static_cast<BayesNet*> (gm2));
|
||||
}
|
||||
results.push_back (solver2->getJointDistributionOf (tasks[i]));
|
||||
delete solver2;
|
||||
delete gm2;
|
||||
results.push_back (bpSolver->getJointDistributionOf (tasks[i]));
|
||||
}
|
||||
}
|
||||
|
||||
delete solver;
|
||||
delete gm;
|
||||
}
|
||||
delete bpSolver;
|
||||
delete graphicalModel;
|
||||
|
||||
YAP_Term list = YAP_TermNil();
|
||||
for (int i = results.size() - 1; i >= 0; i--) {
|
||||
@ -251,10 +210,91 @@ runSolver (void)
|
||||
|
||||
|
||||
|
||||
int
|
||||
setSolverParameter (void)
|
||||
{
|
||||
string key ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG1)));
|
||||
if (key == "inf_alg") {
|
||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||
if ( value == "ve") {
|
||||
InfAlgorithms::infAlgorithm = InfAlgorithms::VE;
|
||||
} else if (value == "bn_bp") {
|
||||
InfAlgorithms::infAlgorithm = InfAlgorithms::BN_BP;
|
||||
} else if (value == "fg_bp") {
|
||||
InfAlgorithms::infAlgorithm = InfAlgorithms::FG_BP;
|
||||
} else if (value == "cbp") {
|
||||
InfAlgorithms::infAlgorithm = InfAlgorithms::CBP;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
return FALSE;
|
||||
}
|
||||
} else if (key == "elim_heuristic") {
|
||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||
if ( value == "min_neighbors") {
|
||||
ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_NEIGHBORS);
|
||||
} else if (value == "min_weight") {
|
||||
ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_WEIGHT);
|
||||
} else if (value == "min_fill") {
|
||||
ElimGraph::setEliminationHeuristic (ElimHeuristic::MIN_FILL);
|
||||
} else if (value == "weighted_min_fill") {
|
||||
ElimGraph::setEliminationHeuristic (ElimHeuristic::WEIGHTED_MIN_FILL);
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
return FALSE;
|
||||
}
|
||||
} else if (key == "schedule") {
|
||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||
if ( value == "seq_fixed") {
|
||||
BpOptions::schedule = BpOptions::Schedule::SEQ_FIXED;
|
||||
} else if (value == "seq_random") {
|
||||
BpOptions::schedule = BpOptions::Schedule::SEQ_RANDOM;
|
||||
} else if (value == "parallel") {
|
||||
BpOptions::schedule = BpOptions::Schedule::PARALLEL;
|
||||
} else if (value == "max_residual") {
|
||||
BpOptions::schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
return FALSE;
|
||||
}
|
||||
} else if (key == "accuracy") {
|
||||
BpOptions::accuracy = (double) YAP_FloatOfTerm (YAP_ARG2);
|
||||
} else if (key == "max_iter") {
|
||||
BpOptions::maxIter = (int) YAP_IntOfTerm (YAP_ARG2);
|
||||
} else if (key == "always_loopy_solver") {
|
||||
string value ((char*) YAP_AtomName (YAP_AtomOfTerm (YAP_ARG2)));
|
||||
if (value == "true") {
|
||||
BpOptions::useAlwaysLoopySolver = true;
|
||||
} else if (value == "false") {
|
||||
BpOptions::useAlwaysLoopySolver = false;
|
||||
} else {
|
||||
cerr << "warning: invalid value `" << value << "' " ;
|
||||
cerr << "for `" << key << "'" << endl;
|
||||
return FALSE;
|
||||
}
|
||||
} else {
|
||||
cerr << "warning: invalid key `" << key << "'" << endl;
|
||||
return FALSE;
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int useLogSpace (void)
|
||||
{
|
||||
NSPACE = NumberSpace::LOGARITHM;
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int
|
||||
freeBayesNetwork (void)
|
||||
{
|
||||
//Statistics::printCompressingStats ("../../compressing.stats");
|
||||
//Statistics::writeStatisticsToFile ("stats.txt");
|
||||
BayesNet* bn = (BayesNet*) YAP_IntOfTerm (YAP_ARG1);
|
||||
bn->freeDistributions();
|
||||
delete bn;
|
||||
@ -270,6 +310,8 @@ init_predicates (void)
|
||||
YAP_UserCPredicate ("set_extra_vars_info", setExtraVarsInfo, 2);
|
||||
YAP_UserCPredicate ("set_parameters", setParameters, 2);
|
||||
YAP_UserCPredicate ("run_solver", runSolver, 3);
|
||||
YAP_UserCPredicate ("set_solver_parameter", setSolverParameter, 2);
|
||||
YAP_UserCPredicate ("use_log_space", useLogSpace, 0);
|
||||
YAP_UserCPredicate ("free_bayesian_network", freeBayesNetwork, 1);
|
||||
}
|
||||
|
||||
|
@ -1,278 +0,0 @@
|
||||
|
||||
#include "LiftedFG.h"
|
||||
#include "FgVarNode.h"
|
||||
#include "Factor.h"
|
||||
#include "Distribution.h"
|
||||
|
||||
LiftedFG::LiftedFG (const FactorGraph& fg)
|
||||
{
|
||||
groundFg_ = &fg;
|
||||
freeColor_ = 0;
|
||||
|
||||
const FgVarSet& varNodes = fg.getFgVarNodes();
|
||||
const FactorSet& factors = fg.getFactors();
|
||||
varColors_.resize (varNodes.size());
|
||||
factorColors_.resize (factors.size());
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
factors[i]->setIndex (i);
|
||||
}
|
||||
|
||||
// create the initial variable colors
|
||||
VarColorMap colorMap;
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
unsigned dsize = varNodes[i]->getDomainSize();
|
||||
VarColorMap::iterator it = colorMap.find (dsize);
|
||||
if (it == colorMap.end()) {
|
||||
it = colorMap.insert (make_pair (
|
||||
dsize, vector<Color> (dsize + 1,-1))).first;
|
||||
}
|
||||
unsigned idx;
|
||||
if (varNodes[i]->hasEvidence()) {
|
||||
idx = varNodes[i]->getEvidence();
|
||||
} else {
|
||||
idx = dsize;
|
||||
}
|
||||
vector<Color>& stateColors = it->second;
|
||||
if (stateColors[idx] == -1) {
|
||||
stateColors[idx] = getFreeColor();
|
||||
}
|
||||
setColor (varNodes[i], stateColors[idx]);
|
||||
}
|
||||
|
||||
// create the initial factor colors
|
||||
DistColorMap distColors;
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
Distribution* dist = factors[i]->getDistribution();
|
||||
DistColorMap::iterator it = distColors.find (dist);
|
||||
if (it == distColors.end()) {
|
||||
it = distColors.insert (make_pair (dist, getFreeColor())).first;
|
||||
}
|
||||
setColor (factors[i], it->second);
|
||||
}
|
||||
|
||||
VarSignMap varGroups;
|
||||
FactorSignMap factorGroups;
|
||||
bool groupsHaveChanged = true;
|
||||
unsigned nIter = 0;
|
||||
while (groupsHaveChanged || nIter == 1) {
|
||||
nIter ++;
|
||||
if (Statistics::numCreatedNets == 4) {
|
||||
cout << "--------------------------------------------" << endl;
|
||||
cout << "Iteration " << nIter << endl;
|
||||
cout << "--------------------------------------------" << endl;
|
||||
}
|
||||
|
||||
unsigned prevFactorGroupsSize = factorGroups.size();
|
||||
factorGroups.clear();
|
||||
// set a new color to the factors with the same signature
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
const string& signatureId = getSignatureId (factors[i]);
|
||||
// cout << factors[i]->getLabel() << " signature: " ;
|
||||
// cout<< signatureId << endl;
|
||||
FactorSignMap::iterator it = factorGroups.find (signatureId);
|
||||
if (it == factorGroups.end()) {
|
||||
it = factorGroups.insert (make_pair (signatureId, FactorSet())).first;
|
||||
}
|
||||
it->second.push_back (factors[i]);
|
||||
}
|
||||
if (nIter > 0)
|
||||
for (FactorSignMap::iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
FactorSet& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
}
|
||||
|
||||
// set a new color to the variables with the same signature
|
||||
unsigned prevVarGroupsSize = varGroups.size();
|
||||
varGroups.clear();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
const string& signatureId = getSignatureId (varNodes[i]);
|
||||
VarSignMap::iterator it = varGroups.find (signatureId);
|
||||
// cout << varNodes[i]->getLabel() << " signature: " ;
|
||||
// cout << signatureId << endl;
|
||||
if (it == varGroups.end()) {
|
||||
it = varGroups.insert (make_pair (signatureId, FgVarSet())).first;
|
||||
}
|
||||
it->second.push_back (varNodes[i]);
|
||||
}
|
||||
if (nIter > 0)
|
||||
for (VarSignMap::iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
Color newColor = getFreeColor();
|
||||
FgVarSet& groupMembers = it->second;
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
setColor (groupMembers[i], newColor);
|
||||
}
|
||||
}
|
||||
|
||||
//if (nIter >= 3) cout << "bigger than three: " << nIter << endl;
|
||||
groupsHaveChanged = prevVarGroupsSize != varGroups.size()
|
||||
|| prevFactorGroupsSize != factorGroups.size();
|
||||
}
|
||||
|
||||
printGroups (varGroups, factorGroups);
|
||||
for (VarSignMap::iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
CFgVarSet vars = it->second;
|
||||
VarCluster* vc = new VarCluster (vars);
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
vid2VarCluster_.insert (make_pair (vars[i]->getVarId(), vc));
|
||||
}
|
||||
varClusters_.push_back (vc);
|
||||
}
|
||||
|
||||
for (FactorSignMap::iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
VarClusterSet varClusters;
|
||||
Factor* groundFactor = it->second[0];
|
||||
FgVarSet groundVars = groundFactor->getFgVarNodes();
|
||||
for (unsigned i = 0; i < groundVars.size(); i++) {
|
||||
Vid vid = groundVars[i]->getVarId();
|
||||
varClusters.push_back (vid2VarCluster_.find (vid)->second);
|
||||
}
|
||||
factorClusters_.push_back (new FactorCluster (it->second, varClusters));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
LiftedFG::~LiftedFG (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
delete varClusters_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < factorClusters_.size(); i++) {
|
||||
delete factorClusters_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
LiftedFG::getSignatureId (FgVarNode* var) const
|
||||
{
|
||||
stringstream ss;
|
||||
CFactorSet myFactors = var->getFactors();
|
||||
ss << myFactors.size();
|
||||
for (unsigned i = 0; i < myFactors.size(); i++) {
|
||||
ss << "." << getColor (myFactors[i]);
|
||||
ss << "." << myFactors[i]->getIndexOf(var);
|
||||
}
|
||||
ss << "." << getColor (var);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
LiftedFG::getSignatureId (Factor* factor) const
|
||||
{
|
||||
stringstream ss;
|
||||
CFgVarSet myVars = factor->getFgVarNodes();
|
||||
ss << myVars.size();
|
||||
for (unsigned i = 0; i < myVars.size(); i++) {
|
||||
ss << "." << getColor (myVars[i]);
|
||||
}
|
||||
ss << "." << getColor (factor);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
FactorGraph*
|
||||
LiftedFG::getCompressedFactorGraph (void)
|
||||
{
|
||||
FactorGraph* fg = new FactorGraph();
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
FgVarNode* var = varClusters_[i]->getGroundFgVarNodes()[0];
|
||||
FgVarNode* newVar = new FgVarNode (var);
|
||||
newVar->setIndex (i);
|
||||
varClusters_[i]->setRepresentativeVariable (newVar);
|
||||
fg->addVariable (newVar);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < factorClusters_.size(); i++) {
|
||||
FgVarSet myGroundVars;
|
||||
const VarClusterSet& myVarClusters = factorClusters_[i]->getVarClusters();
|
||||
for (unsigned j = 0; j < myVarClusters.size(); j++) {
|
||||
myGroundVars.push_back (myVarClusters[j]->getRepresentativeVariable());
|
||||
}
|
||||
Factor* newFactor = new Factor (myGroundVars,
|
||||
factorClusters_[i]->getGroundFactors()[0]->getDistribution());
|
||||
factorClusters_[i]->setRepresentativeFactor (newFactor);
|
||||
fg->addFactor (newFactor);
|
||||
}
|
||||
return fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
LiftedFG::getGroundEdgeCount (FactorCluster* fc, VarCluster* vc) const
|
||||
{
|
||||
CFactorSet clusterGroundFactors = fc->getGroundFactors();
|
||||
FgVarNode* var = vc->getGroundFgVarNodes()[0];
|
||||
unsigned count = 0;
|
||||
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||
if (clusterGroundFactors[i]->getIndexOf (var) != -1) {
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
/*
|
||||
CFgVarSet vars = vc->getGroundFgVarNodes();
|
||||
for (unsigned i = 1; i < vars.size(); i++) {
|
||||
FgVarNode* var = vc->getGroundFgVarNodes()[i];
|
||||
unsigned count2 = 0;
|
||||
for (unsigned i = 0; i < clusterGroundFactors.size(); i++) {
|
||||
if (clusterGroundFactors[i]->getIndexOf (var) != -1) {
|
||||
count2 ++;
|
||||
}
|
||||
}
|
||||
if (count != count2) { cout << "oops!" << endl; abort(); }
|
||||
}
|
||||
*/
|
||||
return count;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
LiftedFG::printGroups (const VarSignMap& varGroups,
|
||||
const FactorSignMap& factorGroups) const
|
||||
{
|
||||
cout << "variable groups:" << endl;
|
||||
unsigned count = 0;
|
||||
for (VarSignMap::const_iterator it = varGroups.begin();
|
||||
it != varGroups.end(); it++) {
|
||||
const FgVarSet& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << ++count << ": " ;
|
||||
//if (groupMembers.size() > 1) {
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
cout << groupMembers[i]->getLabel() << " " ;
|
||||
}
|
||||
//}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
cout << endl;
|
||||
cout << "factor groups:" << endl;
|
||||
count = 0;
|
||||
for (FactorSignMap::const_iterator it = factorGroups.begin();
|
||||
it != factorGroups.end(); it++) {
|
||||
const FactorSet& groupMembers = it->second;
|
||||
if (groupMembers.size() > 0) {
|
||||
cout << ++count << ": " ;
|
||||
//if (groupMembers.size() > 1) {
|
||||
for (unsigned i = 0; i < groupMembers.size(); i++) {
|
||||
cout << groupMembers[i]->getLabel() << " " ;
|
||||
}
|
||||
//}
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,152 +0,0 @@
|
||||
#ifndef BP_LIFTED_FG_H
|
||||
#define BP_LIFTED_FG_H
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "FactorGraph.h"
|
||||
#include "FgVarNode.h"
|
||||
#include "Factor.h"
|
||||
#include "Shared.h"
|
||||
|
||||
class VarCluster;
|
||||
class FactorCluster;
|
||||
class Distribution;
|
||||
|
||||
typedef long Color;
|
||||
typedef vector<Color> Signature;
|
||||
typedef vector<VarCluster*> VarClusterSet;
|
||||
typedef vector<FactorCluster*> FactorClusterSet;
|
||||
|
||||
typedef map<string, FgVarSet> VarSignMap;
|
||||
typedef map<string, FactorSet> FactorSignMap;
|
||||
|
||||
typedef map<unsigned, vector<Color> > VarColorMap;
|
||||
typedef map<Distribution*, Color> DistColorMap;
|
||||
|
||||
typedef map<Vid, VarCluster*> Vid2VarCluster;
|
||||
|
||||
|
||||
class VarCluster
|
||||
{
|
||||
public:
|
||||
VarCluster (CFgVarSet vs)
|
||||
{
|
||||
for (unsigned i = 0; i < vs.size(); i++) {
|
||||
groundVars_.push_back (vs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void addFactorCluster (FactorCluster* fc)
|
||||
{
|
||||
factorClusters_.push_back (fc);
|
||||
}
|
||||
|
||||
const FactorClusterSet& getFactorClusters (void) const
|
||||
{
|
||||
return factorClusters_;
|
||||
}
|
||||
|
||||
FgVarNode* getRepresentativeVariable (void) const { return representVar_; }
|
||||
void setRepresentativeVariable (FgVarNode* v) { representVar_ = v; }
|
||||
CFgVarSet getGroundFgVarNodes (void) const { return groundVars_; }
|
||||
|
||||
private:
|
||||
FgVarSet groundVars_;
|
||||
FactorClusterSet factorClusters_;
|
||||
FgVarNode* representVar_;
|
||||
};
|
||||
|
||||
|
||||
class FactorCluster
|
||||
{
|
||||
public:
|
||||
FactorCluster (CFactorSet groundFactors, const VarClusterSet& vcs)
|
||||
{
|
||||
groundFactors_ = groundFactors;
|
||||
varClusters_ = vcs;
|
||||
for (unsigned i = 0; i < varClusters_.size(); i++) {
|
||||
varClusters_[i]->addFactorCluster (this);
|
||||
}
|
||||
}
|
||||
|
||||
const VarClusterSet& getVarClusters (void) const
|
||||
{
|
||||
return varClusters_;
|
||||
}
|
||||
|
||||
bool containsGround (const Factor* f)
|
||||
{
|
||||
for (unsigned i = 0; i < groundFactors_.size(); i++) {
|
||||
if (groundFactors_[i] == f) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Factor* getRepresentativeFactor (void) const { return representFactor_; }
|
||||
void setRepresentativeFactor (Factor* f) { representFactor_ = f; }
|
||||
CFactorSet getGroundFactors (void) const { return groundFactors_; }
|
||||
|
||||
|
||||
private:
|
||||
FactorSet groundFactors_;
|
||||
VarClusterSet varClusters_;
|
||||
Factor* representFactor_;
|
||||
};
|
||||
|
||||
|
||||
class LiftedFG
|
||||
{
|
||||
public:
|
||||
LiftedFG (const FactorGraph&);
|
||||
~LiftedFG (void);
|
||||
|
||||
FactorGraph* getCompressedFactorGraph (void);
|
||||
unsigned getGroundEdgeCount (FactorCluster*, VarCluster*) const;
|
||||
void printGroups (const VarSignMap& varGroups,
|
||||
const FactorSignMap& factorGroups) const;
|
||||
|
||||
FgVarNode* getEquivalentVariable (Vid vid)
|
||||
{
|
||||
VarCluster* vc = vid2VarCluster_.find (vid)->second;
|
||||
return vc->getRepresentativeVariable();
|
||||
}
|
||||
|
||||
const VarClusterSet& getVariableClusters (void) { return varClusters_; }
|
||||
const FactorClusterSet& getFactorClusters (void) { return factorClusters_; }
|
||||
|
||||
private:
|
||||
string getSignatureId (FgVarNode*) const;
|
||||
string getSignatureId (Factor*) const;
|
||||
|
||||
Color getFreeColor (void) { return ++freeColor_ -1; }
|
||||
Color getColor (FgVarNode* v) const { return varColors_[v->getIndex()]; }
|
||||
Color getColor (Factor* f) const { return factorColors_[f->getIndex()]; }
|
||||
|
||||
void setColor (FgVarNode* v, Color c)
|
||||
{
|
||||
varColors_[v->getIndex()] = c;
|
||||
}
|
||||
|
||||
void setColor (Factor* f, Color c)
|
||||
{
|
||||
factorColors_[f->getIndex()] = c;
|
||||
}
|
||||
|
||||
VarCluster* getVariableCluster (Vid vid) const
|
||||
{
|
||||
return vid2VarCluster_.find (vid)->second;
|
||||
}
|
||||
|
||||
Color freeColor_;
|
||||
vector<Color> varColors_;
|
||||
vector<Color> factorColors_;
|
||||
VarClusterSet varClusters_;
|
||||
FactorClusterSet factorClusters_;
|
||||
Vid2VarCluster vid2VarCluster_;
|
||||
const FactorGraph* groundFg_;
|
||||
};
|
||||
|
||||
#endif // BP_LIFTED_FG_H
|
||||
|
@ -26,10 +26,7 @@ CXX=@CXX@
|
||||
CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../../.. -I$(srcdir)/../../../../include @CPPFLAGS@ -DNDEBUG
|
||||
|
||||
# debug
|
||||
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../../.. -I$(srcdir)/../../../../include @CPPFLAGS@ -g -O0
|
||||
|
||||
# profiling (callgrind)
|
||||
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../../.. -I$(srcdir)/../../../../include @CPPFLAGS@ -g -DNDEBUG
|
||||
#CXXFLAGS= -std=c++0x @SHLIB_CXXFLAGS@ $(YAP_EXTRAS) $(DEFS) -D_YAP_NOT_INSTALLED_=1 -I$(srcdir) -I../../../.. -I$(srcdir)/../../../../include @CPPFLAGS@ -g -O0 -Wextra
|
||||
|
||||
|
||||
#
|
||||
@ -49,33 +46,37 @@ CWD=$(PWD)
|
||||
|
||||
HEADERS = \
|
||||
$(srcdir)/GraphicalModel.h \
|
||||
$(srcdir)/Variable.h \
|
||||
$(srcdir)/VarNode.h \
|
||||
$(srcdir)/Distribution.h \
|
||||
$(srcdir)/BayesNet.h \
|
||||
$(srcdir)/BayesNode.h \
|
||||
$(srcdir)/LiftedFG.h \
|
||||
$(srcdir)/ElimGraph.h \
|
||||
$(srcdir)/CFactorGraph.h \
|
||||
$(srcdir)/CptEntry.h \
|
||||
$(srcdir)/FactorGraph.h \
|
||||
$(srcdir)/FgVarNode.h \
|
||||
$(srcdir)/Factor.h \
|
||||
$(srcdir)/Solver.h \
|
||||
$(srcdir)/BPSolver.h \
|
||||
$(srcdir)/BPNodeInfo.h \
|
||||
$(srcdir)/SPSolver.h \
|
||||
$(srcdir)/CountingBP.h \
|
||||
$(srcdir)/VarElimSolver.h \
|
||||
$(srcdir)/BnBpSolver.h \
|
||||
$(srcdir)/FgBpSolver.h \
|
||||
$(srcdir)/CbpSolver.h \
|
||||
$(srcdir)/Shared.h \
|
||||
$(srcdir)/StatesIndexer.h \
|
||||
$(srcdir)/xmlParser/xmlParser.h
|
||||
|
||||
CPP_SOURCES = \
|
||||
$(srcdir)/BayesNet.cpp \
|
||||
$(srcdir)/BayesNode.cpp \
|
||||
$(srcdir)/ElimGraph.cpp \
|
||||
$(srcdir)/FactorGraph.cpp \
|
||||
$(srcdir)/Factor.cpp \
|
||||
$(srcdir)/LiftedFG.cpp \
|
||||
$(srcdir)/BPSolver.cpp \
|
||||
$(srcdir)/BPNodeInfo.cpp \
|
||||
$(srcdir)/SPSolver.cpp \
|
||||
$(srcdir)/CountingBP.cpp \
|
||||
$(srcdir)/CFactorGraph.cpp \
|
||||
$(srcdir)/VarNode.cpp \
|
||||
$(srcdir)/Solver.cpp \
|
||||
$(srcdir)/VarElimSolver.cpp \
|
||||
$(srcdir)/BnBpSolver.cpp \
|
||||
$(srcdir)/FgBpSolver.cpp \
|
||||
$(srcdir)/CbpSolver.cpp \
|
||||
$(srcdir)/Util.cpp \
|
||||
$(srcdir)/HorusYap.cpp \
|
||||
$(srcdir)/HorusCli.cpp \
|
||||
@ -84,29 +85,35 @@ CPP_SOURCES = \
|
||||
OBJS = \
|
||||
BayesNet.o \
|
||||
BayesNode.o \
|
||||
ElimGraph.o \
|
||||
FactorGraph.o \
|
||||
Factor.o \
|
||||
BPSolver.o \
|
||||
BPNodeInfo.o \
|
||||
SPSolver.o \
|
||||
CFactorGraph.o \
|
||||
VarNode.o \
|
||||
Solver.o \
|
||||
VarElimSolver.o \
|
||||
BnBpSolver.o \
|
||||
FgBpSolver.o \
|
||||
CbpSolver.o \
|
||||
Util.o \
|
||||
LiftedFG.o \
|
||||
CountingBP.o \
|
||||
HorusYap.o
|
||||
|
||||
HCLI_OBJS = \
|
||||
BayesNet.o \
|
||||
BayesNode.o \
|
||||
ElimGraph.o \
|
||||
FactorGraph.o \
|
||||
Factor.o \
|
||||
BPSolver.o \
|
||||
BPNodeInfo.o \
|
||||
SPSolver.o \
|
||||
CFactorGraph.o \
|
||||
VarNode.o \
|
||||
Solver.o \
|
||||
VarElimSolver.o \
|
||||
BnBpSolver.o \
|
||||
FgBpSolver.o \
|
||||
CbpSolver.o \
|
||||
Util.o \
|
||||
LiftedFG.o \
|
||||
CountingBP.o \
|
||||
HorusCli.o \
|
||||
xmlParser/xmlParser.o
|
||||
xmlParser/xmlParser.o \
|
||||
HorusCli.o
|
||||
|
||||
SOBJS=horus.@SO@
|
||||
|
||||
|
@ -1,470 +0,0 @@
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "SPSolver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "FgVarNode.h"
|
||||
#include "Factor.h"
|
||||
#include "Shared.h"
|
||||
|
||||
|
||||
SPSolver::SPSolver (FactorGraph& fg) : Solver (&fg)
|
||||
{
|
||||
fg_ = &fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
SPSolver::~SPSolver (void)
|
||||
{
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < factorsI_.size(); i++) {
|
||||
delete factorsI_[i];
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::runTreeSolver (void)
|
||||
{
|
||||
CFactorSet factors = fg_->getFactors();
|
||||
bool finish = false;
|
||||
while (!finish) {
|
||||
finish = true;
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
CLinkSet links = factorsI_[factors[i]->getIndex()]->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (!links[j]->messageWasSended()) {
|
||||
if (readyToSendMessage(links[j])) {
|
||||
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
|
||||
links[j]->updateMessage();
|
||||
}
|
||||
finish = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
SPSolver::readyToSendMessage (const Link* link) const
|
||||
{
|
||||
CFgVarSet factorVars = link->getFactor()->getFgVarNodes();
|
||||
for (unsigned i = 0; i < factorVars.size(); i++) {
|
||||
if (factorVars[i] != link->getVariable()) {
|
||||
CLinkSet links = varsI_[factorVars[i]->getIndex()]->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getFactor() != link->getFactor() &&
|
||||
!links[j]->messageWasSended()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::runSolver (void)
|
||||
{
|
||||
initializeSolver();
|
||||
runTreeSolver();
|
||||
return;
|
||||
nIter_ = 0;
|
||||
while (!converged() && nIter_ < SolverOptions::maxIter) {
|
||||
|
||||
nIter_ ++;
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
cout << " Iteration " << nIter_ << endl;
|
||||
cout << "****************************************" ;
|
||||
cout << "****************************************" ;
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
switch (SolverOptions::schedule) {
|
||||
case SolverOptions::S_SEQ_RANDOM:
|
||||
random_shuffle (links_.begin(), links_.end());
|
||||
// no break
|
||||
|
||||
case SolverOptions::S_SEQ_FIXED:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||
links_[i]->updateMessage();
|
||||
}
|
||||
break;
|
||||
|
||||
case SolverOptions::S_PARALLEL:
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||
}
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->updateMessage();
|
||||
}
|
||||
break;
|
||||
|
||||
case SolverOptions::S_MAX_RESIDUAL:
|
||||
maxResidualSchedule();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (DL >= 2) {
|
||||
cout << endl;
|
||||
if (nIter_ < SolverOptions::maxIter) {
|
||||
cout << "Loopy Sum-Product converged in " ;
|
||||
cout << nIter_ << " iterations" << endl;
|
||||
} else {
|
||||
cout << "The maximum number of iterations was hit, terminating..." ;
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
SPSolver::getPosterioriOf (Vid vid) const
|
||||
{
|
||||
assert (fg_->getFgVarNode (vid));
|
||||
FgVarNode* var = fg_->getFgVarNode (vid);
|
||||
ParamSet probs;
|
||||
|
||||
if (var->hasEvidence()) {
|
||||
probs.resize (var->getDomainSize(), 0.0);
|
||||
probs[var->getEvidence()] = 1.0;
|
||||
} else {
|
||||
probs.resize (var->getDomainSize(), 1.0);
|
||||
CLinkSet links = varsI_[var->getIndex()]->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
CParamSet msg = links[i]->getMessage();
|
||||
for (unsigned j = 0; j < msg.size(); j++) {
|
||||
probs[j] *= msg[j];
|
||||
}
|
||||
}
|
||||
Util::normalize (probs);
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
SPSolver::getJointDistributionOf (const VidSet& jointVids)
|
||||
{
|
||||
FgVarSet jointVars;
|
||||
unsigned dsize = 1;
|
||||
for (unsigned i = 0; i < jointVids.size(); i++) {
|
||||
FgVarNode* varNode = fg_->getFgVarNode (jointVids[i]);
|
||||
dsize *= varNode->getDomainSize();
|
||||
jointVars.push_back (varNode);
|
||||
}
|
||||
|
||||
unsigned maxVid = std::numeric_limits<unsigned>::max();
|
||||
FgVarNode* junctionVar = new FgVarNode (maxVid, dsize);
|
||||
FgVarSet factorVars = { junctionVar };
|
||||
for (unsigned i = 0; i < jointVars.size(); i++) {
|
||||
factorVars.push_back (jointVars[i]);
|
||||
}
|
||||
|
||||
unsigned nParams = dsize * dsize;
|
||||
ParamSet params (nParams);
|
||||
for (unsigned i = 0; i < nParams; i++) {
|
||||
unsigned row = i / dsize;
|
||||
unsigned col = i % dsize;
|
||||
if (row == col) {
|
||||
params[i] = 1;
|
||||
} else {
|
||||
params[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Distribution* dist = new Distribution (params, maxVid);
|
||||
Factor* newFactor = new Factor (factorVars, dist);
|
||||
fg_->addVariable (junctionVar);
|
||||
fg_->addFactor (newFactor);
|
||||
|
||||
runSolver();
|
||||
ParamSet results = getPosterioriOf (maxVid);
|
||||
deleteJunction (newFactor, junctionVar);
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::initializeSolver (void)
|
||||
{
|
||||
fg_->setIndexes();
|
||||
|
||||
CFgVarSet vars = fg_->getFgVarNodes();
|
||||
for (unsigned i = 0; i < varsI_.size(); i++) {
|
||||
delete varsI_[i];
|
||||
}
|
||||
varsI_.reserve (vars.size());
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
varsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
CFactorSet factors = fg_->getFactors();
|
||||
for (unsigned i = 0; i < factorsI_.size(); i++) {
|
||||
delete factorsI_[i];
|
||||
}
|
||||
factorsI_.reserve (factors.size());
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
factorsI_.push_back (new SPNodeInfo());
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
delete links_[i];
|
||||
}
|
||||
createLinks();
|
||||
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
Factor* source = links_[i]->getFactor();
|
||||
FgVarNode* dest = links_[i]->getVariable();
|
||||
varsI_[dest->getIndex()]->addLink (links_[i]);
|
||||
factorsI_[source->getIndex()]->addLink (links_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::createLinks (void)
|
||||
{
|
||||
CFactorSet factors = fg_->getFactors();
|
||||
for (unsigned i = 0; i < factors.size(); i++) {
|
||||
CFgVarSet neighbors = factors[i]->getFgVarNodes();
|
||||
for (unsigned j = 0; j < neighbors.size(); j++) {
|
||||
links_.push_back (new Link (factors[i], neighbors[j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::deleteJunction (Factor* f, FgVarNode* v)
|
||||
{
|
||||
fg_->removeFactor (f);
|
||||
f->freeDistribution();
|
||||
delete f;
|
||||
fg_->removeVariable (v);
|
||||
delete v;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
SPSolver::converged (void)
|
||||
{
|
||||
// this can happen if the graph is fully disconnected
|
||||
if (links_.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
if (nIter_ == 0 || nIter_ == 1) {
|
||||
return false;
|
||||
}
|
||||
bool converged = true;
|
||||
if (SolverOptions::schedule == SolverOptions::S_MAX_RESIDUAL) {
|
||||
Param maxResidual = (*(sortedOrder_.begin()))->getResidual();
|
||||
if (maxResidual < SolverOptions::accuracy) {
|
||||
converged = true;
|
||||
} else {
|
||||
converged = false;
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
double residual = links_[i]->getResidual();
|
||||
if (DL >= 2) {
|
||||
cout << links_[i]->toString() + " residual = " << residual << endl;
|
||||
}
|
||||
if (residual > SolverOptions::accuracy) {
|
||||
converged = false;
|
||||
if (DL == 0) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
SPSolver::maxResidualSchedule (void)
|
||||
{
|
||||
if (nIter_ == 1) {
|
||||
for (unsigned i = 0; i < links_.size(); i++) {
|
||||
links_[i]->setNextMessage (getFactor2VarMsg (links_[i]));
|
||||
SortedOrder::iterator it = sortedOrder_.insert (links_[i]);
|
||||
linkMap_.insert (make_pair (links_[i], it));
|
||||
if (DL >= 2 && DL < 5) {
|
||||
cout << "calculating " << links_[i]->toString() << endl;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < links_.size(); c++) {
|
||||
if (DL >= 2) {
|
||||
cout << endl << "current residuals:" << endl;
|
||||
for (SortedOrder::iterator it = sortedOrder_.begin();
|
||||
it != sortedOrder_.end(); it ++) {
|
||||
cout << " " << setw (30) << left << (*it)->toString();
|
||||
cout << "residual = " << (*it)->getResidual() << endl;
|
||||
}
|
||||
}
|
||||
|
||||
SortedOrder::iterator it = sortedOrder_.begin();
|
||||
Link* link = *it;
|
||||
if (DL >= 2) {
|
||||
cout << "updating " << (*sortedOrder_.begin())->toString() << endl;
|
||||
}
|
||||
if (link->getResidual() < SolverOptions::accuracy) {
|
||||
return;
|
||||
}
|
||||
link->updateMessage();
|
||||
link->clearResidual();
|
||||
sortedOrder_.erase (it);
|
||||
linkMap_.find (link)->second = sortedOrder_.insert (link);
|
||||
|
||||
// update the messages that depend on message source --> destin
|
||||
CFactorSet factorNeighbors = link->getVariable()->getFactors();
|
||||
for (unsigned i = 0; i < factorNeighbors.size(); i++) {
|
||||
if (factorNeighbors[i] != link->getFactor()) {
|
||||
CLinkSet links = factorsI_[factorNeighbors[i]->getIndex()]->getLinks();
|
||||
for (unsigned j = 0; j < links.size(); j++) {
|
||||
if (links[j]->getVariable() != link->getVariable()) {
|
||||
if (DL >= 2 && DL < 5) {
|
||||
cout << " calculating " << links[j]->toString() << endl;
|
||||
}
|
||||
links[j]->setNextMessage (getFactor2VarMsg (links[j]));
|
||||
LinkMap::iterator iter = linkMap_.find (links[j]);
|
||||
sortedOrder_.erase (iter->second);
|
||||
iter->second = sortedOrder_.insert (links[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
SPSolver::getFactor2VarMsg (const Link* link) const
|
||||
{
|
||||
const Factor* src = link->getFactor();
|
||||
const FgVarNode* dest = link->getVariable();
|
||||
CFgVarSet neighbors = src->getFgVarNodes();
|
||||
CLinkSet links = factorsI_[src->getIndex()]->getLinks();
|
||||
// calculate the product of messages that were sent
|
||||
// to factor `src', except from var `dest'
|
||||
Factor result (*src);
|
||||
Factor temp;
|
||||
if (DL >= 5) {
|
||||
cout << "calculating " ;
|
||||
cout << src->getLabel() << " --> " << dest->getLabel();
|
||||
cout << endl;
|
||||
}
|
||||
for (unsigned i = 0; i < neighbors.size(); i++) {
|
||||
if (links[i]->getVariable() != dest) {
|
||||
if (DL >= 5) {
|
||||
cout << " message from " << links[i]->getVariable()->getLabel();
|
||||
cout << ": " ;
|
||||
ParamSet p = getVar2FactorMsg (links[i]);
|
||||
cout << endl;
|
||||
Factor temp2 (links[i]->getVariable(), p);
|
||||
temp.multiplyByFactor (temp2);
|
||||
temp2.freeDistribution();
|
||||
} else {
|
||||
Factor temp2 (links[i]->getVariable(), getVar2FactorMsg (links[i]));
|
||||
temp.multiplyByFactor (temp2);
|
||||
temp2.freeDistribution();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (links.size() >= 2) {
|
||||
result.multiplyByFactor (temp, &(src->getCptEntries()));
|
||||
if (DL >= 5) {
|
||||
cout << " message product: " ;
|
||||
cout << Util::parametersToString (temp.getParameters()) << endl;
|
||||
cout << " factor product: " ;
|
||||
cout << Util::parametersToString (src->getParameters());
|
||||
cout << " x " ;
|
||||
cout << Util::parametersToString (temp.getParameters());
|
||||
cout << " = " ;
|
||||
cout << Util::parametersToString (result.getParameters()) << endl;
|
||||
}
|
||||
temp.freeDistribution();
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getVariable() != dest) {
|
||||
result.removeVariable (links[i]->getVariable());
|
||||
}
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " final message: " ;
|
||||
cout << Util::parametersToString (result.getParameters()) << endl << endl;
|
||||
}
|
||||
ParamSet msg = result.getParameters();
|
||||
result.freeDistribution();
|
||||
return msg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
SPSolver::getVar2FactorMsg (const Link* link) const
|
||||
{
|
||||
const FgVarNode* src = link->getVariable();
|
||||
const Factor* dest = link->getFactor();
|
||||
ParamSet msg;
|
||||
if (src->hasEvidence()) {
|
||||
msg.resize (src->getDomainSize(), 0.0);
|
||||
msg[src->getEvidence()] = 1.0;
|
||||
if (DL >= 5) {
|
||||
cout << Util::parametersToString (msg);
|
||||
}
|
||||
} else {
|
||||
msg.resize (src->getDomainSize(), 1.0);
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << Util::parametersToString (msg);
|
||||
}
|
||||
CLinkSet links = varsI_[src->getIndex()]->getLinks();
|
||||
for (unsigned i = 0; i < links.size(); i++) {
|
||||
if (links[i]->getFactor() != dest) {
|
||||
CParamSet msgFromFactor = links[i]->getMessage();
|
||||
for (unsigned j = 0; j < msgFromFactor.size(); j++) {
|
||||
msg[j] *= msgFromFactor[j];
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " x " << Util::parametersToString (msgFromFactor);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (DL >= 5) {
|
||||
cout << " = " << Util::parametersToString (msg);
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
@ -1,130 +0,0 @@
|
||||
#ifndef BP_SP_SOLVER_H
|
||||
#define BP_SP_SOLVER_H
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
#include "Solver.h"
|
||||
#include "FgVarNode.h"
|
||||
#include "Factor.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class FactorGraph;
|
||||
class SPSolver;
|
||||
|
||||
|
||||
class Link
|
||||
{
|
||||
public:
|
||||
Link (Factor* f, FgVarNode* v)
|
||||
{
|
||||
factor_ = f;
|
||||
var_ = v;
|
||||
currMsg_.resize (v->getDomainSize(), 1);
|
||||
nextMsg_.resize (v->getDomainSize(), 1);
|
||||
msgSended_ = false;
|
||||
residual_ = 0.0;
|
||||
}
|
||||
|
||||
void setMessage (ParamSet msg)
|
||||
{
|
||||
Util::normalize (msg);
|
||||
residual_ = Util::getMaxNorm (currMsg_, msg);
|
||||
currMsg_ = msg;
|
||||
}
|
||||
|
||||
void setNextMessage (CParamSet msg)
|
||||
{
|
||||
nextMsg_ = msg;
|
||||
Util::normalize (nextMsg_);
|
||||
residual_ = Util::getMaxNorm (currMsg_, nextMsg_);
|
||||
}
|
||||
|
||||
void updateMessage (void)
|
||||
{
|
||||
currMsg_ = nextMsg_;
|
||||
msgSended_ = true;
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << factor_->getLabel();
|
||||
ss << " -- " ;
|
||||
ss << var_->getLabel();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
Factor* getFactor (void) const { return factor_; }
|
||||
FgVarNode* getVariable (void) const { return var_; }
|
||||
CParamSet getMessage (void) const { return currMsg_; }
|
||||
bool messageWasSended (void) const { return msgSended_; }
|
||||
double getResidual (void) const { return residual_; }
|
||||
void clearResidual (void) { residual_ = 0.0; }
|
||||
|
||||
private:
|
||||
Factor* factor_;
|
||||
FgVarNode* var_;
|
||||
ParamSet currMsg_;
|
||||
ParamSet nextMsg_;
|
||||
bool msgSended_;
|
||||
double residual_;
|
||||
};
|
||||
|
||||
|
||||
class SPNodeInfo
|
||||
{
|
||||
public:
|
||||
void addLink (Link* link) { links_.push_back (link); }
|
||||
CLinkSet getLinks (void) { return links_; }
|
||||
|
||||
private:
|
||||
LinkSet links_;
|
||||
};
|
||||
|
||||
|
||||
class SPSolver : public Solver
|
||||
{
|
||||
public:
|
||||
SPSolver (FactorGraph&);
|
||||
virtual ~SPSolver (void);
|
||||
|
||||
void runSolver (void);
|
||||
virtual ParamSet getPosterioriOf (Vid) const;
|
||||
ParamSet getJointDistributionOf (CVidSet);
|
||||
|
||||
protected:
|
||||
virtual void initializeSolver (void);
|
||||
void runTreeSolver (void);
|
||||
bool readyToSendMessage (const Link*) const;
|
||||
virtual void createLinks (void);
|
||||
virtual void deleteJunction (Factor*, FgVarNode*);
|
||||
bool converged (void);
|
||||
virtual void maxResidualSchedule (void);
|
||||
virtual ParamSet getFactor2VarMsg (const Link*) const;
|
||||
virtual ParamSet getVar2FactorMsg (const Link*) const;
|
||||
|
||||
struct CompareResidual {
|
||||
inline bool operator() (const Link* link1, const Link* link2)
|
||||
{
|
||||
return link1->getResidual() > link2->getResidual();
|
||||
}
|
||||
};
|
||||
|
||||
FactorGraph* fg_;
|
||||
LinkSet links_;
|
||||
vector<SPNodeInfo*> varsI_;
|
||||
vector<SPNodeInfo*> factorsI_;
|
||||
unsigned nIter_;
|
||||
|
||||
typedef multiset<Link*, CompareResidual> SortedOrder;
|
||||
SortedOrder sortedOrder_;
|
||||
|
||||
typedef map<Link*, SortedOrder::iterator> LinkMap;
|
||||
LinkMap linkMap_;
|
||||
|
||||
};
|
||||
|
||||
#endif // BP_SP_SOLVER_H
|
||||
|
@ -1,15 +1,15 @@
|
||||
#ifndef BP_SHARED_H
|
||||
#define BP_SHARED_H
|
||||
#ifndef HORUS_SHARED_H
|
||||
#define HORUS_SHARED_H
|
||||
|
||||
#include <cmath>
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
|
||||
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
|
||||
TypeName(const TypeName&); \
|
||||
@ -17,34 +17,29 @@
|
||||
|
||||
using namespace std;
|
||||
|
||||
class Variable;
|
||||
class VarNode;
|
||||
class BayesNet;
|
||||
class BayesNode;
|
||||
class FgVarNode;
|
||||
class Factor;
|
||||
class Link;
|
||||
class Edge;
|
||||
class FgVarNode;
|
||||
class FgFacNode;
|
||||
class SpLink;
|
||||
class BpLink;
|
||||
|
||||
typedef double Param;
|
||||
typedef vector<Param> ParamSet;
|
||||
typedef const ParamSet& CParamSet;
|
||||
typedef unsigned Vid;
|
||||
typedef vector<Vid> VidSet;
|
||||
typedef const VidSet& CVidSet;
|
||||
typedef vector<Variable*> VarSet;
|
||||
typedef unsigned VarId;
|
||||
typedef vector<VarId> VarIdSet;
|
||||
typedef vector<VarNode*> VarNodes;
|
||||
typedef vector<BayesNode*> BnNodeSet;
|
||||
typedef const BnNodeSet& CBnNodeSet;
|
||||
typedef vector<FgVarNode*> FgVarSet;
|
||||
typedef const FgVarSet& CFgVarSet;
|
||||
typedef vector<FgFacNode*> FgFacSet;
|
||||
typedef vector<Factor*> FactorSet;
|
||||
typedef const FactorSet& CFactorSet;
|
||||
typedef vector<Link*> LinkSet;
|
||||
typedef const LinkSet& CLinkSet;
|
||||
typedef vector<Edge*> EdgeSet;
|
||||
typedef const EdgeSet& CEdgeSet;
|
||||
typedef vector<string> Domain;
|
||||
typedef vector<string> States;
|
||||
typedef vector<unsigned> Ranges;
|
||||
typedef vector<unsigned> DConf;
|
||||
typedef pair<unsigned, unsigned> DConstraint;
|
||||
typedef map<unsigned, unsigned> IndexMap;
|
||||
|
||||
|
||||
// level of debug information
|
||||
static const unsigned DL = 0;
|
||||
@ -54,197 +49,260 @@ static const int NO_EVIDENCE = -1;
|
||||
// number of digits to show when printing a parameter
|
||||
static const unsigned PRECISION = 5;
|
||||
|
||||
static const bool EXPORT_TO_DOT = false;
|
||||
static const unsigned EXPORT_MIN_SIZE = 30;
|
||||
static const bool COLLECT_STATISTICS = false;
|
||||
|
||||
static const bool EXPORT_TO_GRAPHVIZ = false;
|
||||
static const unsigned EXPORT_MINIMAL_SIZE = 100;
|
||||
|
||||
static const double INF = -numeric_limits<Param>::infinity();
|
||||
|
||||
|
||||
namespace SolverOptions
|
||||
{
|
||||
enum Schedule
|
||||
{
|
||||
S_SEQ_FIXED,
|
||||
S_SEQ_RANDOM,
|
||||
S_PARALLEL,
|
||||
S_MAX_RESIDUAL
|
||||
namespace NumberSpace {
|
||||
enum ns {
|
||||
NORMAL,
|
||||
LOGARITHM
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
|
||||
extern NumberSpace::ns NSPACE;
|
||||
|
||||
|
||||
namespace InfAlgorithms {
|
||||
enum InfAlgs
|
||||
{
|
||||
VE, // variable elimination
|
||||
BN_BP, // bayesian network belief propagation
|
||||
FG_BP, // factor graph belief propagation
|
||||
CBP // counting bp solver
|
||||
};
|
||||
extern InfAlgs infAlgorithm;
|
||||
};
|
||||
|
||||
|
||||
namespace BpOptions
|
||||
{
|
||||
enum Schedule {
|
||||
SEQ_FIXED,
|
||||
SEQ_RANDOM,
|
||||
PARALLEL,
|
||||
MAX_RESIDUAL
|
||||
};
|
||||
extern bool runBayesBall;
|
||||
extern bool convertBn2Fg;
|
||||
extern bool compressFactorGraph;
|
||||
extern Schedule schedule;
|
||||
extern double accuracy;
|
||||
extern unsigned maxIter;
|
||||
extern bool useAlwaysLoopySolver;
|
||||
}
|
||||
|
||||
|
||||
namespace Util
|
||||
{
|
||||
void toLog (ParamSet&);
|
||||
void fromLog (ParamSet&);
|
||||
void normalize (ParamSet&);
|
||||
void logSum (Param&, Param);
|
||||
void multiply (ParamSet&, const ParamSet&);
|
||||
void multiply (ParamSet&, const ParamSet&, unsigned);
|
||||
void add (ParamSet&, const ParamSet&);
|
||||
void add (ParamSet&, const ParamSet&, unsigned);
|
||||
void pow (ParamSet&, unsigned);
|
||||
double getL1dist (CParamSet, CParamSet);
|
||||
double getMaxNorm (CParamSet, CParamSet);
|
||||
Param pow (Param, unsigned);
|
||||
double getL1Distance (const ParamSet&, const ParamSet&);
|
||||
double getMaxNorm (const ParamSet&, const ParamSet&);
|
||||
unsigned getNumberOfDigits (int);
|
||||
bool isInteger (const string&);
|
||||
string parametersToString (CParamSet);
|
||||
vector<DConf> getDomainConfigurations (const VarSet&);
|
||||
vector<string> getInstantiations (const VarSet&);
|
||||
string parametersToString (const ParamSet&, unsigned = PRECISION);
|
||||
BayesNet* generateBayesianNetworkTreeWithLevel (unsigned);
|
||||
vector<DConf> getDomainConfigurations (const VarNodes&);
|
||||
vector<string> getJointStateStrings (const VarNodes&);
|
||||
double tl (Param v);
|
||||
double fl (Param v);
|
||||
double multIdenty();
|
||||
double addIdenty();
|
||||
double withEvidence();
|
||||
double noEvidence();
|
||||
double one();
|
||||
double zero();
|
||||
};
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Util::logSum (Param& x, Param y)
|
||||
{
|
||||
// x = log (exp (x) + exp (y)); return;
|
||||
assert (isfinite (x) && finite (y));
|
||||
// If one value is much smaller than the other, keep the larger value.
|
||||
if (x < (y - log (1e200))) {
|
||||
x = y;
|
||||
return;
|
||||
}
|
||||
if (y < (x - log (1e200))) {
|
||||
return;
|
||||
}
|
||||
double diff = x - y;
|
||||
assert (isfinite (diff) && finite (x) && finite (y));
|
||||
if (!isfinite (exp (diff))) { // difference is too large
|
||||
x = x > y ? x : y;
|
||||
} else { // otherwise return the sum.
|
||||
x = y + log (static_cast<double>(1.0) + exp (diff));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Util::multiply (ParamSet& v1, const ParamSet& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
v1[i] *= v2[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Util::multiply (ParamSet& v1, const ParamSet& v2, unsigned repetitions)
|
||||
{
|
||||
for (unsigned count = 0; count < v1.size(); ) {
|
||||
for (unsigned i = 0; i < v2.size(); i++) {
|
||||
for (unsigned r = 0; r < repetitions; r++) {
|
||||
v1[count] *= v2[i];
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Util::add (ParamSet& v1, const ParamSet& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
v1[i] += v2[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline void
|
||||
Util::add (ParamSet& v1, const ParamSet& v2, unsigned repetitions)
|
||||
{
|
||||
for (unsigned count = 0; count < v1.size(); ) {
|
||||
for (unsigned i = 0; i < v2.size(); i++) {
|
||||
for (unsigned r = 0; r < repetitions; r++) {
|
||||
v1[count] += v2[i];
|
||||
count ++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
inline double
|
||||
Util::tl (Param v)
|
||||
{
|
||||
return NSPACE == NumberSpace::NORMAL ? v : log(v);
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::fl (Param v)
|
||||
{
|
||||
return NSPACE == NumberSpace::NORMAL ? v : exp(v);
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::multIdenty() {
|
||||
return NSPACE == NumberSpace::NORMAL ? 1.0 : 0.0;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::addIdenty()
|
||||
{
|
||||
return NSPACE == NumberSpace::NORMAL ? 0.0 : INF;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::withEvidence()
|
||||
{
|
||||
return NSPACE == NumberSpace::NORMAL ? 1.0 : 0.0;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::noEvidence() {
|
||||
return NSPACE == NumberSpace::NORMAL ? 0.0 : INF;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::one()
|
||||
{
|
||||
return NSPACE == NumberSpace::NORMAL ? 1.0 : 0.0;
|
||||
}
|
||||
|
||||
inline double
|
||||
Util::zero() {
|
||||
return NSPACE == NumberSpace::NORMAL ? 0.0 : INF;
|
||||
}
|
||||
|
||||
|
||||
struct NetInfo
|
||||
{
|
||||
NetInfo (void)
|
||||
NetInfo (unsigned size, bool loopy, unsigned nIters, double time)
|
||||
{
|
||||
counting = 0;
|
||||
nIters = 0;
|
||||
solvingTime = 0.0;
|
||||
this->size = size;
|
||||
this->loopy = loopy;
|
||||
this->nIters = nIters;
|
||||
this->time = time;
|
||||
}
|
||||
unsigned counting;
|
||||
double solvingTime;
|
||||
unsigned size;
|
||||
bool loopy;
|
||||
unsigned nIters;
|
||||
double time;
|
||||
};
|
||||
|
||||
|
||||
struct CompressInfo
|
||||
{
|
||||
CompressInfo (unsigned a, unsigned b, unsigned c,
|
||||
unsigned d, unsigned e) {
|
||||
nUncVars = a;
|
||||
nUncFactors = b;
|
||||
nCompVars = c;
|
||||
nCompFactors = d;
|
||||
nNeighborlessVars = e;
|
||||
CompressInfo (unsigned a, unsigned b, unsigned c, unsigned d, unsigned e)
|
||||
{
|
||||
nGroundVars = a;
|
||||
nGroundFactors = b;
|
||||
nClusterVars = c;
|
||||
nClusterFactors = d;
|
||||
nWithoutNeighs = e;
|
||||
}
|
||||
unsigned nUncVars;
|
||||
unsigned nUncFactors;
|
||||
unsigned nCompVars;
|
||||
unsigned nCompFactors;
|
||||
unsigned nNeighborlessVars;
|
||||
unsigned nGroundVars;
|
||||
unsigned nGroundFactors;
|
||||
unsigned nClusterVars;
|
||||
unsigned nClusterFactors;
|
||||
unsigned nWithoutNeighs;
|
||||
};
|
||||
|
||||
|
||||
typedef map<unsigned, NetInfo> StatisticMap;
|
||||
class Statistics
|
||||
{
|
||||
public:
|
||||
|
||||
static void updateStats (unsigned size, unsigned nIters, double time)
|
||||
{
|
||||
StatisticMap::iterator it = stats_.find (size);
|
||||
if (it == stats_.end()) {
|
||||
it = (stats_.insert (make_pair (size, NetInfo()))).first;
|
||||
} else {
|
||||
it->second.counting ++;
|
||||
it->second.nIters += nIters;
|
||||
it->second.solvingTime += time;
|
||||
totalOfIterations += nIters;
|
||||
if (nIters > maxIterations) {
|
||||
maxIterations = nIters;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void updateCompressingStats (unsigned nUncVars,
|
||||
unsigned nUncFactors,
|
||||
unsigned nCompVars,
|
||||
unsigned nCompFactors,
|
||||
unsigned nNeighborlessVars) {
|
||||
compressInfo_.push_back (CompressInfo (
|
||||
nUncVars, nUncFactors, nCompVars, nCompFactors, nNeighborlessVars));
|
||||
}
|
||||
|
||||
static void printCompressingStats (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "BayesNet::printCompressingStats()" << endl;
|
||||
abort();
|
||||
}
|
||||
out << "--------------------------------------" ;
|
||||
out << "--------------------------------------" << endl;
|
||||
out << " Compression Stats" << endl;
|
||||
out << "--------------------------------------" ;
|
||||
out << "--------------------------------------" << endl;
|
||||
out << left;
|
||||
out << "Uncompress Compressed Uncompress Compressed Neighborless";
|
||||
out << endl;
|
||||
out << "Vars Vars Factors Factors Vars" ;
|
||||
out << endl;
|
||||
for (unsigned i = 0; i < compressInfo_.size(); i++) {
|
||||
out << setw (13) << compressInfo_[i].nUncVars;
|
||||
out << setw (13) << compressInfo_[i].nCompVars;
|
||||
out << setw (13) << compressInfo_[i].nUncFactors;
|
||||
out << setw (13) << compressInfo_[i].nCompFactors;
|
||||
out << setw (13) << compressInfo_[i].nNeighborlessVars;
|
||||
out << endl;
|
||||
}
|
||||
}
|
||||
|
||||
static unsigned getCounting (unsigned size)
|
||||
{
|
||||
StatisticMap::iterator it = stats_.find(size);
|
||||
assert (it != stats_.end());
|
||||
return it->second.counting;
|
||||
}
|
||||
|
||||
static void writeStats (void)
|
||||
{
|
||||
ofstream out ("../../stats.txt");
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "Statistics::updateStats()" << endl;
|
||||
abort();
|
||||
}
|
||||
unsigned avgIterations = 0;
|
||||
if (numSolvedLoopyNets > 0) {
|
||||
avgIterations = totalOfIterations / numSolvedLoopyNets;
|
||||
}
|
||||
double totalSolvingTime = 0.0;
|
||||
for (StatisticMap::iterator it = stats_.begin();
|
||||
it != stats_.end(); it++) {
|
||||
totalSolvingTime += it->second.solvingTime;
|
||||
}
|
||||
out << "created networks: " << numCreatedNets << endl;
|
||||
out << "solver runs on polytrees: " << numSolvedPolyTrees << endl;
|
||||
out << "solver runs on loopy networks: " << numSolvedLoopyNets << endl;
|
||||
out << " unconverged: " << numUnconvergedRuns << endl;
|
||||
out << " max iterations: " << maxIterations << endl;
|
||||
out << " average iterations: " << avgIterations << endl;
|
||||
out << "total solving time " << totalSolvingTime << endl;
|
||||
out << endl;
|
||||
out << left << endl;
|
||||
out << setw (15) << "Network Size" ;
|
||||
out << setw (15) << "Counting" ;
|
||||
out << setw (15) << "Solving Time" ;
|
||||
out << setw (15) << "Average Time" ;
|
||||
out << setw (15) << "#Iterations" ;
|
||||
out << endl;
|
||||
for (StatisticMap::iterator it = stats_.begin();
|
||||
it != stats_.end(); it++) {
|
||||
out << setw (15) << it->first;
|
||||
out << setw (15) << it->second.counting;
|
||||
out << setw (15) << it->second.solvingTime;
|
||||
if (it->second.counting > 0) {
|
||||
out << setw (15) << it->second.solvingTime / it->second.counting;
|
||||
} else {
|
||||
out << setw (15) << "0.0" ;
|
||||
}
|
||||
out << setw (15) << it->second.nIters;
|
||||
out << endl;
|
||||
}
|
||||
out.close();
|
||||
}
|
||||
|
||||
static unsigned numCreatedNets;
|
||||
static unsigned numSolvedPolyTrees;
|
||||
static unsigned numSolvedLoopyNets;
|
||||
static unsigned numUnconvergedRuns;
|
||||
static unsigned getSolvedNetworksCounting (void);
|
||||
static void incrementPrimaryNetworksCounting (void);
|
||||
static unsigned getPrimaryNetworksCounting (void);
|
||||
static void updateStatistics (unsigned, bool, unsigned, double);
|
||||
static void printStatistics (void);
|
||||
static void writeStatisticsToFile (const char*);
|
||||
static void updateCompressingStatistics (
|
||||
unsigned, unsigned, unsigned, unsigned, unsigned);
|
||||
|
||||
private:
|
||||
static StatisticMap stats_;
|
||||
static unsigned maxIterations;
|
||||
static unsigned totalOfIterations;
|
||||
static string getStatisticString (void);
|
||||
|
||||
static vector<NetInfo> netInfo_;
|
||||
static vector<CompressInfo> compressInfo_;
|
||||
static unsigned primaryNetCount_;
|
||||
};
|
||||
|
||||
#endif //BP_SHARED_H
|
||||
#endif // HORUS_SHARED_H
|
||||
|
||||
|
53
packages/CLPBN/clpbn/bp/Solver.cpp
Normal file
53
packages/CLPBN/clpbn/bp/Solver.cpp
Normal file
@ -0,0 +1,53 @@
|
||||
#include "Solver.h"
|
||||
|
||||
|
||||
void
|
||||
Solver::printAllPosterioris (void)
|
||||
{
|
||||
const VarNodes& vars = gm_->getVariableNodes();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
printPosterioriOf (vars[i]->varId());
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Solver::printPosterioriOf (VarId vid)
|
||||
{
|
||||
VarNode* var = gm_->getVariableNode (vid);
|
||||
const ParamSet& posterioriDist = getPosterioriOf (vid);
|
||||
const States& states = var->states();
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
cout << "P(" << var->label() << "=" << states[i] << ") = " ;
|
||||
cout << setprecision (PRECISION) << posterioriDist[i];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Solver::printJointDistributionOf (const VarIdSet& vids)
|
||||
{
|
||||
VarNodes vars;
|
||||
VarIdSet vidsWithoutEvidence;
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
VarNode* var = gm_->getVariableNode (vids[i]);
|
||||
if (var->hasEvidence() == false) {
|
||||
vars.push_back (var);
|
||||
vidsWithoutEvidence.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
const ParamSet& jointDist = getJointDistributionOf (vidsWithoutEvidence);
|
||||
vector<string> jointStrings = Util::getJointStateStrings (vars);
|
||||
for (unsigned i = 0; i < jointDist.size(); i++) {
|
||||
cout << "P(" << jointStrings[i] << ") = " ;
|
||||
cout << setprecision (PRECISION) << jointDist[i];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
@ -1,10 +1,10 @@
|
||||
#ifndef BP_SOLVER_H
|
||||
#define BP_SOLVER_H
|
||||
#ifndef HORUS_SOLVER_H
|
||||
#define HORUS_SOLVER_H
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
#include "GraphicalModel.h"
|
||||
#include "Variable.h"
|
||||
#include "VarNode.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
@ -15,66 +15,18 @@ class Solver
|
||||
{
|
||||
gm_ = gm;
|
||||
}
|
||||
virtual ~Solver() {} // to call subclass destructor
|
||||
virtual ~Solver() {} // to ensure that subclass destructor is called
|
||||
virtual void runSolver (void) = 0;
|
||||
virtual ParamSet getPosterioriOf (Vid) const = 0;
|
||||
virtual ParamSet getJointDistributionOf (const VidSet&) = 0;
|
||||
virtual ParamSet getPosterioriOf (VarId) = 0;
|
||||
virtual ParamSet getJointDistributionOf (const VarIdSet&) = 0;
|
||||
|
||||
void printAllPosterioris (void) const
|
||||
{
|
||||
VarSet vars = gm_->getVariables();
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
printPosterioriOf (vars[i]->getVarId());
|
||||
}
|
||||
}
|
||||
|
||||
void printPosterioriOf (Vid vid) const
|
||||
{
|
||||
Variable* var = gm_->getVariable (vid);
|
||||
cout << endl;
|
||||
cout << setw (20) << left << var->getLabel() << "posteriori" ;
|
||||
cout << endl;
|
||||
cout << "------------------------------" ;
|
||||
cout << endl;
|
||||
const Domain& domain = var->getDomain();
|
||||
ParamSet results = getPosterioriOf (vid);
|
||||
for (unsigned xi = 0; xi < var->getDomainSize(); xi++) {
|
||||
cout << setw (20) << domain[xi];
|
||||
cout << setprecision (PRECISION) << results[xi];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
void printJointDistributionOf (const VidSet& vids)
|
||||
{
|
||||
const ParamSet& jointDist = getJointDistributionOf (vids);
|
||||
cout << endl;
|
||||
cout << "joint distribution of " ;
|
||||
VarSet vars;
|
||||
for (unsigned i = 0; i < vids.size() - 1; i++) {
|
||||
Variable* var = gm_->getVariable (vids[i]);
|
||||
cout << var->getLabel() << ", " ;
|
||||
vars.push_back (var);
|
||||
}
|
||||
Variable* var = gm_->getVariable (vids[vids.size() - 1]);
|
||||
cout << var->getLabel() ;
|
||||
vars.push_back (var);
|
||||
cout << endl;
|
||||
cout << "------------------------------" ;
|
||||
cout << endl;
|
||||
const vector<string>& domainConfs = Util::getInstantiations (vars);
|
||||
for (unsigned i = 0; i < jointDist.size(); i++) {
|
||||
cout << left << setw (20) << domainConfs[i];
|
||||
cout << setprecision (PRECISION) << jointDist[i];
|
||||
cout << endl;
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
void printAllPosterioris (void);
|
||||
void printPosterioriOf (VarId vid);
|
||||
void printJointDistributionOf (const VarIdSet& vids);
|
||||
|
||||
private:
|
||||
const GraphicalModel* gm_;
|
||||
};
|
||||
|
||||
#endif //BP_SOLVER_H
|
||||
#endif // HORUS_SOLVER_H
|
||||
|
||||
|
246
packages/CLPBN/clpbn/bp/StatesIndexer.h
Normal file
246
packages/CLPBN/clpbn/bp/StatesIndexer.h
Normal file
@ -0,0 +1,246 @@
|
||||
#ifndef HORUS_STATESINDEXER_H
|
||||
#define HORUS_STATESINDEXER_H
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
class StatesIndexer {
|
||||
public:
|
||||
|
||||
StatesIndexer (const Ranges& ranges)
|
||||
{
|
||||
maxIndex_ = 1;
|
||||
states_.resize (ranges.size(), 0);
|
||||
ranges_ = ranges;
|
||||
for (unsigned i = 0; i < ranges.size(); i++) {
|
||||
maxIndex_ *= ranges[i];
|
||||
}
|
||||
linearIndex_ = 0;
|
||||
}
|
||||
|
||||
|
||||
StatesIndexer (const VarNodes& vars)
|
||||
{
|
||||
maxIndex_ = 1;
|
||||
states_.resize (vars.size(), 0);
|
||||
ranges_.reserve (vars.size());
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
ranges_.push_back (vars[i]->nrStates());
|
||||
maxIndex_ *= vars[i]->nrStates();
|
||||
}
|
||||
linearIndex_ = 0;
|
||||
}
|
||||
|
||||
StatesIndexer& operator++ (void) {
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
states_[i] ++;
|
||||
if (states_[i] == (int)ranges_[i]) {
|
||||
states_[i] = 0;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
linearIndex_ ++;
|
||||
return *this;
|
||||
}
|
||||
|
||||
StatesIndexer& operator-- (void) {
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
states_[i] --;
|
||||
if (states_[i] == -1) {
|
||||
states_[i] = ranges_[i] - 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
linearIndex_ --;
|
||||
return *this;
|
||||
}
|
||||
|
||||
void incrementState (unsigned whichVar)
|
||||
{
|
||||
for (int i = whichVar; i >= 0; i--) {
|
||||
states_[i] ++;
|
||||
if (states_[i] == (int)ranges_[i] && i != 0) {
|
||||
if (i == 0) {
|
||||
linearIndex_ = maxIndex_;
|
||||
} else {
|
||||
states_[i] = 0;
|
||||
}
|
||||
} else {
|
||||
linearIndex_ = getLinearIndexFromStates();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void decrementState (unsigned whichVar)
|
||||
{
|
||||
for (int i = whichVar; i >= 0; i--) {
|
||||
states_[i] --;
|
||||
if (states_[i] == -1) {
|
||||
if (i == 0) {
|
||||
linearIndex_ = -1;
|
||||
} else {
|
||||
states_[i] = ranges_[i] - 1;
|
||||
}
|
||||
} else {
|
||||
linearIndex_ = getLinearIndexFromStates();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void nextSameState (unsigned whichVar)
|
||||
{
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
if (i != (int)whichVar) {
|
||||
states_[i] ++;
|
||||
if (states_[i] == (int)ranges_[i]) {
|
||||
if (i == 0 || (i-1 == (int)whichVar && whichVar == 0)) {
|
||||
linearIndex_ = maxIndex_;
|
||||
} else {
|
||||
states_[i] = 0;
|
||||
}
|
||||
} else {
|
||||
linearIndex_ = getLinearIndexFromStates();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void previousSameState (unsigned whichVar)
|
||||
{
|
||||
for (int i = ranges_.size() - 1; i >= 0; i--) {
|
||||
if (i != (int)whichVar) {
|
||||
states_[i] --;
|
||||
if (states_[i] == - 1) {
|
||||
if (i == 0 || (i-1 == (int)whichVar && whichVar == 0)) {
|
||||
linearIndex_ = -1;
|
||||
} else {
|
||||
states_[i] = ranges_[i] - 1;
|
||||
}
|
||||
} else {
|
||||
linearIndex_ = getLinearIndexFromStates();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void moveToBegin (void)
|
||||
{
|
||||
std::fill (states_.begin(), states_.end(), 0);
|
||||
linearIndex_ = 0;
|
||||
}
|
||||
|
||||
void moveToEnd (void)
|
||||
{
|
||||
for (unsigned i = 0; i < states_.size(); i++) {
|
||||
states_[i] = ranges_[i] - 1;
|
||||
}
|
||||
linearIndex_ = maxIndex_ - 1;
|
||||
}
|
||||
|
||||
bool valid (void) const
|
||||
{
|
||||
return linearIndex_ >= 0 && linearIndex_ < (int)maxIndex_;
|
||||
}
|
||||
|
||||
unsigned getLinearIndex (void) const
|
||||
{
|
||||
return linearIndex_;
|
||||
}
|
||||
|
||||
const vector<int>& getStates (void) const
|
||||
{
|
||||
return states_;
|
||||
}
|
||||
|
||||
unsigned operator[] (unsigned whichVar) const
|
||||
{
|
||||
assert (valid());
|
||||
assert (whichVar < states_.size());
|
||||
return states_[whichVar];
|
||||
}
|
||||
|
||||
string toString (void) const
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "linear index=" << setw (3) << linearIndex_ << " " ;
|
||||
ss << "states= [" << states_[0] ;
|
||||
for (unsigned i = 1; i < states_.size(); i++) {
|
||||
ss << ", " << states_[i];
|
||||
}
|
||||
ss << "]" ;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned getLinearIndexFromStates (void)
|
||||
{
|
||||
unsigned prod = 1;
|
||||
unsigned linearIndex = 0;
|
||||
for (int i = states_.size() - 1; i >= 0; i--) {
|
||||
linearIndex += states_[i] * prod;
|
||||
prod *= ranges_[i];
|
||||
}
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
int linearIndex_;
|
||||
int maxIndex_;
|
||||
vector<int> states_;
|
||||
vector<unsigned> ranges_;
|
||||
};
|
||||
|
||||
|
||||
/*
|
||||
FgVarNode* v1 = new FgVarNode (0, 4);
|
||||
FgVarNode* v2 = new FgVarNode (1, 3);
|
||||
FgVarNode* v3 = new FgVarNode (2, 2);
|
||||
FgVarSet vars = {v1,v2,v3};
|
||||
ParamSet params = {
|
||||
0.2, 0.44, 0.1, 0.88, 0.22,0.62,0.32, 0.42, 0.11, 0.88, 0.8,0.5,
|
||||
0.22, 0.4, 0.11, 0.8, 0.224,0.6,0.21, 0.44, 0.14, 0.68, 0.41,0.6
|
||||
};
|
||||
Factor f (vars,params);
|
||||
StatesIndexer idx (vars);
|
||||
while (idx.valid())
|
||||
{
|
||||
cout << idx.toString() << " p=" << params[idx.getLinearIndex()] << endl;
|
||||
idx.incrementVariableState (0);
|
||||
idx.nextSameState (1);
|
||||
++idx;
|
||||
}
|
||||
cout << endl;
|
||||
idx.moveToEnd();
|
||||
while (idx.valid())
|
||||
{
|
||||
cout << idx.toString() << " p=" << params[idx.getLinearIndex()] << endl;
|
||||
idx.decrementVariableState (0);
|
||||
idx.previousSameState (1);
|
||||
--idx;
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
/*
|
||||
FgVarNode* x0 = new FgVarNode (0, 2);
|
||||
FgVarNode* x1 = new FgVarNode (1, 2);
|
||||
FgVarNode* x2 = new FgVarNode (2, 2);
|
||||
FgVarNode* x3 = new FgVarNode (2, 2);
|
||||
FgVarNode* x4 = new FgVarNode (2, 2);
|
||||
FgVarSet vars_ = {x0,x1,x2,x3,x4};
|
||||
ParamSet params_ = {
|
||||
0.2, 0.44, 0.1, 0.88, 0.11, 0.88, 0.8, 0.5,
|
||||
0.2, 0.44, 0.1, 0.88, 0.11, 0.88, 0.8, 0.5,
|
||||
0.2, 0.44, 0.1, 0.88, 0.11, 0.88, 0.8, 0.5,
|
||||
0.2, 0.44, 0.1, 0.88, 0.11, 0.88, 0.8, 0.5
|
||||
};
|
||||
Factor ff (vars_,params_);
|
||||
ff.printFactor();
|
||||
*/
|
||||
|
||||
#endif // HORUS_STATESINDEXER_H
|
||||
|
@ -1,41 +1,66 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "Variable.h"
|
||||
#include "BayesNet.h"
|
||||
#include "VarNode.h"
|
||||
#include "Shared.h"
|
||||
#include "StatesIndexer.h"
|
||||
|
||||
namespace SolverOptions {
|
||||
|
||||
bool runBayesBall = false;
|
||||
bool convertBn2Fg = true;
|
||||
bool compressFactorGraph = true;
|
||||
Schedule schedule = S_SEQ_FIXED;
|
||||
//Schedule schedule = S_SEQ_RANDOM;
|
||||
//Schedule schedule = S_PARALLEL;
|
||||
//Schedule schedule = S_MAX_RESIDUAL;
|
||||
double accuracy = 0.0001;
|
||||
unsigned maxIter = 1000; //FIXME
|
||||
|
||||
namespace InfAlgorithms {
|
||||
InfAlgs infAlgorithm = InfAlgorithms::VE;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::BN_BP;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::FG_BP;
|
||||
//InfAlgs infAlgorithm = InfAlgorithms::CBP;
|
||||
}
|
||||
|
||||
|
||||
namespace BpOptions {
|
||||
Schedule schedule = BpOptions::Schedule::SEQ_FIXED;
|
||||
//Schedule schedule = BpOptions::Schedule::SEQ_RANDOM;
|
||||
//Schedule schedule = BpOptions::Schedule::PARALLEL;
|
||||
//Schedule schedule = BpOptions::Schedule::MAX_RESIDUAL;
|
||||
double accuracy = 0.0001;
|
||||
unsigned maxIter = 1000;
|
||||
bool useAlwaysLoopySolver = true;
|
||||
}
|
||||
|
||||
unsigned Statistics::numCreatedNets = 0;
|
||||
unsigned Statistics::numSolvedPolyTrees = 0;
|
||||
unsigned Statistics::numSolvedLoopyNets = 0;
|
||||
unsigned Statistics::numUnconvergedRuns = 0;
|
||||
unsigned Statistics::maxIterations = 0;
|
||||
unsigned Statistics::totalOfIterations = 0;
|
||||
NumberSpace::ns NSPACE = NumberSpace::NORMAL;
|
||||
|
||||
unordered_map<VarId,VariableInfo> GraphicalModel::varsInfo_;
|
||||
|
||||
vector<NetInfo> Statistics::netInfo_;
|
||||
vector<CompressInfo> Statistics::compressInfo_;
|
||||
StatisticMap Statistics::stats_;
|
||||
|
||||
unsigned Statistics::primaryNetCount_;
|
||||
|
||||
|
||||
namespace Util {
|
||||
|
||||
void
|
||||
toLog (ParamSet& v)
|
||||
{
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = log (v[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
fromLog (ParamSet& v)
|
||||
{
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] = exp (v[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
normalize (ParamSet& v)
|
||||
{
|
||||
double sum = 0.0;
|
||||
double sum;
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
sum = 0.0;
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
sum += v[i];
|
||||
}
|
||||
@ -43,49 +68,124 @@ normalize (ParamSet& v)
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] /= sum;
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
sum = addIdenty();
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
logSum (sum, v[i]);
|
||||
}
|
||||
assert (sum != -numeric_limits<Param>::infinity());
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] -= sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
pow (ParamSet& v, unsigned expoent)
|
||||
{
|
||||
if (expoent == 1) {
|
||||
return; // optimization
|
||||
}
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
double value = 1;
|
||||
double value = 1.0;
|
||||
for (unsigned j = 0; j < expoent; j++) {
|
||||
value *= v[i];
|
||||
}
|
||||
v[i] = value;
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned i = 0; i < v.size(); i++) {
|
||||
v[i] *= expoent;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Param
|
||||
pow (Param p, unsigned expoent)
|
||||
{
|
||||
double value = 1.0;
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned i = 0; i < expoent; i++) {
|
||||
value *= p;
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
value = p * expoent;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
getL1dist (const ParamSet& v1, const ParamSet& v2)
|
||||
getL1Distance (const ParamSet& v1, const ParamSet& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
double dist = 0.0;
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
dist += abs (v1[i] - v2[i]);
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
dist += abs (exp(v1[i]) - exp(v2[i]));
|
||||
}
|
||||
}
|
||||
return dist;
|
||||
}
|
||||
|
||||
|
||||
|
||||
double
|
||||
getMaxNorm (const ParamSet& v1, const ParamSet& v2)
|
||||
{
|
||||
assert (v1.size() == v2.size());
|
||||
double max = 0.0;
|
||||
switch (NSPACE) {
|
||||
case NumberSpace::NORMAL:
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
double diff = abs (v1[i] - v2[i]);
|
||||
if (diff > max) {
|
||||
max = diff;
|
||||
}
|
||||
}
|
||||
break;
|
||||
case NumberSpace::LOGARITHM:
|
||||
for (unsigned i = 0; i < v1.size(); i++) {
|
||||
double diff = abs (exp(v1[i]) - exp(v2[i]));
|
||||
if (diff > max) {
|
||||
max = diff;
|
||||
}
|
||||
}
|
||||
}
|
||||
return max;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
getNumberOfDigits (int number) {
|
||||
unsigned count = 1;
|
||||
while (number >= 10) {
|
||||
number /= 10;
|
||||
count ++;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
isInteger (const string& s)
|
||||
{
|
||||
@ -100,9 +200,10 @@ isInteger (const string& s)
|
||||
|
||||
|
||||
string
|
||||
parametersToString (CParamSet v)
|
||||
parametersToString (const ParamSet& v, unsigned precision)
|
||||
{
|
||||
stringstream ss;
|
||||
ss.precision (precision);
|
||||
ss << "[" ;
|
||||
for (unsigned i = 0; i < v.size() - 1; i++) {
|
||||
ss << v[i] << ", " ;
|
||||
@ -116,12 +217,44 @@ parametersToString (CParamSet v)
|
||||
|
||||
|
||||
|
||||
vector<DConf>
|
||||
getDomainConfigurations (const VarSet& vars)
|
||||
BayesNet*
|
||||
generateBayesianNetworkTreeWithLevel (unsigned level)
|
||||
{
|
||||
BayesNet* bn = new BayesNet();
|
||||
Distribution* dist = new Distribution (ParamSet() = {0.1, 0.5, 0.2, 0.7});
|
||||
BayesNode* root = bn->addNode (0, 2, -1, BnNodeSet() = {},
|
||||
new Distribution (ParamSet() = {0.1, 0.5}));
|
||||
BnNodeSet prevLevel = { root };
|
||||
BnNodeSet currLevel;
|
||||
VarId vidCount = 1;
|
||||
for (unsigned l = 1; l < level; l++) {
|
||||
currLevel.clear();
|
||||
for (unsigned i = 0; i < prevLevel.size(); i++) {
|
||||
currLevel.push_back (
|
||||
bn->addNode (vidCount, 2, -1, BnNodeSet() = {prevLevel[i]}, dist));
|
||||
vidCount ++;
|
||||
currLevel.push_back (
|
||||
bn->addNode (vidCount, 2, -1, BnNodeSet() = {prevLevel[i]}, dist));
|
||||
vidCount ++;
|
||||
}
|
||||
prevLevel = currLevel;
|
||||
}
|
||||
for (unsigned i = 0; i < prevLevel.size(); i++) {
|
||||
prevLevel[i]->setEvidence (0);
|
||||
}
|
||||
bn->setIndexes();
|
||||
return bn;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<DConf>
|
||||
getDomainConfigurations (const VarNodes& vars)
|
||||
{
|
||||
// TODO this method must die
|
||||
unsigned nConfs = 1;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
nConfs *= vars[i]->getDomainSize();
|
||||
nConfs *= vars[i]->nrStates();
|
||||
}
|
||||
|
||||
vector<DConf> confs (nConfs);
|
||||
@ -133,59 +266,213 @@ getDomainConfigurations (const VarSet& vars)
|
||||
for (int i = vars.size() - 1; i >= 0; i--) {
|
||||
unsigned index = 0;
|
||||
while (index < nConfs) {
|
||||
for (unsigned j = 0; j < vars[i]->getDomainSize(); j++) {
|
||||
for (unsigned j = 0; j < vars[i]->nrStates(); j++) {
|
||||
for (unsigned r = 0; r < nReps; r++) {
|
||||
confs[index][i] = j;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
nReps *= vars[i]->getDomainSize();
|
||||
nReps *= vars[i]->nrStates();
|
||||
}
|
||||
return confs;
|
||||
}
|
||||
|
||||
|
||||
|
||||
vector<string>
|
||||
getInstantiations (const VarSet& vars)
|
||||
getJointStateStrings (const VarNodes& vars)
|
||||
{
|
||||
//FIXME handle variables without domain
|
||||
/*
|
||||
char c = 'a' ;
|
||||
const DConf& conf = entries[i].getDomainConfiguration();
|
||||
for (unsigned j = 0; j < conf.size(); j++) {
|
||||
if (j != 0) ss << "," ;
|
||||
ss << c << conf[j] + 1;
|
||||
c ++;
|
||||
}
|
||||
*/
|
||||
unsigned rowSize = 1;
|
||||
StatesIndexer idx (vars);
|
||||
vector<string> jointStrings;
|
||||
while (idx.valid()) {
|
||||
stringstream ss;
|
||||
for (unsigned i = 0; i < vars.size(); i++) {
|
||||
rowSize *= vars[i]->getDomainSize();
|
||||
if (i != 0) ss << ", " ;
|
||||
ss << vars[i]->label() << "=" << vars[i]->states()[(idx[i])];
|
||||
}
|
||||
jointStrings.push_back (ss.str());
|
||||
++ idx;
|
||||
}
|
||||
return jointStrings;
|
||||
}
|
||||
|
||||
vector<string> headers (rowSize);
|
||||
|
||||
unsigned nReps = 1;
|
||||
for (int i = vars.size() - 1; i >= 0; i--) {
|
||||
Domain domain = vars[i]->getDomain();
|
||||
unsigned index = 0;
|
||||
while (index < rowSize) {
|
||||
for (unsigned j = 0; j < vars[i]->getDomainSize(); j++) {
|
||||
for (unsigned r = 0; r < nReps; r++) {
|
||||
if (headers[index] != "") {
|
||||
headers[index] = domain[j] + ", " + headers[index];
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
Statistics::getSolvedNetworksCounting (void)
|
||||
{
|
||||
return netInfo_.size();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Statistics::incrementPrimaryNetworksCounting (void)
|
||||
{
|
||||
primaryNetCount_ ++;
|
||||
}
|
||||
|
||||
|
||||
|
||||
unsigned
|
||||
Statistics::getPrimaryNetworksCounting (void)
|
||||
{
|
||||
return primaryNetCount_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Statistics::updateStatistics (unsigned size, bool loopy,
|
||||
unsigned nIters, double time)
|
||||
{
|
||||
netInfo_.push_back (NetInfo (size, loopy, nIters, time));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Statistics::printStatistics (void)
|
||||
{
|
||||
cout << getStatisticString();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Statistics::writeStatisticsToFile (const char* fileName)
|
||||
{
|
||||
ofstream out (fileName);
|
||||
if (!out.is_open()) {
|
||||
cerr << "error: cannot open file to write at " ;
|
||||
cerr << "Statistics::writeStatisticsToFile()" << endl;
|
||||
abort();
|
||||
}
|
||||
out << getStatisticString();
|
||||
out.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
Statistics::updateCompressingStatistics (unsigned nGroundVars,
|
||||
unsigned nGroundFactors,
|
||||
unsigned nClusterVars,
|
||||
unsigned nClusterFactors,
|
||||
unsigned nWithoutNeighs) {
|
||||
compressInfo_.push_back (CompressInfo (nGroundVars, nGroundFactors,
|
||||
nClusterVars, nClusterFactors, nWithoutNeighs));
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
Statistics::getStatisticString (void)
|
||||
{
|
||||
stringstream ss2, ss3, ss4, ss1;
|
||||
ss1 << "running mode: " ;
|
||||
switch (InfAlgorithms::infAlgorithm) {
|
||||
case InfAlgorithms::VE: ss1 << "ve" << endl; break;
|
||||
case InfAlgorithms::BN_BP: ss1 << "bn_bp" << endl; break;
|
||||
case InfAlgorithms::FG_BP: ss1 << "fg_bp" << endl; break;
|
||||
case InfAlgorithms::CBP: ss1 << "cbp" << endl; break;
|
||||
}
|
||||
ss1 << "message schedule: " ;
|
||||
switch (BpOptions::schedule) {
|
||||
case BpOptions::Schedule::SEQ_FIXED: ss1 << "sequential fixed" << endl; break;
|
||||
case BpOptions::Schedule::SEQ_RANDOM: ss1 << "sequential random" << endl; break;
|
||||
case BpOptions::Schedule::PARALLEL: ss1 << "parallel" << endl; break;
|
||||
case BpOptions::Schedule::MAX_RESIDUAL: ss1 << "max residual" << endl; break;
|
||||
}
|
||||
ss1 << "max iterations: " << BpOptions::maxIter << endl;
|
||||
ss1 << "accuracy " << BpOptions::accuracy << endl;
|
||||
if (BpOptions::useAlwaysLoopySolver) {
|
||||
ss1 << "always loopy solver: yes" << endl;
|
||||
} else {
|
||||
headers[index] = domain[j];
|
||||
ss1 << "always loopy solver: no" << endl;
|
||||
}
|
||||
index++;
|
||||
ss1 << endl << endl;
|
||||
|
||||
ss2 << "---------------------------------------------------" << endl;
|
||||
ss2 << " Network information" << endl;
|
||||
ss2 << "---------------------------------------------------" << endl;
|
||||
ss2 << left;
|
||||
ss2 << setw (15) << "Network Size" ;
|
||||
ss2 << setw (9) << "Loopy" ;
|
||||
ss2 << setw (15) << "Iterations" ;
|
||||
ss2 << setw (15) << "Solving Time" ;
|
||||
ss2 << endl;
|
||||
unsigned nLoopyNets = 0;
|
||||
unsigned nUnconvergedRuns = 0;
|
||||
double totalSolvingTime = 0.0;
|
||||
for (unsigned i = 0; i < netInfo_.size(); i++) {
|
||||
ss2 << setw (15) << netInfo_[i].size;
|
||||
if (netInfo_[i].loopy) {
|
||||
ss2 << setw (9) << "yes";
|
||||
nLoopyNets ++;
|
||||
} else {
|
||||
ss2 << setw (9) << "no";
|
||||
}
|
||||
if (netInfo_[i].nIters == 0) {
|
||||
ss2 << setw (15) << "n/a" ;
|
||||
} else {
|
||||
ss2 << setw (15) << netInfo_[i].nIters;
|
||||
if (netInfo_[i].nIters > BpOptions::maxIter) {
|
||||
nUnconvergedRuns ++;
|
||||
}
|
||||
}
|
||||
ss2 << setw (15) << netInfo_[i].time;
|
||||
totalSolvingTime += netInfo_[i].time;
|
||||
ss2 << endl;
|
||||
}
|
||||
nReps *= vars[i]->getDomainSize();
|
||||
ss2 << endl << endl;
|
||||
|
||||
unsigned c1 = 0, c2 = 0, c3 = 0, c4 = 0;
|
||||
if (compressInfo_.size() > 0) {
|
||||
ss3 << "---------------------------------------------------" << endl;
|
||||
ss3 << " Compression information" << endl;
|
||||
ss3 << "---------------------------------------------------" << endl;
|
||||
ss3 << left;
|
||||
ss3 << "Ground Cluster Ground Cluster Neighborless" << endl;
|
||||
ss3 << "Vars Vars Factors Factors Vars" << endl;
|
||||
for (unsigned i = 0; i < compressInfo_.size(); i++) {
|
||||
ss3 << setw (9) << compressInfo_[i].nGroundVars;
|
||||
ss3 << setw (10) << compressInfo_[i].nClusterVars;
|
||||
ss3 << setw (10) << compressInfo_[i].nGroundFactors;
|
||||
ss3 << setw (10) << compressInfo_[i].nClusterFactors;
|
||||
ss3 << setw (10) << compressInfo_[i].nWithoutNeighs;
|
||||
ss3 << endl;
|
||||
c1 += compressInfo_[i].nGroundVars - compressInfo_[i].nWithoutNeighs;
|
||||
c2 += compressInfo_[i].nClusterVars;
|
||||
c3 += compressInfo_[i].nGroundFactors - compressInfo_[i].nWithoutNeighs;
|
||||
c4 += compressInfo_[i].nClusterFactors;
|
||||
if (compressInfo_[i].nWithoutNeighs != 0) {
|
||||
c2 --;
|
||||
c4 --;
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
}
|
||||
ss3 << endl << endl;
|
||||
}
|
||||
|
||||
ss4 << "primary networks: " << primaryNetCount_ << endl;
|
||||
ss4 << "solved networks: " << netInfo_.size() << endl;
|
||||
ss4 << "loopy networks: " << nLoopyNets << endl;
|
||||
ss4 << "unconverged runs: " << nUnconvergedRuns << endl;
|
||||
ss4 << "total solving time: " << totalSolvingTime << endl;
|
||||
if (compressInfo_.size() > 0) {
|
||||
double pc1 = (1.0 - (c2 / (double)c1)) * 100.0;
|
||||
double pc2 = (1.0 - (c4 / (double)c3)) * 100.0;
|
||||
ss4 << setprecision (5);
|
||||
ss4 << "variable compression: " << pc1 << "%" << endl;
|
||||
ss4 << "factor compression: " << pc2 << "%" << endl;
|
||||
}
|
||||
ss4 << endl << endl;
|
||||
|
||||
ss1 << ss4.str() << ss2.str() << ss3.str();
|
||||
return ss1.str();
|
||||
}
|
||||
|
||||
|
211
packages/CLPBN/clpbn/bp/VarElimSolver.cpp
Normal file
211
packages/CLPBN/clpbn/bp/VarElimSolver.cpp
Normal file
@ -0,0 +1,211 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "VarElimSolver.h"
|
||||
#include "ElimGraph.h"
|
||||
#include "Factor.h"
|
||||
|
||||
|
||||
VarElimSolver::VarElimSolver (const BayesNet& bn) : Solver (&bn)
|
||||
{
|
||||
bayesNet_ = &bn;
|
||||
factorGraph_ = new FactorGraph (bn);
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarElimSolver::VarElimSolver (const FactorGraph& fg) : Solver (&fg)
|
||||
{
|
||||
bayesNet_ = 0;
|
||||
factorGraph_ = &fg;
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarElimSolver::~VarElimSolver (void)
|
||||
{
|
||||
if (bayesNet_) {
|
||||
delete factorGraph_;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
VarElimSolver::getPosterioriOf (VarId vid)
|
||||
{
|
||||
FgVarNode* vn = factorGraph_->getFgVarNode (vid);
|
||||
assert (vn);
|
||||
if (vn->hasEvidence()) {
|
||||
ParamSet params (vn->nrStates(), 0.0);
|
||||
params[vn->getEvidence()] = 1.0;
|
||||
return params;
|
||||
}
|
||||
return getJointDistributionOf (VarIdSet() = {vid});
|
||||
}
|
||||
|
||||
|
||||
|
||||
ParamSet
|
||||
VarElimSolver::getJointDistributionOf (const VarIdSet& vids)
|
||||
{
|
||||
factorList_.clear();
|
||||
varFactors_.clear();
|
||||
elimOrder_.clear();
|
||||
createFactorList();
|
||||
introduceEvidence();
|
||||
chooseEliminationOrder (vids);
|
||||
processFactorList (vids);
|
||||
ParamSet params = factorList_.back()->getParameters();
|
||||
factorList_.back()->freeDistribution();
|
||||
delete factorList_.back();
|
||||
Util::normalize (params);
|
||||
return params;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::createFactorList (void)
|
||||
{
|
||||
const FgFacSet& factorNodes = factorGraph_->getFactorNodes();
|
||||
factorList_.reserve (factorNodes.size() * 2);
|
||||
for (unsigned i = 0; i < factorNodes.size(); i++) {
|
||||
factorList_.push_back (new Factor (*factorNodes[i]->factor()));
|
||||
const FgVarSet& neighs = factorNodes[i]->neighbors();
|
||||
for (unsigned j = 0; j < neighs.size(); j++) {
|
||||
unordered_map<VarId,vector<unsigned> >::iterator it
|
||||
= varFactors_.find (neighs[j]->varId());
|
||||
if (it == varFactors_.end()) {
|
||||
it = varFactors_.insert (make_pair (
|
||||
neighs[j]->varId(), vector<unsigned>())).first;
|
||||
}
|
||||
it->second.push_back (i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::introduceEvidence (void)
|
||||
{
|
||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
if (varNodes[i]->hasEvidence()) {
|
||||
const vector<unsigned>& idxs =
|
||||
varFactors_.find (varNodes[i]->varId())->second;
|
||||
for (unsigned j = 0; j < idxs.size(); j++) {
|
||||
Factor* factor = factorList_[idxs[j]];
|
||||
if (factor->nrVariables() == 1) {
|
||||
factorList_[idxs[j]] = 0;
|
||||
} else {
|
||||
factorList_[idxs[j]]->removeInconsistentEntries (
|
||||
varNodes[i]->varId(), varNodes[i]->getEvidence());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::chooseEliminationOrder (const VarIdSet& vids)
|
||||
{
|
||||
if (bayesNet_) {
|
||||
ElimGraph graph = ElimGraph (*bayesNet_);
|
||||
elimOrder_ = graph.getEliminatingOrder (vids);
|
||||
} else {
|
||||
const FgVarSet& varNodes = factorGraph_->getVarNodes();
|
||||
for (unsigned i = 0; i < varNodes.size(); i++) {
|
||||
VarId vid = varNodes[i]->varId();
|
||||
if (std::find (vids.begin(), vids.end(), vid) == vids.end()
|
||||
&& !varNodes[i]->hasEvidence()) {
|
||||
elimOrder_.push_back (vid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::processFactorList (const VarIdSet& vids)
|
||||
{
|
||||
for (unsigned i = 0; i < elimOrder_.size(); i++) {
|
||||
// cout << "-----------------------------------------" << endl;
|
||||
// cout << "Eliminating " << elimOrder_[i];
|
||||
// cout << " in the following factors:" << endl;
|
||||
// printActiveFactors();
|
||||
eliminate (elimOrder_[i]);
|
||||
}
|
||||
Factor* thisIsTheEnd = new Factor();
|
||||
|
||||
for (unsigned i = 0; i < factorList_.size(); i++) {
|
||||
if (factorList_[i]) {
|
||||
thisIsTheEnd->multiplyByFactor (*factorList_[i]);
|
||||
factorList_[i]->freeDistribution();
|
||||
delete factorList_[i];
|
||||
factorList_[i] = 0;
|
||||
}
|
||||
}
|
||||
VarIdSet vidsWithoutEvidence;
|
||||
for (unsigned i = 0; i < vids.size(); i++) {
|
||||
if (factorGraph_->getFgVarNode (vids[i])->hasEvidence() == false) {
|
||||
vidsWithoutEvidence.push_back (vids[i]);
|
||||
}
|
||||
}
|
||||
thisIsTheEnd->orderVariables (vidsWithoutEvidence);
|
||||
factorList_.push_back (thisIsTheEnd);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::eliminate (VarId elimVar)
|
||||
{
|
||||
FgVarNode* vn = factorGraph_->getFgVarNode (elimVar);
|
||||
Factor* result = 0;
|
||||
vector<unsigned>& idxs = varFactors_.find (elimVar)->second;
|
||||
//cout << "eliminating " << setw (5) << elimVar << ":" ;
|
||||
for (unsigned i = 0; i < idxs.size(); i++) {
|
||||
unsigned idx = idxs[i];
|
||||
if (factorList_[idx]) {
|
||||
if (result == 0) {
|
||||
result = new Factor(*factorList_[idx]);
|
||||
//cout << " " << factorList_[idx]->label();
|
||||
} else {
|
||||
result->multiplyByFactor (*factorList_[idx]);
|
||||
//cout << " x " << factorList_[idx]->label();
|
||||
}
|
||||
factorList_[idx]->freeDistribution();
|
||||
delete factorList_[idx];
|
||||
factorList_[idx] = 0;
|
||||
}
|
||||
}
|
||||
if (result != 0 && result->nrVariables() != 1) {
|
||||
result->removeVariable (vn->varId());
|
||||
factorList_.push_back (result);
|
||||
// cout << endl <<" factor size=" << result->size() << endl;
|
||||
const VarIdSet& resultVarIds = result->getVarIds();
|
||||
for (unsigned i = 0; i < resultVarIds.size(); i++) {
|
||||
vector<unsigned>& idxs =
|
||||
varFactors_.find (resultVarIds[i])->second;
|
||||
idxs.push_back (factorList_.size() - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarElimSolver::printActiveFactors (void)
|
||||
{
|
||||
for (unsigned i = 0; i < factorList_.size(); i++) {
|
||||
if (factorList_[i] != 0) {
|
||||
factorList_[i]->printFactor();
|
||||
cout << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
41
packages/CLPBN/clpbn/bp/VarElimSolver.h
Normal file
41
packages/CLPBN/clpbn/bp/VarElimSolver.h
Normal file
@ -0,0 +1,41 @@
|
||||
#ifndef HORUS_VARELIMSOLVER_H
|
||||
#define HORUS_VARELIMSOLVER_H
|
||||
|
||||
#include "unordered_map"
|
||||
|
||||
#include "Solver.h"
|
||||
#include "FactorGraph.h"
|
||||
#include "BayesNet.h"
|
||||
#include "Shared.h"
|
||||
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
class VarElimSolver : public Solver
|
||||
{
|
||||
public:
|
||||
VarElimSolver (const BayesNet&);
|
||||
VarElimSolver (const FactorGraph&);
|
||||
~VarElimSolver (void);
|
||||
void runSolver (void) { }
|
||||
ParamSet getPosterioriOf (VarId);
|
||||
ParamSet getJointDistributionOf (const VarIdSet&);
|
||||
|
||||
private:
|
||||
void createFactorList (void);
|
||||
void introduceEvidence (void);
|
||||
void chooseEliminationOrder (const VarIdSet&);
|
||||
void processFactorList (const VarIdSet&);
|
||||
void eliminate (VarId);
|
||||
void printActiveFactors (void);
|
||||
|
||||
const BayesNet* bayesNet_;
|
||||
const FactorGraph* factorGraph_;
|
||||
vector<Factor*> factorList_;
|
||||
VarIdSet elimOrder_;
|
||||
unordered_map<VarId, vector<unsigned>> varFactors_;
|
||||
};
|
||||
|
||||
#endif // HORUS_VARELIMSOLVER_H
|
||||
|
100
packages/CLPBN/clpbn/bp/VarNode.cpp
Normal file
100
packages/CLPBN/clpbn/bp/VarNode.cpp
Normal file
@ -0,0 +1,100 @@
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
|
||||
#include "VarNode.h"
|
||||
#include "GraphicalModel.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
VarNode::VarNode (const VarNode* v)
|
||||
{
|
||||
varId_ = v->varId();
|
||||
nrStates_ = v->nrStates();
|
||||
evidence_ = v->getEvidence();
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
VarNode::VarNode (VarId varId, unsigned nrStates, int evidence)
|
||||
{
|
||||
assert (nrStates != 0);
|
||||
assert (evidence < (int) nrStates);
|
||||
varId_ = varId;
|
||||
nrStates_ = nrStates;
|
||||
evidence_ = evidence;
|
||||
index_ = std::numeric_limits<unsigned>::max();
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
VarNode::isValidState (int stateIndex)
|
||||
{
|
||||
return stateIndex >= 0 && stateIndex < (int) nrStates_;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool
|
||||
VarNode::isValidState (const string& stateName)
|
||||
{
|
||||
States states = GraphicalModel::getVariableInformation (varId_).states;
|
||||
return find (states.begin(), states.end(), stateName) != states.end();
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarNode::setEvidence (int ev)
|
||||
{
|
||||
assert (ev < (int) nrStates_);
|
||||
evidence_ = ev;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
VarNode::setEvidence (const string& ev)
|
||||
{
|
||||
States states = GraphicalModel::getVariableInformation (varId_).states;
|
||||
for (unsigned i = 0; i < states.size(); i++) {
|
||||
if (states[i] == ev) {
|
||||
evidence_ = i;
|
||||
return;
|
||||
}
|
||||
}
|
||||
assert (false);
|
||||
}
|
||||
|
||||
|
||||
|
||||
string
|
||||
VarNode::label (void) const
|
||||
{
|
||||
if (GraphicalModel::variablesHaveInformation()) {
|
||||
return GraphicalModel::getVariableInformation (varId_).label;
|
||||
}
|
||||
stringstream ss;
|
||||
ss << "x" << varId_;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
States
|
||||
VarNode::states (void) const
|
||||
{
|
||||
if (GraphicalModel::variablesHaveInformation()) {
|
||||
return GraphicalModel::getVariableInformation (varId_).states;
|
||||
}
|
||||
States states;
|
||||
for (unsigned i = 0; i < nrStates_; i++) {
|
||||
stringstream ss;
|
||||
ss << i ;
|
||||
states.push_back (ss.str());
|
||||
}
|
||||
return states;
|
||||
}
|
||||
|
52
packages/CLPBN/clpbn/bp/VarNode.h
Normal file
52
packages/CLPBN/clpbn/bp/VarNode.h
Normal file
@ -0,0 +1,52 @@
|
||||
#ifndef HORUS_VARNODE_H
|
||||
#define HORUS_VARNODE_H
|
||||
|
||||
#include "Shared.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class VarNode
|
||||
{
|
||||
public:
|
||||
VarNode (const VarNode*);
|
||||
VarNode (VarId, unsigned, int = NO_EVIDENCE);
|
||||
virtual ~VarNode (void) {};
|
||||
|
||||
bool isValidState (int);
|
||||
bool isValidState (const string&);
|
||||
void setEvidence (int);
|
||||
void setEvidence (const string&);
|
||||
string label (void) const;
|
||||
States states (void) const;
|
||||
|
||||
unsigned varId (void) const { return varId_; }
|
||||
unsigned nrStates (void) const { return nrStates_; }
|
||||
bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; }
|
||||
int getEvidence (void) const { return evidence_; }
|
||||
unsigned getIndex (void) const { return index_; }
|
||||
void setIndex (unsigned idx) { index_ = idx; }
|
||||
|
||||
operator unsigned () const { return index_; }
|
||||
|
||||
bool operator== (const VarNode& var) const
|
||||
{
|
||||
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
|
||||
return varId_ == var.varId();
|
||||
}
|
||||
|
||||
bool operator!= (const VarNode& var) const
|
||||
{
|
||||
assert (!(varId_ == var.varId() && nrStates_ != var.nrStates()));
|
||||
return varId_ != var.varId();
|
||||
}
|
||||
|
||||
private:
|
||||
VarId varId_;
|
||||
unsigned nrStates_;
|
||||
int evidence_;
|
||||
unsigned index_;
|
||||
|
||||
};
|
||||
|
||||
#endif // BP_VARNODE_H
|
||||
|
@ -1,172 +0,0 @@
|
||||
#ifndef BP_VARIABLE_H
|
||||
#define BP_VARIABLE_H
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "Shared.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
class Variable
|
||||
{
|
||||
public:
|
||||
|
||||
Variable (const Variable* v)
|
||||
{
|
||||
vid_ = v->getVarId();
|
||||
dsize_ = v->getDomainSize();
|
||||
if (v->hasDomain()) {
|
||||
domain_ = v->getDomain();
|
||||
dsize_ = domain_.size();
|
||||
} else {
|
||||
dsize_ = v->getDomainSize();
|
||||
}
|
||||
evidence_ = v->getEvidence();
|
||||
if (v->hasLabel()) {
|
||||
label_ = new string (v->getLabel());
|
||||
} else {
|
||||
label_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Variable (Vid vid)
|
||||
{
|
||||
this->vid_ = vid;
|
||||
this->dsize_ = 0;
|
||||
this->evidence_ = NO_EVIDENCE;
|
||||
this->label_ = 0;
|
||||
}
|
||||
|
||||
Variable (Vid vid, unsigned dsize, int evidence = NO_EVIDENCE,
|
||||
const string& lbl = string())
|
||||
{
|
||||
assert (dsize != 0);
|
||||
assert (evidence < (int)dsize);
|
||||
this->vid_ = vid;
|
||||
this->dsize_ = dsize;
|
||||
this->evidence_ = evidence;
|
||||
if (!lbl.empty()) {
|
||||
this->label_ = new string (lbl);
|
||||
} else {
|
||||
this->label_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
Variable (Vid vid, const Domain& domain, int evidence = NO_EVIDENCE,
|
||||
const string& lbl = string())
|
||||
{
|
||||
assert (!domain.empty());
|
||||
assert (evidence < (int)domain.size());
|
||||
this->vid_ = vid;
|
||||
this->dsize_ = domain.size();
|
||||
this->domain_ = domain;
|
||||
this->evidence_ = evidence;
|
||||
if (!lbl.empty()) {
|
||||
this->label_ = new string (lbl);
|
||||
} else {
|
||||
this->label_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
~Variable (void)
|
||||
{
|
||||
delete label_;
|
||||
}
|
||||
|
||||
unsigned getVarId (void) const { return vid_; }
|
||||
unsigned getIndex (void) const { return index_; }
|
||||
void setIndex (unsigned idx) { index_ = idx; }
|
||||
unsigned getDomainSize (void) const { return dsize_; }
|
||||
bool hasEvidence (void) const { return evidence_ != NO_EVIDENCE; }
|
||||
int getEvidence (void) const { return evidence_; }
|
||||
bool hasDomain (void) const { return !domain_.empty(); }
|
||||
bool hasLabel (void) const { return label_ != 0; }
|
||||
|
||||
bool isValidStateIndex (int index)
|
||||
{
|
||||
return index >= 0 && index < (int)dsize_;
|
||||
}
|
||||
|
||||
bool isValidState (const string& state)
|
||||
{
|
||||
return find (domain_.begin(), domain_.end(), state) != domain_.end();
|
||||
}
|
||||
|
||||
Domain getDomain (void) const
|
||||
{
|
||||
assert (dsize_ != 0);
|
||||
if (domain_.size() == 0) {
|
||||
Domain d;
|
||||
for (unsigned i = 0; i < dsize_; i++) {
|
||||
stringstream ss;
|
||||
ss << "x" << i ;
|
||||
d.push_back (ss.str());
|
||||
}
|
||||
return d;
|
||||
} else {
|
||||
return domain_;
|
||||
}
|
||||
}
|
||||
|
||||
void setDomainSize (unsigned dsize)
|
||||
{
|
||||
assert (dsize != 0);
|
||||
dsize_ = dsize;
|
||||
}
|
||||
|
||||
void setDomain (const Domain& domain)
|
||||
{
|
||||
assert (!domain.empty());
|
||||
domain_ = domain;
|
||||
dsize_ = domain.size();
|
||||
}
|
||||
|
||||
void setEvidence (int ev)
|
||||
{
|
||||
assert (ev < dsize_);
|
||||
evidence_ = ev;
|
||||
}
|
||||
|
||||
void setEvidence (const string& ev)
|
||||
{
|
||||
assert (isValidState (ev));
|
||||
for (unsigned i = 0; i < domain_.size(); i++) {
|
||||
if (domain_[i] == ev) {
|
||||
evidence_ = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void setLabel (const string& label)
|
||||
{
|
||||
label_ = new string (label);
|
||||
}
|
||||
|
||||
string getLabel (void) const
|
||||
{
|
||||
if (label_ == 0) {
|
||||
stringstream ss;
|
||||
ss << "v" << vid_;
|
||||
return ss.str();
|
||||
} else {
|
||||
return *label_;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN (Variable);
|
||||
|
||||
Vid vid_;
|
||||
unsigned dsize_;
|
||||
int evidence_;
|
||||
Domain domain_;
|
||||
string* label_;
|
||||
unsigned index_;
|
||||
|
||||
};
|
||||
|
||||
#endif // BP_VARIABLE_H
|
||||
|
@ -1,147 +0,0 @@
|
||||
|
||||
/*
|
||||
----------------------------------------------------------------
|
||||
|
||||
Notice that the following BSD-style license applies to this one
|
||||
file (callgrind.h) only. The rest of Valgrind is licensed under the
|
||||
terms of the GNU General Public License, version 2, unless
|
||||
otherwise indicated. See the COPYING file in the source
|
||||
distribution for details.
|
||||
|
||||
----------------------------------------------------------------
|
||||
|
||||
This file is part of callgrind, a valgrind tool for cache simulation
|
||||
and call tree tracing.
|
||||
|
||||
Copyright (C) 2003-2010 Josef Weidendorfer. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. The origin of this software must not be misrepresented; you must
|
||||
not claim that you wrote the original software. If you use this
|
||||
software in a product, an acknowledgment in the product
|
||||
documentation would be appreciated but is not required.
|
||||
|
||||
3. Altered source versions must be plainly marked as such, and must
|
||||
not be misrepresented as being the original software.
|
||||
|
||||
4. The name of the author may not be used to endorse or promote
|
||||
products derived from this software without specific prior written
|
||||
permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS
|
||||
OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
|
||||
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
|
||||
GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
|
||||
WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
----------------------------------------------------------------
|
||||
|
||||
Notice that the above BSD-style license applies to this one file
|
||||
(callgrind.h) only. The entire rest of Valgrind is licensed under
|
||||
the terms of the GNU General Public License, version 2. See the
|
||||
COPYING file in the source distribution for details.
|
||||
|
||||
----------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#ifndef __CALLGRIND_H
|
||||
#define __CALLGRIND_H
|
||||
|
||||
#include "valgrind.h"
|
||||
|
||||
/* !! ABIWARNING !! ABIWARNING !! ABIWARNING !! ABIWARNING !!
|
||||
This enum comprises an ABI exported by Valgrind to programs
|
||||
which use client requests. DO NOT CHANGE THE ORDER OF THESE
|
||||
ENTRIES, NOR DELETE ANY -- add new ones at the end.
|
||||
|
||||
The identification ('C','T') for Callgrind has historical
|
||||
reasons: it was called "Calltree" before. Besides, ('C','G') would
|
||||
clash with cachegrind.
|
||||
*/
|
||||
|
||||
typedef
|
||||
enum {
|
||||
VG_USERREQ__DUMP_STATS = VG_USERREQ_TOOL_BASE('C','T'),
|
||||
VG_USERREQ__ZERO_STATS,
|
||||
VG_USERREQ__TOGGLE_COLLECT,
|
||||
VG_USERREQ__DUMP_STATS_AT,
|
||||
VG_USERREQ__START_INSTRUMENTATION,
|
||||
VG_USERREQ__STOP_INSTRUMENTATION
|
||||
} Vg_CallgrindClientRequest;
|
||||
|
||||
/* Dump current state of cost centers, and zero them afterwards */
|
||||
#define CALLGRIND_DUMP_STATS \
|
||||
{unsigned int _qzz_res; \
|
||||
VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \
|
||||
VG_USERREQ__DUMP_STATS, \
|
||||
0, 0, 0, 0, 0); \
|
||||
}
|
||||
|
||||
/* Dump current state of cost centers, and zero them afterwards.
|
||||
The argument is appended to a string stating the reason which triggered
|
||||
the dump. This string is written as a description field into the
|
||||
profile data dump. */
|
||||
#define CALLGRIND_DUMP_STATS_AT(pos_str) \
|
||||
{unsigned int _qzz_res; \
|
||||
VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \
|
||||
VG_USERREQ__DUMP_STATS_AT, \
|
||||
pos_str, 0, 0, 0, 0); \
|
||||
}
|
||||
|
||||
/* Zero cost centers */
|
||||
#define CALLGRIND_ZERO_STATS \
|
||||
{unsigned int _qzz_res; \
|
||||
VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \
|
||||
VG_USERREQ__ZERO_STATS, \
|
||||
0, 0, 0, 0, 0); \
|
||||
}
|
||||
|
||||
/* Toggles collection state.
|
||||
The collection state specifies whether the happening of events
|
||||
should be noted or if they are to be ignored. Events are noted
|
||||
by increment of counters in a cost center */
|
||||
#define CALLGRIND_TOGGLE_COLLECT \
|
||||
{unsigned int _qzz_res; \
|
||||
VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \
|
||||
VG_USERREQ__TOGGLE_COLLECT, \
|
||||
0, 0, 0, 0, 0); \
|
||||
}
|
||||
|
||||
/* Start full callgrind instrumentation if not already switched on.
|
||||
When cache simulation is done, it will flush the simulated cache;
|
||||
this will lead to an artifical cache warmup phase afterwards with
|
||||
cache misses which would not have happened in reality. */
|
||||
#define CALLGRIND_START_INSTRUMENTATION \
|
||||
{unsigned int _qzz_res; \
|
||||
VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \
|
||||
VG_USERREQ__START_INSTRUMENTATION, \
|
||||
0, 0, 0, 0, 0); \
|
||||
}
|
||||
|
||||
/* Stop full callgrind instrumentation if not already switched off.
|
||||
This flushes Valgrinds translation cache, and does no additional
|
||||
instrumentation afterwards, which effectivly will run at the same
|
||||
speed as the "none" tool (ie. at minimal slowdown).
|
||||
Use this to bypass Callgrind aggregation for uninteresting code parts.
|
||||
To start Callgrind in this mode to ignore the setup phase, use
|
||||
the option "--instr-atstart=no". */
|
||||
#define CALLGRIND_STOP_INSTRUMENTATION \
|
||||
{unsigned int _qzz_res; \
|
||||
VALGRIND_DO_CLIENT_REQUEST(_qzz_res, 0, \
|
||||
VG_USERREQ__STOP_INSTRUMENTATION, \
|
||||
0, 0, 0, 0, 0); \
|
||||
}
|
||||
|
||||
#endif /* __CALLGRIND_H */
|
14
packages/CLPBN/clpbn/bp/examples/cbp_example.uai
Normal file
14
packages/CLPBN/clpbn/bp/examples/cbp_example.uai
Normal file
@ -0,0 +1,14 @@
|
||||
MARKOV
|
||||
3
|
||||
2 2 2
|
||||
2
|
||||
2 0 1
|
||||
2 2 1
|
||||
|
||||
|
||||
4
|
||||
1.2 1.4 2.0 0.4
|
||||
|
||||
4
|
||||
1.2 1.4 2.0 0.4
|
||||
|
@ -1,53 +0,0 @@
|
||||
|
||||
:- use_module(library(clpbn)).
|
||||
|
||||
:- set_clpbn_flag(solver, bp).
|
||||
|
||||
%
|
||||
% A E
|
||||
% / \ /
|
||||
% / \ /
|
||||
% B C
|
||||
% \ /
|
||||
% \ /
|
||||
% D
|
||||
%
|
||||
|
||||
a(A) :-
|
||||
a_table(ADist),
|
||||
{ A = a with p([a1, a2], ADist) }.
|
||||
|
||||
b(B) :-
|
||||
a(A),
|
||||
b_table(BDist),
|
||||
{ B = b with p([b1, b2], BDist, [A]) }.
|
||||
|
||||
c(C) :-
|
||||
a(A),
|
||||
c_table(CDist),
|
||||
{ C = c with p([c1, c2], CDist, [A]) }.
|
||||
|
||||
d(D) :-
|
||||
b(B),
|
||||
c(C),
|
||||
d_table(DDist),
|
||||
{ D = d with p([d1, d2], DDist, [B, C]) }.
|
||||
|
||||
e(E) :-
|
||||
e_table(EDist),
|
||||
{ E = e with p([e1, e2], EDist) }.
|
||||
|
||||
|
||||
a_table([0.005, 0.995]).
|
||||
|
||||
b_table([0.02, 0.97,
|
||||
0.88, 0.03]).
|
||||
|
||||
c_table([0.55, 0.94,
|
||||
0.45, 0.06]).
|
||||
|
||||
d_table([0.192, 0.98, 0.33, 0.013,
|
||||
0.908, 0.02, 0.77, 0.987]).
|
||||
|
||||
e_table([0.055, 0.945]).
|
||||
|
@ -0,0 +1,60 @@
|
||||
#!/bin/bash
|
||||
|
||||
cp ~/bin/yap ~/bin/town_comp
|
||||
YAP=~/bin/town_comp
|
||||
|
||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
||||
OUT_FILE_NAME=bp_compress.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
|
||||
|
||||
function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_solver_parameter\(run_mode,$4\)
|
||||
extra_flag2=clpbn_bp:set_solver_parameter\(schedule,$5\)
|
||||
extra_flag3=clpbn_bp:set_solver_parameter\(always_loopy_solver,$6\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
extra_flag3=true
|
||||
fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_bp:use_log_space,
|
||||
$extra_flag1, $extra_flag2, $extra_flag3,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
format(S, '$3: ~15+ ',[]),
|
||||
close(S).
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
||||
run_solver town_2500000 $1 town_2500000 $3 $4 $5
|
||||
run_solver town_5000000 $1 town_5000000 $3 $4 $5
|
||||
run_solver town_7500000 $1 town_7500000 $3 $4 $5
|
||||
run_solver town_10000000 $1 town_10000000 $3 $4 $5
|
||||
}
|
||||
|
||||
run_solver town_10000 "bp(compress,seq_fixed)" town_10000 compress seq_fixed true
|
||||
exit
|
||||
|
||||
##########
|
||||
run_all_graphs bp "bp(compress,seq_fixed) " compress seq_fixed true
|
||||
|
51
packages/CLPBN/clpbn/bp/examples/town/run_town_bp_convert.sh
Normal file
51
packages/CLPBN/clpbn/bp/examples/town/run_town_bp_convert.sh
Normal file
@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
|
||||
YAP=~/bin/town_conv
|
||||
|
||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
||||
OUT_FILE_NAME=bp_convert.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
|
||||
|
||||
function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_solver_parameter\(run_mode,$4\)
|
||||
extra_flag2=clpbn_bp:set_solver_parameter\(schedule,$5\)
|
||||
extra_flag3=clpbn_bp:set_solver_parameter\(always_loopy_solver,$6\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
extra_flag3=true
|
||||
fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_bp:use_log_space,
|
||||
$extra_flag1, $extra_flag2, $extra_flag3,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
format(S, '$3: ~15+ ',[]),
|
||||
close(S).
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
||||
}
|
||||
|
||||
run_all_graphs bp "bp(convert,seq_fixed) " convert seq_fixed false
|
||||
|
50
packages/CLPBN/clpbn/bp/examples/town/run_town_bp_normal.sh
Normal file
50
packages/CLPBN/clpbn/bp/examples/town/run_town_bp_normal.sh
Normal file
@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
|
||||
YAP=~/bin/town_norm
|
||||
|
||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
||||
OUT_FILE_NAME=bp_normal.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
|
||||
|
||||
function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_solver_parameter\(run_mode,$4\)
|
||||
extra_flag2=clpbn_bp:set_solver_parameter\(schedule,$5\)
|
||||
extra_flag3=clpbn_bp:set_solver_parameter\(always_loopy_solver,$6\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
extra_flag3=true
|
||||
fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_bp:use_log_space,
|
||||
$extra_flag1, $extra_flag2, $extra_flag3,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
format(S, '$3: ~15+ ',[]),
|
||||
close(S).
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
||||
}
|
||||
|
||||
run_all_graphs bp "bp(normal,seq_fixed) " normal seq_fixed false
|
@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
|
||||
YAP=~/bin/town_gibbs
|
||||
|
||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
||||
OUT_FILE_NAME=gibbs.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
|
||||
|
||||
function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_solver_parameter\(run_mode,$4\)
|
||||
extra_flag2=clpbn_bp:set_solver_parameter\(schedule,$5\)
|
||||
extra_flag3=clpbn_bp:set_solver_parameter\(always_loopy_solver,$6\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
extra_flag3=true
|
||||
fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_bp:use_log_space,
|
||||
$extra_flag1, $extra_flag2, $extra_flag3,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
format(S, '$3: ~15+ ',[]),
|
||||
close(S).
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
||||
}
|
||||
|
||||
run_all_graphs gibbs "gibbs "
|
||||
|
51
packages/CLPBN/clpbn/bp/examples/town/run_town_jt_tests.sh
Normal file
51
packages/CLPBN/clpbn/bp/examples/town/run_town_jt_tests.sh
Normal file
@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
|
||||
YAP=~/bin/town_jt
|
||||
|
||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
||||
OUT_FILE_NAME=jt.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
|
||||
|
||||
function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_solver_parameter\(run_mode,$4\)
|
||||
extra_flag2=clpbn_bp:set_solver_parameter\(schedule,$5\)
|
||||
extra_flag3=clpbn_bp:set_solver_parameter\(always_loopy_solver,$6\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
extra_flag3=true
|
||||
fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_bp:use_log_space,
|
||||
$extra_flag1, $extra_flag2, $extra_flag3,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
format(S, '$3: ~15+ ',[]),
|
||||
close(S).
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
||||
}
|
||||
|
||||
run_all_graphs jt "jt "
|
||||
|
51
packages/CLPBN/clpbn/bp/examples/town/run_town_ve_tests.sh
Normal file
51
packages/CLPBN/clpbn/bp/examples/town/run_town_ve_tests.sh
Normal file
@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
|
||||
YAP=~/bin/town_ve
|
||||
|
||||
#OUT_FILE_NAME=results`date "+ %H:%M:%S %d-%m-%Y"`.log
|
||||
OUT_FILE_NAME=ve.log
|
||||
rm -f $OUT_FILE_NAME
|
||||
rm -f ignore.$OUT_FILE_NAME
|
||||
|
||||
|
||||
function run_solver
|
||||
{
|
||||
if [ $2 = bp ]
|
||||
then
|
||||
extra_flag1=clpbn_bp:set_solver_parameter\(run_mode,$4\)
|
||||
extra_flag2=clpbn_bp:set_solver_parameter\(schedule,$5\)
|
||||
extra_flag3=clpbn_bp:set_solver_parameter\(always_loopy_solver,$6\)
|
||||
else
|
||||
extra_flag1=true
|
||||
extra_flag2=true
|
||||
extra_flag3=true
|
||||
fi
|
||||
/usr/bin/time -o $OUT_FILE_NAME -a -f "real:%E\tuser:%U\tsys:%S" $YAP << EOF >> $OUT_FILE_NAME 2>> ignore.$OUT_FILE_NAME
|
||||
[$1].
|
||||
clpbn:set_clpbn_flag(solver,$2),
|
||||
clpbn_bp:use_log_space,
|
||||
$extra_flag1, $extra_flag2, $extra_flag3,
|
||||
run_query(_R),
|
||||
open("$OUT_FILE_NAME", 'append',S),
|
||||
format(S, '$3: ~15+ ',[]),
|
||||
close(S).
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
function run_all_graphs
|
||||
{
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
echo "results for solver $2" >> $OUT_FILE_NAME
|
||||
echo "*******************************************************************" >> "$OUT_FILE_NAME"
|
||||
run_solver town_1000 $1 town_1000 $3 $4 $5
|
||||
run_solver town_5000 $1 town_5000 $3 $4 $5
|
||||
run_solver town_10000 $1 town_10000 $3 $4 $5
|
||||
run_solver town_50000 $1 town_50000 $3 $4 $5
|
||||
run_solver town_100000 $1 town_100000 $3 $4 $5
|
||||
run_solver town_500000 $1 town_500000 $3 $4 $5
|
||||
run_solver town_1000000 $1 town_1000000 $3 $4 $5
|
||||
}
|
||||
|
||||
run_all_graphs ve "ve "
|
||||
|
65
packages/CLPBN/clpbn/bp/examples/town/schema.yap
Normal file
65
packages/CLPBN/clpbn/bp/examples/town/schema.yap
Normal file
@ -0,0 +1,65 @@
|
||||
|
||||
conservative_city(City, Cons) :-
|
||||
cons_table(City, ConsDist),
|
||||
{ Cons = conservative_city(City) with p([y,n], ConsDist) }.
|
||||
|
||||
|
||||
gender(X, Gender) :-
|
||||
gender_table(X, GenderDist),
|
||||
{ Gender = gender(X) with p([m,f], GenderDist) }.
|
||||
|
||||
|
||||
hair_color(X, Color) :-
|
||||
lives(X, City),
|
||||
conservative_city(City, Cons),
|
||||
hair_color_table(X,ColorTable),
|
||||
{ Color = hair_color(X) with
|
||||
p([t,f], ColorTable,[Cons]) }.
|
||||
|
||||
|
||||
car_color(X, Color) :-
|
||||
hair_color(X, HColor),
|
||||
car_color_table(X,CColorTable),
|
||||
{ Color = car_color(X) with
|
||||
p([t,f], CColorTable,[HColor]) }.
|
||||
|
||||
|
||||
height(X, Height) :-
|
||||
gender(X, Gender),
|
||||
height_table(X,HeightTable),
|
||||
{ Height = height(X) with
|
||||
p([t,f], HeightTable,[Gender]) }.
|
||||
|
||||
|
||||
shoe_size(X, Shoesize) :-
|
||||
height(X, Height),
|
||||
shoe_size_table(X,ShoesizeTable),
|
||||
{ Shoesize = shoe_size(X) with
|
||||
p([t,f], ShoesizeTable,[Height]) }.
|
||||
|
||||
|
||||
guilty(X, Guilt) :-
|
||||
guilty_table(X, GuiltDist),
|
||||
{ Guilt = guilty(X) with p([y,n], GuiltDist) }.
|
||||
|
||||
|
||||
descn(X, Descn) :-
|
||||
car_color(X, Car),
|
||||
hair_color(X, Hair),
|
||||
height(X, Height),
|
||||
guilty(X, Guilt),
|
||||
descn_table(X, DescTable),
|
||||
{ Descn = descn(X) with
|
||||
p([t,f], DescTable,[Car,Hair,Height,Guilt]) }.
|
||||
|
||||
|
||||
witness(City, Witness) :-
|
||||
descn(joe, DescnJ),
|
||||
descn(p2, Descn2),
|
||||
wit_table(WitTable),
|
||||
{ Witness = witness(City) with
|
||||
p([t,f], WitTable,[DescnJ, Descn2]) }.
|
||||
|
||||
|
||||
:- ensure_loaded(tables).
|
||||
|
46
packages/CLPBN/clpbn/bp/examples/town/tables.yap
Normal file
46
packages/CLPBN/clpbn/bp/examples/town/tables.yap
Normal file
@ -0,0 +1,46 @@
|
||||
|
||||
cons_table(amsterdam, [0.2, 0.8]) :- !.
|
||||
cons_table(_, [0.8, 0.2]).
|
||||
|
||||
|
||||
gender_table(_, [0.55, 0.44]).
|
||||
|
||||
|
||||
hair_color_table(_,
|
||||
/* conservative_city */
|
||||
/* y n */
|
||||
[ 0.05, 0.1,
|
||||
0.95, 0.9 ]).
|
||||
|
||||
|
||||
car_color_table(_,
|
||||
/* t f */
|
||||
[ 0.9, 0.2,
|
||||
0.1, 0.8 ]).
|
||||
|
||||
|
||||
height_table(_,
|
||||
/* m f */
|
||||
[ 0.6, 0.4,
|
||||
0.4, 0.6 ]).
|
||||
|
||||
|
||||
shoe_size_table(_,
|
||||
/* t f */
|
||||
[ 0.9, 0.1,
|
||||
0.1, 0.9 ]).
|
||||
|
||||
|
||||
guilty_table(_, [0.23, 0.77]).
|
||||
|
||||
|
||||
descn_table(_,
|
||||
/* color, hair, height, guilt */
|
||||
/* ttttt tttf ttft ttff tfttt tftf tfft tfff ttttt fttf ftft ftff ffttt fftf ffft ffff */
|
||||
[ 0.99, 0.5, 0.23, 0.88, 0.41, 0.3, 0.76, 0.87, 0.44, 0.43, 0.29, 0.72, 0.33, 0.91, 0.95, 0.92,
|
||||
0.01, 0.5, 0.77, 0.12, 0.59, 0.7, 0.24, 0.13, 0.56, 0.57, 0.61, 0.28, 0.77, 0.09, 0.05, 0.08]).
|
||||
|
||||
|
||||
wit_table([0.2, 0.45, 0.24, 0.34,
|
||||
0.8, 0.55, 0.76, 0.66]).
|
||||
|
59
packages/CLPBN/clpbn/bp/examples/town/town_generator.sh
Normal file
59
packages/CLPBN/clpbn/bp/examples/town/town_generator.sh
Normal file
@ -0,0 +1,59 @@
|
||||
#!/home/tgomes/bin/yap -L --
|
||||
|
||||
/*
|
||||
Steps:
|
||||
1. generate N facts lives(I, nyc), 0 <= I < N.
|
||||
2. generate evidence on descn for N people, *** except for 1 ***
|
||||
3. Run query ?- guilty(joe, Guilty), witness(joe, t), descn(2,t), descn(3, f), descn(4, f) ...
|
||||
*/
|
||||
|
||||
:- initialization(main).
|
||||
|
||||
|
||||
main :-
|
||||
unix(argv([H])),
|
||||
generate_town(H).
|
||||
|
||||
|
||||
generate_town(N) :-
|
||||
atomic_concat(['town_', N, '.yap'], FileName),
|
||||
open(FileName, 'write', S),
|
||||
write(S, ':- source.\n'),
|
||||
write(S, ':- style_check(all).\n'),
|
||||
write(S, ':- yap_flag(unknown,error).\n'),
|
||||
write(S, ':- yap_flag(write_strings,on).\n'),
|
||||
write(S, ':- use_module(library(clpbn)).\n'),
|
||||
write(S, ':- set_clpbn_flag(solver, bp).\n'),
|
||||
write(S, ':- [-schema].\n\n'),
|
||||
write(S, 'lives(_joe, nyc).\n'),
|
||||
atom_number(N, N2),
|
||||
generate_people(S, N2, 2),
|
||||
write(S, '\nrun_query(Guilty) :- \n'),
|
||||
write(S, '\tguilty(joe, Guilty),\n'),
|
||||
write(S, '\twitness(nyc, t),\n'),
|
||||
write(S, '\trunall(X, ev(X)).\n\n\n'),
|
||||
write(S, 'runall(G, Wrapper) :-\n'),
|
||||
write(S, '\tfindall(G, Wrapper, L),\n'),
|
||||
write(S, '\texecute_all(L).\n\n\n'),
|
||||
write(S, 'execute_all([]).\n'),
|
||||
write(S, 'execute_all(G.L) :-\n'),
|
||||
write(S, '\tcall(G),\n'),
|
||||
write(S, '\texecute_all(L).\n\n\n'),
|
||||
generate_query(S, N2, 2),
|
||||
close(S).
|
||||
|
||||
|
||||
generate_people(_, N, Counting1) :- !.
|
||||
generate_people(S, N, Counting) :-
|
||||
format(S, 'lives(p~w, nyc).~n', [Counting]),
|
||||
Counting1 is Counting + 1,
|
||||
generate_people(S, N, Counting1).
|
||||
|
||||
|
||||
generate_query(S, N, Counting) :-
|
||||
Counting > N, !.
|
||||
generate_query(S, N, Counting) :- !,
|
||||
format(S, 'ev(descn(p~w, t)).~n', [Counting]),
|
||||
Counting1 is Counting + 1,
|
||||
generate_query(S, N, Counting1).
|
||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user