prism logical probabilistic system.

This commit is contained in:
Vítor Santos Costa
2011-11-10 12:24:47 +00:00
parent d971219b7e
commit e865248dce
127 changed files with 22788 additions and 0 deletions

View File

@@ -0,0 +1,21 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifndef MP_H
#define MP_H
/*-------------------------------------------------------------------------*/
#include <mpi.h>
/*-------------------------------------------------------------------------*/
#define TAG_GOAL_REQ (1)
#define TAG_GOAL_LEN (2)
#define TAG_GOAL_STR (3)
#define TAG_SWITCH_REQ (4)
#define TAG_SWITCH_RES (5)
/*-------------------------------------------------------------------------*/
#endif /* MP_H */

View File

@@ -0,0 +1,101 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
/* [27 Aug 2007, by yuizumi]
* FIXME: mp_debug() is currently platform-dependent.
*/
#ifdef MPI
#include "up/up.h"
#include "mp/mp.h"
#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <sys/time.h>
#include <unistd.h> /* STDOUT_FILENO */
#include <mpi.h>
/* Currently mpprism works only on Linux systems. */
#define DEV_NULL "/dev/null"
/*-------------------------------------------------------------------------*/
int fd_dup_stdout = -1;
int mp_size;
int mp_rank;
/*-------------------------------------------------------------------------*/
static void close_stdout(void)
{
fd_dup_stdout = dup(STDOUT_FILENO);
if (fd_dup_stdout < 0)
return;
if (freopen(DEV_NULL, "w", stdout) == NULL) {
close(fd_dup_stdout);
fd_dup_stdout = -1;
}
}
/*-------------------------------------------------------------------------*/
void mp_init(int *argc, char **argv[])
{
MPI_Init(argc, argv);
MPI_Comm_size(MPI_COMM_WORLD, &mp_size);
MPI_Comm_rank(MPI_COMM_WORLD, &mp_rank);
if (mp_size < 2) {
printf("Two or more processes required to run mpprism.\n");
MPI_Finalize();
exit(1);
}
if (mp_rank > 0) {
close_stdout();
}
}
void mp_done(void)
{
MPI_Finalize();
}
NORET mp_quit(int status)
{
fprintf(stderr, "The system is aborted by Rank #%d.\n", mp_rank);
MPI_Abort(MPI_COMM_WORLD, status);
exit(status); /* should not reach here */
}
/*-------------------------------------------------------------------------*/
void mp_debug(const char *fmt, ...)
{
#ifdef MP_DEBUG
char str[1024];
va_list ap;
struct timeval tv;
int s, u;
va_start(ap, fmt);
vsnprintf(str, sizeof(str), fmt, ap);
va_end(ap);
gettimeofday(&tv, NULL);
s = tv.tv_sec;
u = tv.tv_usec;
fprintf(stderr, "[RANK:%d] %02d:%02d:%02d.%03d -- %s\n",
mp_rank, (s / 3600) % 24, (s / 60) % 60, s % 60, u / 1000, str);
#endif
}
/*-------------------------------------------------------------------------*/
#endif /* MPI */

View File

@@ -0,0 +1,19 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifndef MP_CORE_H
#define MP_CORE_H
/*-------------------------------------------------------------------------*/
extern int mp_size;
extern int mp_rank;
extern int fd_dup_stdout;
/*-------------------------------------------------------------------------*/
void mp_debug(const char *, ...);
NORET mp_quit(int);
/*-------------------------------------------------------------------------*/
#endif /* MP_CORE_H */

View File

@@ -0,0 +1,256 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifdef MPI
/*------------------------------------------------------------------------*/
#include "bprolog.h"
#include "up/up.h"
#include "up/em.h"
#include "up/graph.h"
#include "mp/mp.h"
#include "mp/mp_core.h"
#include "mp/mp_sw.h"
#include <stdlib.h>
/*------------------------------------------------------------------------*/
int sw_msg_size = 0;
static void * sw_msg_send = NULL;
static void * sw_msg_recv = NULL;
/*------------------------------------------------------------------------*/
/* mic.c (B-Prolog) */
NORET quit(const char *);
/*------------------------------------------------------------------------*/
void alloc_sw_msg_buffers(void)
{
sw_msg_send = MALLOC(sizeof(double) * sw_msg_size);
sw_msg_recv = MALLOC(sizeof(double) * sw_msg_size);
}
void release_sw_msg_buffers(void)
{
free(sw_msg_send);
sw_msg_send = NULL;
free(sw_msg_recv);
sw_msg_recv = NULL;
}
/*------------------------------------------------------------------------*/
void mpm_bcast_fixed(void)
{
SW_INS_PTR sw_ins_ptr;
char *meg_ptr;
int i;
meg_ptr = sw_msg_send;
for (i = 0; i < occ_switch_tab_size; i++) {
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
*(meg_ptr++) = (!!sw_ins_ptr->fixed) | ((!!sw_ins_ptr->fixed_h) << 1);
}
}
MPI_Bcast(sw_msg_send, sw_msg_size, MPI_CHAR, 0, MPI_COMM_WORLD);
mp_debug("mpm_bcast_fixed");
}
void mps_bcast_fixed(void)
{
SW_INS_PTR sw_ins_ptr;
char *meg_ptr;
int i;
MPI_Bcast(sw_msg_recv, sw_msg_size, MPI_CHAR, 0, MPI_COMM_WORLD);
mp_debug("mps_bcast_fixed");
for (i = 0; i < occ_switch_tab_size; i++) {
meg_ptr = sw_msg_recv;
meg_ptr += occ_position[i];
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
sw_ins_ptr->fixed = !!(*meg_ptr & 1);
sw_ins_ptr->fixed_h = !!(*meg_ptr & 2);
meg_ptr++;
}
}
}
void mpm_bcast_inside(void)
{
SW_INS_PTR sw_ins_ptr;
double *meg_ptr;
int i;
meg_ptr = sw_msg_send;
for (i = 0; i < occ_switch_tab_size; i++) {
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
*(meg_ptr++) = sw_ins_ptr->inside;
}
}
MPI_Bcast(sw_msg_send, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);
mp_debug("mpm_bcast_inside");
}
void mps_bcast_inside(void)
{
SW_INS_PTR sw_ins_ptr;
double *meg_ptr;
int i;
MPI_Bcast(sw_msg_recv, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);
mp_debug("mps_bcast_inside");
for (i = 0; i < occ_switch_tab_size; i++) {
meg_ptr = sw_msg_recv;
meg_ptr += occ_position[i];
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
sw_ins_ptr->inside = *(meg_ptr++);
}
}
}
void mpm_bcast_inside_h(void)
{
SW_INS_PTR sw_ins_ptr;
double *meg_ptr;
int i;
meg_ptr = sw_msg_send;
for (i = 0; i < occ_switch_tab_size; i++) {
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
*(meg_ptr++) = sw_ins_ptr->inside_h;
}
}
MPI_Bcast(sw_msg_send, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);
mp_debug("mpm_bcast_inside_h");
}
void mps_bcast_inside_h(void)
{
SW_INS_PTR sw_ins_ptr;
double *meg_ptr;
int i;
MPI_Bcast(sw_msg_recv, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);
mp_debug("mps_bcast_inside_h");
for (i = 0; i < occ_switch_tab_size; i++) {
meg_ptr = sw_msg_recv;
meg_ptr += occ_position[i];
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
sw_ins_ptr->inside_h = *(meg_ptr++);
}
}
}
void mpm_bcast_smooth(void)
{
SW_INS_PTR sw_ins_ptr;
double *meg_ptr;
int i;
meg_ptr = sw_msg_send;
for (i = 0; i < occ_switch_tab_size; i++) {
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
*(meg_ptr++) = sw_ins_ptr->smooth;
}
}
MPI_Bcast(sw_msg_send, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);
mp_debug("mpm_bcast_smooth");
}
void mps_bcast_smooth(void)
{
SW_INS_PTR sw_ins_ptr;
double *meg_ptr;
int i;
MPI_Bcast(sw_msg_recv, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);
mp_debug("mps_bcast_smooth");
for (i = 0; i < occ_switch_tab_size; i++) {
meg_ptr = sw_msg_recv;
meg_ptr += occ_position[i];
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
sw_ins_ptr->smooth = *(meg_ptr++);
}
}
}
/*------------------------------------------------------------------------*/
void clear_sw_msg_send(void)
{
double *meg_ptr;
double *end_ptr;
meg_ptr = sw_msg_send;
end_ptr = meg_ptr + sw_msg_size;
while (meg_ptr != end_ptr) {
*(meg_ptr++) = 0.0;
}
}
void mpm_share_expectation(void)
{
SW_INS_PTR sw_ins_ptr;
double *meg_ptr;
int i;
MPI_Allreduce(sw_msg_send, sw_msg_recv, sw_msg_size, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
meg_ptr = sw_msg_recv;
for (i = 0; i < occ_switch_tab_size; i++) {
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
sw_ins_ptr->total_expect = *(meg_ptr++);
}
}
}
void mps_share_expectation(void)
{
SW_INS_PTR sw_ins_ptr;
double *meg_ptr;
int i;
for (i = 0; i < occ_switch_tab_size; i++) {
meg_ptr = sw_msg_send;
meg_ptr += occ_position[i];
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
*(meg_ptr++) = sw_ins_ptr->total_expect;
}
}
MPI_Allreduce(sw_msg_send, sw_msg_recv, sw_msg_size, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
for (i = 0; i < occ_switch_tab_size; i++) {
meg_ptr = sw_msg_recv;
meg_ptr += occ_position[i];
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
sw_ins_ptr->total_expect = *(meg_ptr++);
}
}
}
double mp_sum_value(double value)
{
double g_value;
MPI_Allreduce(&value, &g_value, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
return g_value;
}
/*------------------------------------------------------------------------*/
#endif /* MPI */

View File

@@ -0,0 +1,29 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifndef MP_EM_AUX_H
#define MP_EM_AUX_H
/*-------------------------------------------------------------------------*/
extern int sw_msg_size;
/*-------------------------------------------------------------------------*/
void alloc_sw_msg_buffers(void);
void release_sw_msg_buffers(void);
void mpm_bcast_fixed(void);
void mps_bcast_fixed(void);
void mpm_bcast_inside(void);
void mps_bcast_inside(void);
void mpm_bcast_inside_h(void);
void mps_bcast_inside_h(void);
void mpm_bcast_smooth(void);
void mps_bcast_smooth(void);
void clear_sw_msg_send(void);
void mpm_share_expectation(void);
void mps_share_expectation(void);
double mp_sum_value(double);
/*-------------------------------------------------------------------------*/
#endif /* MP_EM_AUX_H */

View File

@@ -0,0 +1,265 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifdef MPI
/*------------------------------------------------------------------------*/
#include "bprolog.h"
#include "core/error.h"
#include "up/up.h"
#include "up/em.h"
#include "up/em_aux.h"
#include "up/em_aux_ml.h"
#include "up/em_ml.h"
#include "up/graph.h"
#include "up/flags.h"
#include "up/util.h"
#include "mp/mp.h"
#include "mp/mp_core.h"
#include "mp/mp_em_aux.h"
#include <mpi.h>
/*------------------------------------------------------------------------*/
void mpm_share_preconds_em(int *smooth)
{
int ivals[4];
int ovals[4];
ivals[0] = sw_msg_size;
ivals[1] = 0;
ivals[2] = 0;
ivals[3] = *smooth;
MPI_Allreduce(ivals, ovals, 4, MPI_INT, MPI_SUM, MPI_COMM_WORLD);
sw_msg_size = ovals[0];
num_goals = ovals[1];
failure_observed = ovals[2];
*smooth = ovals[3];
mp_debug("msgsize=%d, #goals=%d, failure=%s, smooth = %s",
sw_msg_size, num_goals, failure_observed ? "on" : "off", *smooth ? "on" : "off");
alloc_sw_msg_buffers();
mpm_bcast_fixed();
if (*smooth) {
mpm_bcast_smooth();
}
}
void mps_share_preconds_em(int *smooth)
{
int ivals[4];
int ovals[4];
ivals[0] = 0;
ivals[1] = num_goals;
ivals[2] = failure_observed;
ivals[3] = 0;
MPI_Allreduce(ivals, ovals, 4, MPI_INT, MPI_SUM, MPI_COMM_WORLD);
sw_msg_size = ovals[0];
num_goals = ovals[1];
failure_observed = ovals[2];
*smooth = ovals[3];
mp_debug("msgsize=%d, #goals=%d, failure=%s, smooth = %s",
sw_msg_size, num_goals, failure_observed ? "on" : "off", *smooth ? "on" : "off");
alloc_sw_msg_buffers();
mps_bcast_fixed();
if (*smooth) {
mps_bcast_smooth();
}
}
/*------------------------------------------------------------------------*/
int mpm_run_em(EM_ENG_PTR emptr)
{
int r, iterate, old_valid, converged, saved=0;
double likelihood, log_prior;
double lambda, old_lambda=0.0;
config_em(emptr);
for (r = 0; r < num_restart; r++) {
SHOW_PROGRESS_HEAD("#em-iters", r);
initialize_params();
mpm_bcast_inside();
clear_sw_msg_send();
itemp = daem ? itemp_init : 1.0;
iterate = 0;
while (1) {
if (daem) {
SHOW_PROGRESS_TEMP(itemp);
}
old_valid = 0;
while (1) {
if (CTRLC_PRESSED) {
SHOW_PROGRESS_INTR();
RET_ERR(err_ctrl_c_pressed);
}
if (failure_observed) {
inside_failure = mp_sum_value(0.0);
}
log_prior = emptr->smooth ? emptr->compute_log_prior() : 0.0;
lambda = mp_sum_value(log_prior);
likelihood = lambda - log_prior;
mp_debug("local lambda = %.9f, lambda = %.9f", log_prior, lambda);
if (verb_em) {
if (emptr->smooth) {
prism_printf("Iteration #%d:\tlog_likelihood=%.9f\tlog_prior=%.9f\tlog_post=%.9f\n", iterate, likelihood, log_prior, lambda);
}
else {
prism_printf("Iteration #%d:\tlog_likelihood=%.9f\n", iterate, likelihood);
}
}
if (!isfinite(lambda)) {
emit_internal_error("invalid log likelihood or log post: %s (at iterateion #%d)",
isnan(lambda) ? "NaN" : "infinity", iterate);
RET_ERR(ierr_invalid_likelihood);
}
if (old_valid && old_lambda - lambda > prism_epsilon) {
emit_error("log likelihood or log post decreased [old: %.9f, new: %.9f] (at iteration #%d)",
old_lambda, lambda, iterate);
RET_ERR(err_invalid_likelihood);
}
if (itemp == 1.0 && likelihood > 0.0) {
emit_error("log likelihood greater than zero [value: %.9f] (at iteration #%d)",
likelihood, iterate);
RET_ERR(err_invalid_likelihood);
}
converged = (old_valid && lambda - old_lambda <= prism_epsilon);
if (converged || REACHED_MAX_ITERATE(iterate)) {
break;
}
old_lambda = lambda;
old_valid = 1;
mpm_share_expectation();
SHOW_PROGRESS(iterate);
RET_ON_ERR(emptr->update_params());
iterate++;
}
if (itemp == 1.0) {
break;
}
itemp *= itemp_rate;
if (itemp >= 1.0) {
itemp = 1.0;
}
}
SHOW_PROGRESS_TAIL(converged, iterate, lambda);
if (r == 0 || lambda > emptr->lambda) {
emptr->lambda = lambda;
emptr->likelihood = likelihood;
emptr->iterate = iterate;
saved = (r < num_restart - 1);
if (saved) {
save_params();
}
}
}
if (saved) {
restore_params();
}
emptr->bic = compute_bic(emptr->likelihood);
emptr->cs = emptr->smooth ? compute_cs(emptr->likelihood) : 0.0;
return BP_TRUE;
}
int mps_run_em(EM_ENG_PTR emptr)
{
int r, iterate, old_valid, converged, saved=0;
double likelihood;
double lambda, old_lambda=0.0;
config_em(emptr);
for (r = 0; r < num_restart; r++) {
mps_bcast_inside();
clear_sw_msg_send();
itemp = daem ? itemp_init : 1.0;
iterate = 0;
while (1) {
old_valid = 0;
while (1) {
RET_ON_ERR(emptr->compute_inside());
RET_ON_ERR(emptr->examine_inside());
if (failure_observed) {
inside_failure = mp_sum_value(inside_failure);
}
likelihood = emptr->compute_likelihood();
lambda = mp_sum_value(likelihood);
mp_debug("local lambda = %.9f, lambda = %.9f", likelihood, lambda);
converged = (old_valid && lambda - old_lambda <= prism_epsilon);
if (converged || REACHED_MAX_ITERATE(iterate)) {
break;
}
old_lambda = lambda;
old_valid = 1;
RET_ON_ERR(emptr->compute_expectation());
mps_share_expectation();
RET_ON_ERR(emptr->update_params());
iterate++;
}
if (itemp == 1.0) {
break;
}
itemp *= itemp_rate;
if (itemp >= 1.0) {
itemp = 1.0;
}
}
if (r == 0 || lambda > emptr->lambda) {
emptr->lambda = lambda;
saved = (r < num_restart - 1);
if (saved) {
save_params();
}
}
}
if (saved) {
restore_params();
}
return BP_TRUE;
}
/*------------------------------------------------------------------------*/
#endif /* MPI */

View File

@@ -0,0 +1,15 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifndef MP_EM_ML_H
#define MP_EM_ML_H
/*-------------------------------------------------------------------------*/
void mpm_share_preconds_em(int *);
void mps_share_preconds_em(int *);
int mpm_run_em(EM_ENG_PTR);
int mps_run_em(EM_ENG_PTR);
/*-------------------------------------------------------------------------*/
#endif /* MP_EM_ML_H */

View File

@@ -0,0 +1,167 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifdef MPI
/*------------------------------------------------------------------------*/
#include "bprolog.h"
#include "up/up.h"
#include "up/em.h"
#include "up/em_aux.h"
#include "up/em_aux_ml.h"
#include "up/em_aux_vb.h"
#include "up/graph.h"
#include "up/flags.h"
#include "mp/mp.h"
#include "mp/mp_core.h"
#include "mp/mp_em_aux.h"
#include "mp/mp_em_ml.h"
#include "mp/mp_em_vb.h"
#include "mp/mp_sw.h"
#include <mpi.h>
/*------------------------------------------------------------------------*/
/* mic.c (B-Prolog) */
NORET myquit(int, const char *);
/*------------------------------------------------------------------------*/
int pc_mpm_prism_em_6(void)
{
struct EM_Engine em_eng;
/* [28 Aug 2007, by yuizumi]
* occ_switches[] will be freed in pc_import_occ_switches/1.
* occ_position[] is not allocated.
*/
RET_ON_ERR(check_smooth(&em_eng.smooth));
mpm_share_preconds_em(&em_eng.smooth);
RET_ON_ERR(mpm_run_em(&em_eng));
release_sw_msg_buffers();
release_num_sw_vals();
return
bpx_unify(bpx_get_call_arg(1,6), bpx_build_integer(em_eng.iterate)) &&
bpx_unify(bpx_get_call_arg(2,6), bpx_build_float(em_eng.lambda)) &&
bpx_unify(bpx_get_call_arg(3,6), bpx_build_float(em_eng.likelihood)) &&
bpx_unify(bpx_get_call_arg(4,6), bpx_build_float(em_eng.bic)) &&
bpx_unify(bpx_get_call_arg(5,6), bpx_build_float(em_eng.cs)) &&
bpx_unify(bpx_get_call_arg(6,6), bpx_build_integer(em_eng.smooth));
}
int pc_mps_prism_em_0(void)
{
struct EM_Engine em_eng;
mps_share_preconds_em(&em_eng.smooth);
RET_ON_ERR(mps_run_em(&em_eng));
release_sw_msg_buffers();
release_occ_switches();
release_num_sw_vals();
release_occ_position();
return BP_TRUE;
}
int pc_mpm_prism_vbem_2(void)
{
struct VBEM_Engine vb_eng;
RET_ON_ERR(check_smooth_vb());
mpm_share_preconds_vbem();
RET_ON_ERR(mpm_run_vbem(&vb_eng));
release_sw_msg_buffers();
release_num_sw_vals();
return
bpx_unify(bpx_get_call_arg(1,2), bpx_build_integer(vb_eng.iterate)) &&
bpx_unify(bpx_get_call_arg(2,2), bpx_build_float(vb_eng.free_energy));
}
int pc_mps_prism_vbem_0(void)
{
struct VBEM_Engine vb_eng;
mps_share_preconds_vbem();
RET_ON_ERR(mps_run_vbem(&vb_eng));
release_sw_msg_buffers();
release_occ_switches();
release_num_sw_vals();
release_occ_position();
return BP_TRUE;
}
int pc_mpm_prism_both_em_2(void)
{
struct VBEM_Engine vb_eng;
RET_ON_ERR(check_smooth_vb());
mpm_share_preconds_vbem();
RET_ON_ERR(mpm_run_vbem(&vb_eng));
get_param_means();
release_sw_msg_buffers();
release_num_sw_vals();
return
bpx_unify(bpx_get_call_arg(1,2), bpx_build_integer(vb_eng.iterate)) &&
bpx_unify(bpx_get_call_arg(2,2), bpx_build_float(vb_eng.free_energy));
}
int pc_mps_prism_both_em_0(void)
{
struct VBEM_Engine vb_eng;
mps_share_preconds_vbem();
RET_ON_ERR(mps_run_vbem(&vb_eng));
get_param_means();
release_sw_msg_buffers();
release_occ_switches();
release_num_sw_vals();
release_occ_position();
return BP_TRUE;
}
/*------------------------------------------------------------------------*/
int pc_mpm_import_graph_stats_4(void)
{
int dummy[4] = { 0 };
int stats[4];
double avg_shared;
MPI_Reduce(dummy, stats, 4, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
avg_shared = (double)(stats[3]) / stats[0];
return
bpx_unify(bpx_get_call_arg(1,4), bpx_build_integer(stats[0])) &&
bpx_unify(bpx_get_call_arg(2,4), bpx_build_integer(stats[1])) &&
bpx_unify(bpx_get_call_arg(3,4), bpx_build_integer(stats[2])) &&
bpx_unify(bpx_get_call_arg(4,4), bpx_build_float(avg_shared));
}
int pc_mps_import_graph_stats_0(void)
{
int dummy[4];
int stats[4];
graph_stats(stats);
MPI_Reduce(stats, dummy, 4, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
mp_debug("# subgoals = %d", stats[0]);
mp_debug("# goal nodes = %d", stats[1]);
mp_debug("# switch nodes = %d", stats[2]);
mp_debug("# sharings = %d", stats[3]);
return BP_TRUE;
}
/*------------------------------------------------------------------------*/
#endif /* MPI */

View File

@@ -0,0 +1,19 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifndef MP_EM_PREDS_H
#define MP_EM_PREDS_H
/*-------------------------------------------------------------------------*/
int pc_mpm_prism_em_6(void);
int pc_mps_prism_em_0(void);
int pc_mpm_prism_vbem_2(void);
int pc_mps_prism_vbem_0(void);
int pc_mpm_prism_both_em_7(void);
int pc_mps_prism_both_em_0(void);
int pc_mpm_import_graph_stats_4(void);
int pc_mps_import_graph_stats_0(void);
/*-------------------------------------------------------------------------*/
#endif /* MP_EM_PREDS_H */

View File

@@ -0,0 +1,256 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifdef MPI
/*------------------------------------------------------------------------*/
#include "bprolog.h"
#include "up/up.h"
#include "up/em.h"
#include "up/em_aux.h"
#include "up/em_aux_vb.h"
#include "up/em_vb.h"
#include "up/graph.h"
#include "up/flags.h"
#include "up/util.h"
#include "mp/mp.h"
#include "mp/mp_core.h"
#include "mp/mp_em_aux.h"
#include <mpi.h>
/*------------------------------------------------------------------------*/
void mpm_share_preconds_vbem(void)
{
int ivals[3];
int ovals[3];
ivals[0] = sw_msg_size;
ivals[1] = 0;
ivals[2] = 0;
MPI_Allreduce(ivals, ovals, 3, MPI_INT, MPI_SUM, MPI_COMM_WORLD);
sw_msg_size = ovals[0];
num_goals = ovals[1];
failure_observed = ovals[2];
mp_debug("msgsize=%d, #goals=%d, failure=%s",
sw_msg_size, num_goals, failure_observed ? "on" : "off");
alloc_sw_msg_buffers();
mpm_bcast_fixed();
}
void mps_share_preconds_vbem(void)
{
int ivals[3];
int ovals[3];
ivals[0] = 0;
ivals[1] = num_goals;
ivals[2] = failure_observed;
MPI_Allreduce(ivals, ovals, 3, MPI_INT, MPI_SUM, MPI_COMM_WORLD);
sw_msg_size = ovals[0];
num_goals = ovals[1];
failure_observed = ovals[2];
mp_debug("msgsize=%d, #goals=%d, failure=%s",
sw_msg_size, num_goals, failure_observed ? "on" : "off");
alloc_sw_msg_buffers();
mps_bcast_fixed();
}
/*------------------------------------------------------------------------*/
int mpm_run_vbem(VBEM_ENG_PTR vbptr)
{
int r, iterate, old_valid, converged, saved=0;
double free_energy, old_free_energy=0.0;
double l0, l1;
config_vbem(vbptr);
for (r = 0; r < num_restart; r++) {
SHOW_PROGRESS_HEAD("#vbem-iters", r);
initialize_hyperparams();
mpm_bcast_inside_h();
mpm_bcast_smooth();
clear_sw_msg_send();
itemp = daem ? itemp_init : 1.0;
iterate = 0;
while (1) {
if (daem) {
SHOW_PROGRESS_TEMP(itemp);
}
old_valid = 0;
while (1) {
if (CTRLC_PRESSED) {
SHOW_PROGRESS_INTR();
RET_ERR(err_ctrl_c_pressed);
}
RET_ON_ERR(vbptr->compute_pi());
if (failure_observed) {
inside_failure = mp_sum_value(0.0);
}
l0 = vbptr->compute_free_energy_l0();
l1 = vbptr->compute_free_energy_l1();
free_energy = mp_sum_value(l0 - l1);
mp_debug("local free_energy = %.9f, free_energy = %.9f", l0 - l1, free_energy);
if (verb_em) {
prism_printf("Iteration #%d:\tfree_energy=%.9f\n", iterate, free_energy);
}
if (!isfinite(free_energy)) {
emit_internal_error("invalid variational free energy: %s (at iteration #%d)",
isnan(free_energy) ? "NaN" : "infinity", iterate);
RET_ERR(err_invalid_free_energy);
}
if (old_valid && old_free_energy - free_energy > prism_epsilon) {
emit_error("variational free energy decreased [old: %.9f, new: %.9f] (at iteration #%d)",
old_free_energy, free_energy, iterate);
RET_ERR(err_invalid_free_energy);
}
if (itemp == 1.0 && free_energy > 0.0) {
emit_error("variational free energy greater than zero [value: %.9f] (at iteration #%d)",
free_energy, iterate);
RET_ERR(err_invalid_free_energy);
}
converged = (old_valid && free_energy - old_free_energy <= prism_epsilon);
if (converged || REACHED_MAX_ITERATE(iterate)) {
break;
}
old_free_energy = free_energy;
old_valid = 1;
mpm_share_expectation();
SHOW_PROGRESS(iterate);
RET_ON_ERR(vbptr->update_hyperparams());
iterate++;
}
if (itemp == 1.0) {
break;
}
itemp *= itemp_rate;
if (itemp >= 1.0) {
itemp = 1.0;
}
}
SHOW_PROGRESS_TAIL(converged, iterate, free_energy);
if (r == 0 || free_energy > vbptr->free_energy) {
vbptr->free_energy = free_energy;
vbptr->iterate = iterate;
saved = (r < num_restart - 1);
if (saved) {
save_hyperparams();
}
}
}
if (saved) {
restore_hyperparams();
}
transfer_hyperparams();
return BP_TRUE;
}
int mps_run_vbem(VBEM_ENG_PTR vbptr)
{
int r, iterate, old_valid, converged, saved=0;
double free_energy, old_free_energy=0.0;
double l2;
config_vbem(vbptr);
for (r = 0; r < num_restart; r++) {
mps_bcast_inside_h();
mps_bcast_smooth();
clear_sw_msg_send();
itemp = daem ? itemp_init : 1.0;
iterate = 0;
while (1) {
old_valid = 0;
while (1) {
RET_ON_ERR(vbptr->compute_pi());
RET_ON_ERR(vbptr->compute_inside());
RET_ON_ERR(vbptr->examine_inside());
if (failure_observed) {
inside_failure = mp_sum_value(inside_failure);
}
l2 = vbptr->compute_likelihood() / itemp;
free_energy = mp_sum_value(l2);
mp_debug("local free_energy = %.9f, free_energy = %.9f", l2, free_energy);
converged = (old_valid && free_energy - old_free_energy <= prism_epsilon);
if (converged || REACHED_MAX_ITERATE(iterate)) {
break;
}
old_free_energy = free_energy;
old_valid = 1;
RET_ON_ERR(vbptr->compute_expectation());
mps_share_expectation();
RET_ON_ERR(vbptr->update_hyperparams());
iterate++;
}
if (itemp == 1.0) {
break;
}
itemp *= itemp_rate;
if (itemp >= 1.0) {
itemp = 1.0;
}
}
if (r == 0 || free_energy > vbptr->free_energy) {
vbptr->free_energy = free_energy;
saved = (r < num_restart - 1);
if (saved) {
save_hyperparams();
}
}
}
if (saved) {
restore_hyperparams();
}
transfer_hyperparams();
return BP_TRUE;
}
/*------------------------------------------------------------------------*/
#endif /* MPI */

View File

@@ -0,0 +1,15 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifndef MP_EM_VB_H
#define MP_EM_VB_H
/*-------------------------------------------------------------------------*/
void mpm_share_preconds_vbem(void);
void mps_share_preconds_vbem(void);
int mpm_run_vbem(VBEM_ENG_PTR);
int mps_run_vbem(VBEM_ENG_PTR);
/*-------------------------------------------------------------------------*/
#endif /* MP_EM_VB_H */

View File

@@ -0,0 +1,77 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifdef MPI
/*------------------------------------------------------------------------*/
#include "bprolog.h"
#include "up/flags.h"
#include <mpi.h>
/*------------------------------------------------------------------------*/
#define PUT(msg,pos,type,value) \
MPI_Pack(&(value),1,(type),(msg),sizeof(msg),&(pos),MPI_COMM_WORLD)
#define GET(msg,pos,type,value) \
MPI_Unpack((msg),sizeof(msg),&(pos),&(value),1,(type),MPI_COMM_WORLD)
/*------------------------------------------------------------------------*/
int pc_mpm_share_prism_flags_0(void)
{
char msg[256];
int pos = 0;
PUT( msg , pos , MPI_INT , daem );
PUT( msg , pos , MPI_INT , em_message );
PUT( msg , pos , MPI_INT , em_progress );
PUT( msg , pos , MPI_INT , error_on_cycle );
PUT( msg , pos , MPI_INT , fix_init_order );
PUT( msg , pos , MPI_INT , init_method );
PUT( msg , pos , MPI_DOUBLE , itemp_init );
PUT( msg , pos , MPI_DOUBLE , itemp_rate );
PUT( msg , pos , MPI_INT , log_scale );
PUT( msg , pos , MPI_INT , max_iterate );
PUT( msg , pos , MPI_INT , num_restart );
PUT( msg , pos , MPI_DOUBLE , prism_epsilon );
PUT( msg , pos , MPI_DOUBLE , std_ratio );
PUT( msg , pos , MPI_INT , verb_em );
PUT( msg , pos , MPI_INT , verb_graph );
PUT( msg , pos , MPI_INT , warn );
MPI_Bcast(msg, sizeof(msg), MPI_PACKED, 0, MPI_COMM_WORLD);
return BP_TRUE;
}
int pc_mps_share_prism_flags_0(void)
{
char msg[256];
int pos = 0;
MPI_Bcast(msg, sizeof(msg), MPI_PACKED, 0, MPI_COMM_WORLD);
GET( msg , pos , MPI_INT , daem );
GET( msg , pos , MPI_INT , em_message );
GET( msg , pos , MPI_INT , em_progress );
GET( msg , pos , MPI_INT , error_on_cycle );
GET( msg , pos , MPI_INT , fix_init_order );
GET( msg , pos , MPI_INT , init_method );
GET( msg , pos , MPI_DOUBLE , itemp_init );
GET( msg , pos , MPI_DOUBLE , itemp_rate );
GET( msg , pos , MPI_INT , log_scale );
GET( msg , pos , MPI_INT , max_iterate );
GET( msg , pos , MPI_INT , num_restart );
GET( msg , pos , MPI_DOUBLE , prism_epsilon );
GET( msg , pos , MPI_DOUBLE , std_ratio );
GET( msg , pos , MPI_INT , verb_em );
GET( msg , pos , MPI_INT , verb_graph );
GET( msg , pos , MPI_INT , warn );
return BP_TRUE;
}
/*------------------------------------------------------------------------*/
#endif /* MPI */

View File

@@ -0,0 +1,13 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifndef MP_FLAGS_H
#define MP_FLAGS_H
/*-------------------------------------------------------------------------*/
int pc_mpm_share_prism_flags_0(void);
int pc_mps_share_prism_flags_0(void);
/*-------------------------------------------------------------------------*/
#endif /* MP_FLAGS_H */

View File

@@ -0,0 +1,191 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifdef MPI
#include "bprolog.h"
#include "core/error.h"
#include "up/up.h"
#include "mp/mp.h"
#include "mp/mp_core.h"
#include <unistd.h> /* STDOUT_FILENO */
#include <string.h>
#include <mpi.h>
/*------------------------------------------------------------------------*/
/* cpred.c (B-Prolog) */
int bp_string_2_term(const char *, TERM, TERM);
/*------------------------------------------------------------------------*/
static char str_prealloc[65536];
/*------------------------------------------------------------------------*/
static int send_term(TERM arg, int mode, int rank)
{
char *str;
int len;
str = (char *)bpx_term_2_string(arg);
len = strlen(str);
switch (mode) {
case 0:
MPI_Send (&len, 1 , MPI_INT , rank, TAG_GOAL_LEN, MPI_COMM_WORLD);
MPI_Send ( str, len, MPI_CHAR, rank, TAG_GOAL_STR, MPI_COMM_WORLD);
break;
case 1:
MPI_Bcast(&len, 1 , MPI_INT , rank, MPI_COMM_WORLD);
MPI_Bcast( str, len, MPI_CHAR, rank, MPI_COMM_WORLD);
break;
}
mp_debug("SEND(%d,%d): %s", mode, rank, str);
return BP_TRUE;
}
static int recv_term(TERM arg, int mode, int rank)
{
char *str;
TERM op1, op2;
int len, res;
switch (mode) {
case 0:
MPI_Recv (&len, 1, MPI_INT, rank, TAG_GOAL_LEN, MPI_COMM_WORLD, NULL);
break;
case 1:
MPI_Bcast(&len, 1, MPI_INT, rank, MPI_COMM_WORLD);
break;
}
if (len < sizeof(str_prealloc))
str = str_prealloc;
else {
str = MALLOC(len + 1);
}
switch (mode) {
case 0:
MPI_Recv (str, len, MPI_CHAR, rank, TAG_GOAL_STR, MPI_COMM_WORLD, NULL);
break;
case 1:
MPI_Bcast(str, len, MPI_CHAR, rank, MPI_COMM_WORLD);
break;
}
*(str + len) = '\0';
mp_debug("RECV(%d,%d): %s", mode, rank, str);
op1 = bpx_build_var();
op2 = bpx_build_var();
res = bp_string_2_term(str,op1,op2);
if (str != str_prealloc) {
free(str);
}
if (res == BP_TRUE) {
return bpx_unify(arg, op1);
}
return res;
}
/*------------------------------------------------------------------------*/
int pc_mp_size_1(void)
{
return bpx_unify(bpx_get_call_arg(1,1), bpx_build_integer(mp_size));
}
int pc_mp_rank_1(void)
{
return bpx_unify(bpx_get_call_arg(1,1), bpx_build_integer(mp_rank));
}
int pc_mp_master_0(void)
{
return (mp_rank == 0) ? BP_TRUE : BP_FALSE;
}
int pc_mp_abort_0(void)
{
mp_quit(0);
}
int pc_mp_wtime_1(void)
{
return bpx_unify(bpx_get_call_arg(1,1), bpx_build_float(MPI_Wtime()));
}
int pc_mp_sync_2(void)
{
int args[2], amin[2], amax[2];
args[0] = bpx_get_integer(bpx_get_call_arg(1,2)); /* tag */
args[1] = bpx_get_integer(bpx_get_call_arg(2,2)); /* sync-id */
mp_debug("SYNC(%d,%d): BGN", args[0], args[1]);
MPI_Allreduce(args, amin, 2, MPI_INT, MPI_MIN, MPI_COMM_WORLD);
MPI_Allreduce(args, amax, 2, MPI_INT, MPI_MAX, MPI_COMM_WORLD);
if (amin[0] != amax[0]) {
emit_internal_error("failure on sync (%d,%d)", args[0], args[1]);
RET_INTERNAL_ERR;
}
if (amin[1] < 0) {
return BP_FALSE;
}
if (amin[1] != amax[1]) {
emit_internal_error("failure on sync (%d,%d)", args[0], args[1]);
RET_INTERNAL_ERR;
}
mp_debug("SYNC(%d,%d): END", args[0], args[1]);
return BP_TRUE;
}
int pc_mp_send_goal_1(void)
{
MPI_Status status;
MPI_Recv(NULL, 0, MPI_INT, MPI_ANY_SOURCE, TAG_GOAL_REQ, MPI_COMM_WORLD, &status);
return send_term(bpx_get_call_arg(1,1), 0, status.MPI_SOURCE);
}
int pc_mp_recv_goal_1(void)
{
MPI_Send(NULL, 0, MPI_INT, 0, TAG_GOAL_REQ, MPI_COMM_WORLD);
return recv_term(bpx_get_call_arg(1,1), 0, 0);
}
int pc_mpm_bcast_command_1(void)
{
return send_term(bpx_get_call_arg(1,1), 1, 0);
}
int pc_mps_bcast_command_1(void)
{
return recv_term(bpx_get_call_arg(1,1), 1, 0);
}
int pc_mps_revert_stdout_0(void)
{
if (fd_dup_stdout >= 0) {
dup2(fd_dup_stdout, STDOUT_FILENO);
close(fd_dup_stdout);
fd_dup_stdout = -1;
}
return BP_TRUE;
}
/*------------------------------------------------------------------------*/
#endif /* MPI */

View File

@@ -0,0 +1,22 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifndef MP_PREDS_H
#define MP_PREDS_H
/*-------------------------------------------------------------------------*/
int pc_mp_size_1(void);
int pc_mp_rank_1(void);
int pc_mp_master_0(void);
int pc_mp_abort_0(void);
int pc_mp_wtime_1(void);
int pc_mp_sync_2(void);
int pc_mp_send_goal_1(void);
int pc_mp_recv_goal_1(void);
int pc_mpm_bcast_command_1(void);
int pc_mps_bcast_command_1(void);
int pc_mps_revert_stdout_0(void);
/*-------------------------------------------------------------------------*/
#endif /* MP_PREDS_H */

View File

@@ -0,0 +1,206 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifdef MPI
/*------------------------------------------------------------------------*/
#include "bprolog.h"
#include "core/idtable.h"
#include "core/idtable_preds.h"
#include "up/up.h"
#include "up/em_aux.h"
#include "up/graph.h"
#include "up/flags.h"
#include "mp/mp.h"
#include "mp/mp_core.h"
#include "mp/mp_em_aux.h"
#include <mpi.h>
#include <stdlib.h>
#include <string.h>
/*------------------------------------------------------------------------*/
int *occ_position = NULL;
static int * sizes = NULL;
static int ** swids = NULL;
#define L(i) (sizes[i * 2 + 0]) /* length of the message from RANK #i */
#define N(i) (sizes[i * 2 + 1]) /* number of switches in RANK #i*/
/*------------------------------------------------------------------------*/
/* cpred.c (B-Prolog) */
int bp_string_2_term(const char *, TERM, TERM);
/* mic.c (B-Prolog) */
NORET quit(const char *);
/*------------------------------------------------------------------------*/
static void parse_switch_req(const char *msg, int src)
{
const char *p;
TERM op1, op2;
int i;
swids[src] = MALLOC(sizeof(int) * N(src));
p = msg;
for (i = 0; i < N(src); i++) {
op1 = bpx_build_var();
op2 = bpx_build_var();
bp_string_2_term(p, op1, op2);
swids[src][i] = prism_sw_id_register(op1);
while (*(p++) != '\0') ;
}
}
/*------------------------------------------------------------------------*/
int pc_mp_send_switches_0(void)
{
char *msg, *str;
TERM msw;
int msglen, msgsiz;
int vals[2];
int i, n;
msglen = 0;
msgsiz = 65536;
msg = MALLOC(msgsiz);
for (i = 0; i < occ_switch_tab_size; i++) {
msw = bpx_get_arg(1, prism_sw_ins_term(occ_switches[i]->id));
str = (char *)bpx_term_2_string(msw);
n = strlen(str) + 1;
if (msgsiz <= msglen + n) {
msgsiz = (msglen + n + 65536) & ~65535;
msg = REALLOC(msg, msgsiz);
}
strcpy(msg + msglen, str);
msglen += n;
}
msg[msglen++] = '\0'; /* this is safe */
vals[0] = msglen;
vals[1] = occ_switch_tab_size;
MPI_Gather(vals, 2, MPI_INT, NULL, 0, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Send(msg, msglen, MPI_CHAR, 0, TAG_SWITCH_REQ, MPI_COMM_WORLD);
free(msg);
return BP_TRUE;
}
int pc_mp_recv_switches_0(void)
{
int i, lmax, vals[2];
char *msg;
sizes = MALLOC(sizeof(int) * 2 * mp_size);
swids = MALLOC(sizeof(int *) * mp_size);
MPI_Gather(vals, 2, MPI_INT, sizes, 2, MPI_INT, 0, MPI_COMM_WORLD);
lmax = 0;
for (i = 1; i < mp_size; i++) {
if (lmax < L(i)) {
lmax = L(i);
}
}
msg = MALLOC(lmax);
for (i = 1; i < mp_size; i++) {
MPI_Recv(msg, L(i), MPI_CHAR, i, TAG_SWITCH_REQ, MPI_COMM_WORLD, NULL);
parse_switch_req(msg, i);
}
free(msg);
return BP_TRUE;
}
int pc_mp_send_swlayout_0(void)
{
int i, j, *msg, *pos;
msg = MALLOC(sizeof(int) * sw_tab_size);
pos = MALLOC(sizeof(int) * sw_ins_tab_size);
j = 0;
for (i = 0; i < occ_switch_tab_size; i++) {
pos[occ_switches[i]->id] = j;
j += num_sw_vals[i];
}
sw_msg_size = j;
for (i = 1; i < mp_size; i++) {
for (j = 0; j < N(i); j++) {
msg[j] = pos[switches[swids[i][j]]->id];
}
MPI_Send(msg, N(i), MPI_INT, i, TAG_SWITCH_RES, MPI_COMM_WORLD);
free(swids[i]);
}
free(pos);
free(msg);
free(sizes);
free(swids);
return BP_TRUE;
}
int pc_mp_recv_swlayout_0(void)
{
occ_position = MALLOC(sizeof(int) * occ_switch_tab_size);
MPI_Recv(occ_position, occ_switch_tab_size, MPI_INT, 0, TAG_SWITCH_RES, MPI_COMM_WORLD, NULL);
/* debug */
{
int i;
TERM msw;
for (i = 0; i < occ_switch_tab_size; i++) {
msw = bpx_get_arg(1, prism_sw_ins_term(occ_switches[i]->id));
mp_debug("%s -> %d", bpx_term_2_string(msw), occ_position[i]);
}
}
return BP_TRUE;
}
int pc_mpm_alloc_occ_switches_0(void)
{
occ_switches = MALLOC(sizeof(SW_INS_PTR) * sw_tab_size);
occ_switch_tab_size = sw_tab_size;
memcpy(occ_switches, switches, sizeof(SW_INS_PTR) * sw_tab_size);
if (fix_init_order) {
sort_occ_switches();
}
alloc_num_sw_vals();
return BP_TRUE;
}
void release_occ_position(void)
{
free(occ_position);
occ_position = NULL;
}
/*------------------------------------------------------------------------*/
#endif /* MPI */

View File

@@ -0,0 +1,22 @@
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
#ifndef MP_SW_H
#define MP_SW_H
/*-------------------------------------------------------------------------*/
extern int *occ_position;
/*-------------------------------------------------------------------------*/
int pc_mp_send_switches_0(void);
int pc_mp_recv_switches_0(void);
int pc_mp_send_swlayout_0(void);
int pc_mp_recv_swlayout_0(void);
int pc_mpm_alloc_occ_switches_0(void);
void release_occ_position(void);
/*-------------------------------------------------------------------------*/
#endif /* MP_SW_H */