prism logical probabilistic system.
This commit is contained in:
21
packages/prism/src/c/mp/mp.h
Normal file
21
packages/prism/src/c/mp/mp.h
Normal file
@@ -0,0 +1,21 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifndef MP_H
|
||||
#define MP_H
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
#include <mpi.h>
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
#define TAG_GOAL_REQ (1)
|
||||
#define TAG_GOAL_LEN (2)
|
||||
#define TAG_GOAL_STR (3)
|
||||
|
||||
#define TAG_SWITCH_REQ (4)
|
||||
#define TAG_SWITCH_RES (5)
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MP_H */
|
||||
101
packages/prism/src/c/mp/mp_core.c
Normal file
101
packages/prism/src/c/mp/mp_core.c
Normal file
@@ -0,0 +1,101 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
/* [27 Aug 2007, by yuizumi]
|
||||
* FIXME: mp_debug() is currently platform-dependent.
|
||||
*/
|
||||
|
||||
#ifdef MPI
|
||||
|
||||
#include "up/up.h"
|
||||
#include "mp/mp.h"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdarg.h>
|
||||
#include <sys/time.h>
|
||||
#include <unistd.h> /* STDOUT_FILENO */
|
||||
#include <mpi.h>
|
||||
|
||||
/* Currently mpprism works only on Linux systems. */
|
||||
#define DEV_NULL "/dev/null"
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
int fd_dup_stdout = -1;
|
||||
|
||||
int mp_size;
|
||||
int mp_rank;
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
static void close_stdout(void)
|
||||
{
|
||||
fd_dup_stdout = dup(STDOUT_FILENO);
|
||||
|
||||
if (fd_dup_stdout < 0)
|
||||
return;
|
||||
|
||||
if (freopen(DEV_NULL, "w", stdout) == NULL) {
|
||||
close(fd_dup_stdout);
|
||||
fd_dup_stdout = -1;
|
||||
}
|
||||
}
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
void mp_init(int *argc, char **argv[])
|
||||
{
|
||||
MPI_Init(argc, argv);
|
||||
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &mp_size);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &mp_rank);
|
||||
|
||||
if (mp_size < 2) {
|
||||
printf("Two or more processes required to run mpprism.\n");
|
||||
MPI_Finalize();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (mp_rank > 0) {
|
||||
close_stdout();
|
||||
}
|
||||
}
|
||||
|
||||
void mp_done(void)
|
||||
{
|
||||
MPI_Finalize();
|
||||
}
|
||||
|
||||
NORET mp_quit(int status)
|
||||
{
|
||||
fprintf(stderr, "The system is aborted by Rank #%d.\n", mp_rank);
|
||||
MPI_Abort(MPI_COMM_WORLD, status);
|
||||
exit(status); /* should not reach here */
|
||||
}
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
void mp_debug(const char *fmt, ...)
|
||||
{
|
||||
#ifdef MP_DEBUG
|
||||
char str[1024];
|
||||
va_list ap;
|
||||
struct timeval tv;
|
||||
int s, u;
|
||||
|
||||
va_start(ap, fmt);
|
||||
vsnprintf(str, sizeof(str), fmt, ap);
|
||||
va_end(ap);
|
||||
|
||||
gettimeofday(&tv, NULL);
|
||||
|
||||
s = tv.tv_sec;
|
||||
u = tv.tv_usec;
|
||||
|
||||
fprintf(stderr, "[RANK:%d] %02d:%02d:%02d.%03d -- %s\n",
|
||||
mp_rank, (s / 3600) % 24, (s / 60) % 60, s % 60, u / 1000, str);
|
||||
#endif
|
||||
}
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MPI */
|
||||
19
packages/prism/src/c/mp/mp_core.h
Normal file
19
packages/prism/src/c/mp/mp_core.h
Normal file
@@ -0,0 +1,19 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifndef MP_CORE_H
|
||||
#define MP_CORE_H
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
extern int mp_size;
|
||||
extern int mp_rank;
|
||||
extern int fd_dup_stdout;
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
void mp_debug(const char *, ...);
|
||||
NORET mp_quit(int);
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MP_CORE_H */
|
||||
256
packages/prism/src/c/mp/mp_em_aux.c
Normal file
256
packages/prism/src/c/mp/mp_em_aux.c
Normal file
@@ -0,0 +1,256 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifdef MPI
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
#include "bprolog.h"
|
||||
#include "up/up.h"
|
||||
#include "up/em.h"
|
||||
#include "up/graph.h"
|
||||
#include "mp/mp.h"
|
||||
#include "mp/mp_core.h"
|
||||
#include "mp/mp_sw.h"
|
||||
#include <stdlib.h>
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
int sw_msg_size = 0;
|
||||
static void * sw_msg_send = NULL;
|
||||
static void * sw_msg_recv = NULL;
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
/* mic.c (B-Prolog) */
|
||||
NORET quit(const char *);
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
void alloc_sw_msg_buffers(void)
|
||||
{
|
||||
sw_msg_send = MALLOC(sizeof(double) * sw_msg_size);
|
||||
sw_msg_recv = MALLOC(sizeof(double) * sw_msg_size);
|
||||
}
|
||||
|
||||
void release_sw_msg_buffers(void)
|
||||
{
|
||||
free(sw_msg_send);
|
||||
sw_msg_send = NULL;
|
||||
free(sw_msg_recv);
|
||||
sw_msg_recv = NULL;
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
void mpm_bcast_fixed(void)
|
||||
{
|
||||
SW_INS_PTR sw_ins_ptr;
|
||||
char *meg_ptr;
|
||||
int i;
|
||||
|
||||
meg_ptr = sw_msg_send;
|
||||
|
||||
for (i = 0; i < occ_switch_tab_size; i++) {
|
||||
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
|
||||
*(meg_ptr++) = (!!sw_ins_ptr->fixed) | ((!!sw_ins_ptr->fixed_h) << 1);
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Bcast(sw_msg_send, sw_msg_size, MPI_CHAR, 0, MPI_COMM_WORLD);
|
||||
mp_debug("mpm_bcast_fixed");
|
||||
}
|
||||
|
||||
void mps_bcast_fixed(void)
|
||||
{
|
||||
SW_INS_PTR sw_ins_ptr;
|
||||
char *meg_ptr;
|
||||
int i;
|
||||
|
||||
MPI_Bcast(sw_msg_recv, sw_msg_size, MPI_CHAR, 0, MPI_COMM_WORLD);
|
||||
mp_debug("mps_bcast_fixed");
|
||||
|
||||
for (i = 0; i < occ_switch_tab_size; i++) {
|
||||
meg_ptr = sw_msg_recv;
|
||||
meg_ptr += occ_position[i];
|
||||
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
|
||||
sw_ins_ptr->fixed = !!(*meg_ptr & 1);
|
||||
sw_ins_ptr->fixed_h = !!(*meg_ptr & 2);
|
||||
meg_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void mpm_bcast_inside(void)
|
||||
{
|
||||
SW_INS_PTR sw_ins_ptr;
|
||||
double *meg_ptr;
|
||||
int i;
|
||||
|
||||
meg_ptr = sw_msg_send;
|
||||
|
||||
for (i = 0; i < occ_switch_tab_size; i++) {
|
||||
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
|
||||
*(meg_ptr++) = sw_ins_ptr->inside;
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Bcast(sw_msg_send, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);
|
||||
mp_debug("mpm_bcast_inside");
|
||||
}
|
||||
|
||||
void mps_bcast_inside(void)
|
||||
{
|
||||
SW_INS_PTR sw_ins_ptr;
|
||||
double *meg_ptr;
|
||||
int i;
|
||||
|
||||
MPI_Bcast(sw_msg_recv, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);
|
||||
mp_debug("mps_bcast_inside");
|
||||
|
||||
for (i = 0; i < occ_switch_tab_size; i++) {
|
||||
meg_ptr = sw_msg_recv;
|
||||
meg_ptr += occ_position[i];
|
||||
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
|
||||
sw_ins_ptr->inside = *(meg_ptr++);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void mpm_bcast_inside_h(void)
|
||||
{
|
||||
SW_INS_PTR sw_ins_ptr;
|
||||
double *meg_ptr;
|
||||
int i;
|
||||
|
||||
meg_ptr = sw_msg_send;
|
||||
|
||||
for (i = 0; i < occ_switch_tab_size; i++) {
|
||||
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
|
||||
*(meg_ptr++) = sw_ins_ptr->inside_h;
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Bcast(sw_msg_send, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);
|
||||
mp_debug("mpm_bcast_inside_h");
|
||||
}
|
||||
|
||||
void mps_bcast_inside_h(void)
|
||||
{
|
||||
SW_INS_PTR sw_ins_ptr;
|
||||
double *meg_ptr;
|
||||
int i;
|
||||
|
||||
MPI_Bcast(sw_msg_recv, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);
|
||||
mp_debug("mps_bcast_inside_h");
|
||||
|
||||
for (i = 0; i < occ_switch_tab_size; i++) {
|
||||
meg_ptr = sw_msg_recv;
|
||||
meg_ptr += occ_position[i];
|
||||
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
|
||||
sw_ins_ptr->inside_h = *(meg_ptr++);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void mpm_bcast_smooth(void)
|
||||
{
|
||||
SW_INS_PTR sw_ins_ptr;
|
||||
double *meg_ptr;
|
||||
int i;
|
||||
|
||||
meg_ptr = sw_msg_send;
|
||||
|
||||
for (i = 0; i < occ_switch_tab_size; i++) {
|
||||
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
|
||||
*(meg_ptr++) = sw_ins_ptr->smooth;
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Bcast(sw_msg_send, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);
|
||||
mp_debug("mpm_bcast_smooth");
|
||||
}
|
||||
|
||||
void mps_bcast_smooth(void)
|
||||
{
|
||||
SW_INS_PTR sw_ins_ptr;
|
||||
double *meg_ptr;
|
||||
int i;
|
||||
|
||||
MPI_Bcast(sw_msg_recv, sw_msg_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);
|
||||
mp_debug("mps_bcast_smooth");
|
||||
|
||||
for (i = 0; i < occ_switch_tab_size; i++) {
|
||||
meg_ptr = sw_msg_recv;
|
||||
meg_ptr += occ_position[i];
|
||||
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
|
||||
sw_ins_ptr->smooth = *(meg_ptr++);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
void clear_sw_msg_send(void)
|
||||
{
|
||||
double *meg_ptr;
|
||||
double *end_ptr;
|
||||
|
||||
meg_ptr = sw_msg_send;
|
||||
end_ptr = meg_ptr + sw_msg_size;
|
||||
while (meg_ptr != end_ptr) {
|
||||
*(meg_ptr++) = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
void mpm_share_expectation(void)
|
||||
{
|
||||
SW_INS_PTR sw_ins_ptr;
|
||||
double *meg_ptr;
|
||||
int i;
|
||||
|
||||
MPI_Allreduce(sw_msg_send, sw_msg_recv, sw_msg_size, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
|
||||
|
||||
meg_ptr = sw_msg_recv;
|
||||
|
||||
for (i = 0; i < occ_switch_tab_size; i++) {
|
||||
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
|
||||
sw_ins_ptr->total_expect = *(meg_ptr++);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void mps_share_expectation(void)
|
||||
{
|
||||
SW_INS_PTR sw_ins_ptr;
|
||||
double *meg_ptr;
|
||||
int i;
|
||||
|
||||
for (i = 0; i < occ_switch_tab_size; i++) {
|
||||
meg_ptr = sw_msg_send;
|
||||
meg_ptr += occ_position[i];
|
||||
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
|
||||
*(meg_ptr++) = sw_ins_ptr->total_expect;
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Allreduce(sw_msg_send, sw_msg_recv, sw_msg_size, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
|
||||
|
||||
for (i = 0; i < occ_switch_tab_size; i++) {
|
||||
meg_ptr = sw_msg_recv;
|
||||
meg_ptr += occ_position[i];
|
||||
for (sw_ins_ptr = occ_switches[i]; sw_ins_ptr != NULL; sw_ins_ptr = sw_ins_ptr->next) {
|
||||
sw_ins_ptr->total_expect = *(meg_ptr++);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double mp_sum_value(double value)
|
||||
{
|
||||
double g_value;
|
||||
MPI_Allreduce(&value, &g_value, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
|
||||
return g_value;
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MPI */
|
||||
29
packages/prism/src/c/mp/mp_em_aux.h
Normal file
29
packages/prism/src/c/mp/mp_em_aux.h
Normal file
@@ -0,0 +1,29 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifndef MP_EM_AUX_H
|
||||
#define MP_EM_AUX_H
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
extern int sw_msg_size;
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
void alloc_sw_msg_buffers(void);
|
||||
void release_sw_msg_buffers(void);
|
||||
void mpm_bcast_fixed(void);
|
||||
void mps_bcast_fixed(void);
|
||||
void mpm_bcast_inside(void);
|
||||
void mps_bcast_inside(void);
|
||||
void mpm_bcast_inside_h(void);
|
||||
void mps_bcast_inside_h(void);
|
||||
void mpm_bcast_smooth(void);
|
||||
void mps_bcast_smooth(void);
|
||||
void clear_sw_msg_send(void);
|
||||
void mpm_share_expectation(void);
|
||||
void mps_share_expectation(void);
|
||||
double mp_sum_value(double);
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MP_EM_AUX_H */
|
||||
265
packages/prism/src/c/mp/mp_em_ml.c
Normal file
265
packages/prism/src/c/mp/mp_em_ml.c
Normal file
@@ -0,0 +1,265 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifdef MPI
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
#include "bprolog.h"
|
||||
#include "core/error.h"
|
||||
#include "up/up.h"
|
||||
#include "up/em.h"
|
||||
#include "up/em_aux.h"
|
||||
#include "up/em_aux_ml.h"
|
||||
#include "up/em_ml.h"
|
||||
#include "up/graph.h"
|
||||
#include "up/flags.h"
|
||||
#include "up/util.h"
|
||||
#include "mp/mp.h"
|
||||
#include "mp/mp_core.h"
|
||||
#include "mp/mp_em_aux.h"
|
||||
#include <mpi.h>
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
void mpm_share_preconds_em(int *smooth)
|
||||
{
|
||||
int ivals[4];
|
||||
int ovals[4];
|
||||
|
||||
ivals[0] = sw_msg_size;
|
||||
ivals[1] = 0;
|
||||
ivals[2] = 0;
|
||||
ivals[3] = *smooth;
|
||||
|
||||
MPI_Allreduce(ivals, ovals, 4, MPI_INT, MPI_SUM, MPI_COMM_WORLD);
|
||||
|
||||
sw_msg_size = ovals[0];
|
||||
num_goals = ovals[1];
|
||||
failure_observed = ovals[2];
|
||||
*smooth = ovals[3];
|
||||
|
||||
mp_debug("msgsize=%d, #goals=%d, failure=%s, smooth = %s",
|
||||
sw_msg_size, num_goals, failure_observed ? "on" : "off", *smooth ? "on" : "off");
|
||||
|
||||
alloc_sw_msg_buffers();
|
||||
mpm_bcast_fixed();
|
||||
if (*smooth) {
|
||||
mpm_bcast_smooth();
|
||||
}
|
||||
}
|
||||
|
||||
void mps_share_preconds_em(int *smooth)
|
||||
{
|
||||
int ivals[4];
|
||||
int ovals[4];
|
||||
|
||||
ivals[0] = 0;
|
||||
ivals[1] = num_goals;
|
||||
ivals[2] = failure_observed;
|
||||
ivals[3] = 0;
|
||||
|
||||
MPI_Allreduce(ivals, ovals, 4, MPI_INT, MPI_SUM, MPI_COMM_WORLD);
|
||||
|
||||
sw_msg_size = ovals[0];
|
||||
num_goals = ovals[1];
|
||||
failure_observed = ovals[2];
|
||||
*smooth = ovals[3];
|
||||
|
||||
mp_debug("msgsize=%d, #goals=%d, failure=%s, smooth = %s",
|
||||
sw_msg_size, num_goals, failure_observed ? "on" : "off", *smooth ? "on" : "off");
|
||||
|
||||
alloc_sw_msg_buffers();
|
||||
mps_bcast_fixed();
|
||||
if (*smooth) {
|
||||
mps_bcast_smooth();
|
||||
}
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
int mpm_run_em(EM_ENG_PTR emptr)
|
||||
{
|
||||
int r, iterate, old_valid, converged, saved=0;
|
||||
double likelihood, log_prior;
|
||||
double lambda, old_lambda=0.0;
|
||||
|
||||
config_em(emptr);
|
||||
|
||||
for (r = 0; r < num_restart; r++) {
|
||||
SHOW_PROGRESS_HEAD("#em-iters", r);
|
||||
|
||||
initialize_params();
|
||||
mpm_bcast_inside();
|
||||
clear_sw_msg_send();
|
||||
|
||||
itemp = daem ? itemp_init : 1.0;
|
||||
iterate = 0;
|
||||
|
||||
while (1) {
|
||||
if (daem) {
|
||||
SHOW_PROGRESS_TEMP(itemp);
|
||||
}
|
||||
old_valid = 0;
|
||||
|
||||
while (1) {
|
||||
if (CTRLC_PRESSED) {
|
||||
SHOW_PROGRESS_INTR();
|
||||
RET_ERR(err_ctrl_c_pressed);
|
||||
}
|
||||
|
||||
if (failure_observed) {
|
||||
inside_failure = mp_sum_value(0.0);
|
||||
}
|
||||
|
||||
log_prior = emptr->smooth ? emptr->compute_log_prior() : 0.0;
|
||||
lambda = mp_sum_value(log_prior);
|
||||
likelihood = lambda - log_prior;
|
||||
|
||||
mp_debug("local lambda = %.9f, lambda = %.9f", log_prior, lambda);
|
||||
|
||||
if (verb_em) {
|
||||
if (emptr->smooth) {
|
||||
prism_printf("Iteration #%d:\tlog_likelihood=%.9f\tlog_prior=%.9f\tlog_post=%.9f\n", iterate, likelihood, log_prior, lambda);
|
||||
}
|
||||
else {
|
||||
prism_printf("Iteration #%d:\tlog_likelihood=%.9f\n", iterate, likelihood);
|
||||
}
|
||||
}
|
||||
|
||||
if (!isfinite(lambda)) {
|
||||
emit_internal_error("invalid log likelihood or log post: %s (at iterateion #%d)",
|
||||
isnan(lambda) ? "NaN" : "infinity", iterate);
|
||||
RET_ERR(ierr_invalid_likelihood);
|
||||
}
|
||||
if (old_valid && old_lambda - lambda > prism_epsilon) {
|
||||
emit_error("log likelihood or log post decreased [old: %.9f, new: %.9f] (at iteration #%d)",
|
||||
old_lambda, lambda, iterate);
|
||||
RET_ERR(err_invalid_likelihood);
|
||||
}
|
||||
if (itemp == 1.0 && likelihood > 0.0) {
|
||||
emit_error("log likelihood greater than zero [value: %.9f] (at iteration #%d)",
|
||||
likelihood, iterate);
|
||||
RET_ERR(err_invalid_likelihood);
|
||||
}
|
||||
|
||||
converged = (old_valid && lambda - old_lambda <= prism_epsilon);
|
||||
if (converged || REACHED_MAX_ITERATE(iterate)) {
|
||||
break;
|
||||
}
|
||||
|
||||
old_lambda = lambda;
|
||||
old_valid = 1;
|
||||
|
||||
mpm_share_expectation();
|
||||
|
||||
SHOW_PROGRESS(iterate);
|
||||
RET_ON_ERR(emptr->update_params());
|
||||
iterate++;
|
||||
}
|
||||
|
||||
if (itemp == 1.0) {
|
||||
break;
|
||||
}
|
||||
itemp *= itemp_rate;
|
||||
if (itemp >= 1.0) {
|
||||
itemp = 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
SHOW_PROGRESS_TAIL(converged, iterate, lambda);
|
||||
|
||||
if (r == 0 || lambda > emptr->lambda) {
|
||||
emptr->lambda = lambda;
|
||||
emptr->likelihood = likelihood;
|
||||
emptr->iterate = iterate;
|
||||
|
||||
saved = (r < num_restart - 1);
|
||||
if (saved) {
|
||||
save_params();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (saved) {
|
||||
restore_params();
|
||||
}
|
||||
|
||||
emptr->bic = compute_bic(emptr->likelihood);
|
||||
emptr->cs = emptr->smooth ? compute_cs(emptr->likelihood) : 0.0;
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
int mps_run_em(EM_ENG_PTR emptr)
|
||||
{
|
||||
int r, iterate, old_valid, converged, saved=0;
|
||||
double likelihood;
|
||||
double lambda, old_lambda=0.0;
|
||||
|
||||
config_em(emptr);
|
||||
|
||||
for (r = 0; r < num_restart; r++) {
|
||||
mps_bcast_inside();
|
||||
clear_sw_msg_send();
|
||||
itemp = daem ? itemp_init : 1.0;
|
||||
iterate = 0;
|
||||
|
||||
while (1) {
|
||||
old_valid = 0;
|
||||
|
||||
while (1) {
|
||||
RET_ON_ERR(emptr->compute_inside());
|
||||
RET_ON_ERR(emptr->examine_inside());
|
||||
|
||||
if (failure_observed) {
|
||||
inside_failure = mp_sum_value(inside_failure);
|
||||
}
|
||||
|
||||
likelihood = emptr->compute_likelihood();
|
||||
lambda = mp_sum_value(likelihood);
|
||||
|
||||
mp_debug("local lambda = %.9f, lambda = %.9f", likelihood, lambda);
|
||||
|
||||
converged = (old_valid && lambda - old_lambda <= prism_epsilon);
|
||||
if (converged || REACHED_MAX_ITERATE(iterate)) {
|
||||
break;
|
||||
}
|
||||
|
||||
old_lambda = lambda;
|
||||
old_valid = 1;
|
||||
|
||||
RET_ON_ERR(emptr->compute_expectation());
|
||||
mps_share_expectation();
|
||||
|
||||
RET_ON_ERR(emptr->update_params());
|
||||
iterate++;
|
||||
}
|
||||
|
||||
if (itemp == 1.0) {
|
||||
break;
|
||||
}
|
||||
itemp *= itemp_rate;
|
||||
if (itemp >= 1.0) {
|
||||
itemp = 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
if (r == 0 || lambda > emptr->lambda) {
|
||||
emptr->lambda = lambda;
|
||||
saved = (r < num_restart - 1);
|
||||
if (saved) {
|
||||
save_params();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (saved) {
|
||||
restore_params();
|
||||
}
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MPI */
|
||||
15
packages/prism/src/c/mp/mp_em_ml.h
Normal file
15
packages/prism/src/c/mp/mp_em_ml.h
Normal file
@@ -0,0 +1,15 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifndef MP_EM_ML_H
|
||||
#define MP_EM_ML_H
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
void mpm_share_preconds_em(int *);
|
||||
void mps_share_preconds_em(int *);
|
||||
int mpm_run_em(EM_ENG_PTR);
|
||||
int mps_run_em(EM_ENG_PTR);
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MP_EM_ML_H */
|
||||
167
packages/prism/src/c/mp/mp_em_preds.c
Normal file
167
packages/prism/src/c/mp/mp_em_preds.c
Normal file
@@ -0,0 +1,167 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifdef MPI
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
#include "bprolog.h"
|
||||
#include "up/up.h"
|
||||
#include "up/em.h"
|
||||
#include "up/em_aux.h"
|
||||
#include "up/em_aux_ml.h"
|
||||
#include "up/em_aux_vb.h"
|
||||
#include "up/graph.h"
|
||||
#include "up/flags.h"
|
||||
#include "mp/mp.h"
|
||||
#include "mp/mp_core.h"
|
||||
#include "mp/mp_em_aux.h"
|
||||
#include "mp/mp_em_ml.h"
|
||||
#include "mp/mp_em_vb.h"
|
||||
#include "mp/mp_sw.h"
|
||||
#include <mpi.h>
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
/* mic.c (B-Prolog) */
|
||||
NORET myquit(int, const char *);
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
int pc_mpm_prism_em_6(void)
|
||||
{
|
||||
struct EM_Engine em_eng;
|
||||
|
||||
/* [28 Aug 2007, by yuizumi]
|
||||
* occ_switches[] will be freed in pc_import_occ_switches/1.
|
||||
* occ_position[] is not allocated.
|
||||
*/
|
||||
RET_ON_ERR(check_smooth(&em_eng.smooth));
|
||||
mpm_share_preconds_em(&em_eng.smooth);
|
||||
RET_ON_ERR(mpm_run_em(&em_eng));
|
||||
release_sw_msg_buffers();
|
||||
release_num_sw_vals();
|
||||
|
||||
return
|
||||
bpx_unify(bpx_get_call_arg(1,6), bpx_build_integer(em_eng.iterate)) &&
|
||||
bpx_unify(bpx_get_call_arg(2,6), bpx_build_float(em_eng.lambda)) &&
|
||||
bpx_unify(bpx_get_call_arg(3,6), bpx_build_float(em_eng.likelihood)) &&
|
||||
bpx_unify(bpx_get_call_arg(4,6), bpx_build_float(em_eng.bic)) &&
|
||||
bpx_unify(bpx_get_call_arg(5,6), bpx_build_float(em_eng.cs)) &&
|
||||
bpx_unify(bpx_get_call_arg(6,6), bpx_build_integer(em_eng.smooth));
|
||||
}
|
||||
|
||||
int pc_mps_prism_em_0(void)
|
||||
{
|
||||
struct EM_Engine em_eng;
|
||||
|
||||
mps_share_preconds_em(&em_eng.smooth);
|
||||
RET_ON_ERR(mps_run_em(&em_eng));
|
||||
release_sw_msg_buffers();
|
||||
release_occ_switches();
|
||||
release_num_sw_vals();
|
||||
release_occ_position();
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
int pc_mpm_prism_vbem_2(void)
|
||||
{
|
||||
struct VBEM_Engine vb_eng;
|
||||
|
||||
RET_ON_ERR(check_smooth_vb());
|
||||
mpm_share_preconds_vbem();
|
||||
RET_ON_ERR(mpm_run_vbem(&vb_eng));
|
||||
release_sw_msg_buffers();
|
||||
release_num_sw_vals();
|
||||
|
||||
return
|
||||
bpx_unify(bpx_get_call_arg(1,2), bpx_build_integer(vb_eng.iterate)) &&
|
||||
bpx_unify(bpx_get_call_arg(2,2), bpx_build_float(vb_eng.free_energy));
|
||||
}
|
||||
|
||||
int pc_mps_prism_vbem_0(void)
|
||||
{
|
||||
struct VBEM_Engine vb_eng;
|
||||
|
||||
mps_share_preconds_vbem();
|
||||
RET_ON_ERR(mps_run_vbem(&vb_eng));
|
||||
release_sw_msg_buffers();
|
||||
release_occ_switches();
|
||||
release_num_sw_vals();
|
||||
release_occ_position();
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
int pc_mpm_prism_both_em_2(void)
|
||||
{
|
||||
struct VBEM_Engine vb_eng;
|
||||
|
||||
RET_ON_ERR(check_smooth_vb());
|
||||
mpm_share_preconds_vbem();
|
||||
RET_ON_ERR(mpm_run_vbem(&vb_eng));
|
||||
|
||||
get_param_means();
|
||||
|
||||
release_sw_msg_buffers();
|
||||
release_num_sw_vals();
|
||||
|
||||
return
|
||||
bpx_unify(bpx_get_call_arg(1,2), bpx_build_integer(vb_eng.iterate)) &&
|
||||
bpx_unify(bpx_get_call_arg(2,2), bpx_build_float(vb_eng.free_energy));
|
||||
}
|
||||
|
||||
int pc_mps_prism_both_em_0(void)
|
||||
{
|
||||
struct VBEM_Engine vb_eng;
|
||||
|
||||
mps_share_preconds_vbem();
|
||||
RET_ON_ERR(mps_run_vbem(&vb_eng));
|
||||
|
||||
get_param_means();
|
||||
|
||||
release_sw_msg_buffers();
|
||||
release_occ_switches();
|
||||
release_num_sw_vals();
|
||||
release_occ_position();
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
int pc_mpm_import_graph_stats_4(void)
|
||||
{
|
||||
int dummy[4] = { 0 };
|
||||
int stats[4];
|
||||
double avg_shared;
|
||||
|
||||
MPI_Reduce(dummy, stats, 4, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
|
||||
avg_shared = (double)(stats[3]) / stats[0];
|
||||
|
||||
return
|
||||
bpx_unify(bpx_get_call_arg(1,4), bpx_build_integer(stats[0])) &&
|
||||
bpx_unify(bpx_get_call_arg(2,4), bpx_build_integer(stats[1])) &&
|
||||
bpx_unify(bpx_get_call_arg(3,4), bpx_build_integer(stats[2])) &&
|
||||
bpx_unify(bpx_get_call_arg(4,4), bpx_build_float(avg_shared));
|
||||
}
|
||||
|
||||
int pc_mps_import_graph_stats_0(void)
|
||||
{
|
||||
int dummy[4];
|
||||
int stats[4];
|
||||
|
||||
graph_stats(stats);
|
||||
MPI_Reduce(stats, dummy, 4, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
|
||||
|
||||
mp_debug("# subgoals = %d", stats[0]);
|
||||
mp_debug("# goal nodes = %d", stats[1]);
|
||||
mp_debug("# switch nodes = %d", stats[2]);
|
||||
mp_debug("# sharings = %d", stats[3]);
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MPI */
|
||||
19
packages/prism/src/c/mp/mp_em_preds.h
Normal file
19
packages/prism/src/c/mp/mp_em_preds.h
Normal file
@@ -0,0 +1,19 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifndef MP_EM_PREDS_H
|
||||
#define MP_EM_PREDS_H
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
int pc_mpm_prism_em_6(void);
|
||||
int pc_mps_prism_em_0(void);
|
||||
int pc_mpm_prism_vbem_2(void);
|
||||
int pc_mps_prism_vbem_0(void);
|
||||
int pc_mpm_prism_both_em_7(void);
|
||||
int pc_mps_prism_both_em_0(void);
|
||||
int pc_mpm_import_graph_stats_4(void);
|
||||
int pc_mps_import_graph_stats_0(void);
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MP_EM_PREDS_H */
|
||||
256
packages/prism/src/c/mp/mp_em_vb.c
Normal file
256
packages/prism/src/c/mp/mp_em_vb.c
Normal file
@@ -0,0 +1,256 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifdef MPI
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
#include "bprolog.h"
|
||||
#include "up/up.h"
|
||||
#include "up/em.h"
|
||||
#include "up/em_aux.h"
|
||||
#include "up/em_aux_vb.h"
|
||||
#include "up/em_vb.h"
|
||||
#include "up/graph.h"
|
||||
#include "up/flags.h"
|
||||
#include "up/util.h"
|
||||
#include "mp/mp.h"
|
||||
#include "mp/mp_core.h"
|
||||
#include "mp/mp_em_aux.h"
|
||||
#include <mpi.h>
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
void mpm_share_preconds_vbem(void)
|
||||
{
|
||||
int ivals[3];
|
||||
int ovals[3];
|
||||
|
||||
ivals[0] = sw_msg_size;
|
||||
ivals[1] = 0;
|
||||
ivals[2] = 0;
|
||||
|
||||
MPI_Allreduce(ivals, ovals, 3, MPI_INT, MPI_SUM, MPI_COMM_WORLD);
|
||||
|
||||
sw_msg_size = ovals[0];
|
||||
num_goals = ovals[1];
|
||||
failure_observed = ovals[2];
|
||||
|
||||
mp_debug("msgsize=%d, #goals=%d, failure=%s",
|
||||
sw_msg_size, num_goals, failure_observed ? "on" : "off");
|
||||
|
||||
alloc_sw_msg_buffers();
|
||||
mpm_bcast_fixed();
|
||||
}
|
||||
|
||||
void mps_share_preconds_vbem(void)
|
||||
{
|
||||
int ivals[3];
|
||||
int ovals[3];
|
||||
|
||||
ivals[0] = 0;
|
||||
ivals[1] = num_goals;
|
||||
ivals[2] = failure_observed;
|
||||
|
||||
MPI_Allreduce(ivals, ovals, 3, MPI_INT, MPI_SUM, MPI_COMM_WORLD);
|
||||
|
||||
sw_msg_size = ovals[0];
|
||||
num_goals = ovals[1];
|
||||
failure_observed = ovals[2];
|
||||
|
||||
mp_debug("msgsize=%d, #goals=%d, failure=%s",
|
||||
sw_msg_size, num_goals, failure_observed ? "on" : "off");
|
||||
|
||||
alloc_sw_msg_buffers();
|
||||
mps_bcast_fixed();
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
int mpm_run_vbem(VBEM_ENG_PTR vbptr)
|
||||
{
|
||||
int r, iterate, old_valid, converged, saved=0;
|
||||
double free_energy, old_free_energy=0.0;
|
||||
double l0, l1;
|
||||
|
||||
config_vbem(vbptr);
|
||||
|
||||
for (r = 0; r < num_restart; r++) {
|
||||
SHOW_PROGRESS_HEAD("#vbem-iters", r);
|
||||
|
||||
initialize_hyperparams();
|
||||
mpm_bcast_inside_h();
|
||||
mpm_bcast_smooth();
|
||||
clear_sw_msg_send();
|
||||
|
||||
itemp = daem ? itemp_init : 1.0;
|
||||
iterate = 0;
|
||||
|
||||
while (1) {
|
||||
if (daem) {
|
||||
SHOW_PROGRESS_TEMP(itemp);
|
||||
}
|
||||
old_valid = 0;
|
||||
|
||||
while (1) {
|
||||
if (CTRLC_PRESSED) {
|
||||
SHOW_PROGRESS_INTR();
|
||||
RET_ERR(err_ctrl_c_pressed);
|
||||
}
|
||||
|
||||
RET_ON_ERR(vbptr->compute_pi());
|
||||
|
||||
if (failure_observed) {
|
||||
inside_failure = mp_sum_value(0.0);
|
||||
}
|
||||
|
||||
l0 = vbptr->compute_free_energy_l0();
|
||||
l1 = vbptr->compute_free_energy_l1();
|
||||
free_energy = mp_sum_value(l0 - l1);
|
||||
|
||||
mp_debug("local free_energy = %.9f, free_energy = %.9f", l0 - l1, free_energy);
|
||||
|
||||
if (verb_em) {
|
||||
prism_printf("Iteration #%d:\tfree_energy=%.9f\n", iterate, free_energy);
|
||||
}
|
||||
|
||||
if (!isfinite(free_energy)) {
|
||||
emit_internal_error("invalid variational free energy: %s (at iteration #%d)",
|
||||
isnan(free_energy) ? "NaN" : "infinity", iterate);
|
||||
RET_ERR(err_invalid_free_energy);
|
||||
}
|
||||
if (old_valid && old_free_energy - free_energy > prism_epsilon) {
|
||||
emit_error("variational free energy decreased [old: %.9f, new: %.9f] (at iteration #%d)",
|
||||
old_free_energy, free_energy, iterate);
|
||||
RET_ERR(err_invalid_free_energy);
|
||||
}
|
||||
if (itemp == 1.0 && free_energy > 0.0) {
|
||||
emit_error("variational free energy greater than zero [value: %.9f] (at iteration #%d)",
|
||||
free_energy, iterate);
|
||||
RET_ERR(err_invalid_free_energy);
|
||||
}
|
||||
|
||||
converged = (old_valid && free_energy - old_free_energy <= prism_epsilon);
|
||||
if (converged || REACHED_MAX_ITERATE(iterate)) {
|
||||
break;
|
||||
}
|
||||
|
||||
old_free_energy = free_energy;
|
||||
old_valid = 1;
|
||||
|
||||
mpm_share_expectation();
|
||||
|
||||
SHOW_PROGRESS(iterate);
|
||||
RET_ON_ERR(vbptr->update_hyperparams());
|
||||
|
||||
iterate++;
|
||||
}
|
||||
|
||||
if (itemp == 1.0) {
|
||||
break;
|
||||
}
|
||||
itemp *= itemp_rate;
|
||||
if (itemp >= 1.0) {
|
||||
itemp = 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
SHOW_PROGRESS_TAIL(converged, iterate, free_energy);
|
||||
|
||||
if (r == 0 || free_energy > vbptr->free_energy) {
|
||||
vbptr->free_energy = free_energy;
|
||||
vbptr->iterate = iterate;
|
||||
|
||||
saved = (r < num_restart - 1);
|
||||
if (saved) {
|
||||
save_hyperparams();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (saved) {
|
||||
restore_hyperparams();
|
||||
}
|
||||
|
||||
transfer_hyperparams();
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
int mps_run_vbem(VBEM_ENG_PTR vbptr)
|
||||
{
|
||||
int r, iterate, old_valid, converged, saved=0;
|
||||
double free_energy, old_free_energy=0.0;
|
||||
double l2;
|
||||
|
||||
config_vbem(vbptr);
|
||||
|
||||
for (r = 0; r < num_restart; r++) {
|
||||
mps_bcast_inside_h();
|
||||
mps_bcast_smooth();
|
||||
clear_sw_msg_send();
|
||||
|
||||
itemp = daem ? itemp_init : 1.0;
|
||||
iterate = 0;
|
||||
|
||||
while (1) {
|
||||
old_valid = 0;
|
||||
|
||||
while (1) {
|
||||
RET_ON_ERR(vbptr->compute_pi());
|
||||
RET_ON_ERR(vbptr->compute_inside());
|
||||
RET_ON_ERR(vbptr->examine_inside());
|
||||
|
||||
if (failure_observed) {
|
||||
inside_failure = mp_sum_value(inside_failure);
|
||||
}
|
||||
|
||||
l2 = vbptr->compute_likelihood() / itemp;
|
||||
free_energy = mp_sum_value(l2);
|
||||
|
||||
mp_debug("local free_energy = %.9f, free_energy = %.9f", l2, free_energy);
|
||||
|
||||
converged = (old_valid && free_energy - old_free_energy <= prism_epsilon);
|
||||
if (converged || REACHED_MAX_ITERATE(iterate)) {
|
||||
break;
|
||||
}
|
||||
|
||||
old_free_energy = free_energy;
|
||||
old_valid = 1;
|
||||
|
||||
RET_ON_ERR(vbptr->compute_expectation());
|
||||
mps_share_expectation();
|
||||
|
||||
RET_ON_ERR(vbptr->update_hyperparams());
|
||||
iterate++;
|
||||
}
|
||||
|
||||
if (itemp == 1.0) {
|
||||
break;
|
||||
}
|
||||
itemp *= itemp_rate;
|
||||
if (itemp >= 1.0) {
|
||||
itemp = 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
if (r == 0 || free_energy > vbptr->free_energy) {
|
||||
vbptr->free_energy = free_energy;
|
||||
saved = (r < num_restart - 1);
|
||||
if (saved) {
|
||||
save_hyperparams();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (saved) {
|
||||
restore_hyperparams();
|
||||
}
|
||||
|
||||
transfer_hyperparams();
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MPI */
|
||||
15
packages/prism/src/c/mp/mp_em_vb.h
Normal file
15
packages/prism/src/c/mp/mp_em_vb.h
Normal file
@@ -0,0 +1,15 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifndef MP_EM_VB_H
|
||||
#define MP_EM_VB_H
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
void mpm_share_preconds_vbem(void);
|
||||
void mps_share_preconds_vbem(void);
|
||||
int mpm_run_vbem(VBEM_ENG_PTR);
|
||||
int mps_run_vbem(VBEM_ENG_PTR);
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MP_EM_VB_H */
|
||||
77
packages/prism/src/c/mp/mp_flags.c
Normal file
77
packages/prism/src/c/mp/mp_flags.c
Normal file
@@ -0,0 +1,77 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifdef MPI
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
#include "bprolog.h"
|
||||
#include "up/flags.h"
|
||||
#include <mpi.h>
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
#define PUT(msg,pos,type,value) \
|
||||
MPI_Pack(&(value),1,(type),(msg),sizeof(msg),&(pos),MPI_COMM_WORLD)
|
||||
|
||||
#define GET(msg,pos,type,value) \
|
||||
MPI_Unpack((msg),sizeof(msg),&(pos),&(value),1,(type),MPI_COMM_WORLD)
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
int pc_mpm_share_prism_flags_0(void)
|
||||
{
|
||||
char msg[256];
|
||||
int pos = 0;
|
||||
|
||||
PUT( msg , pos , MPI_INT , daem );
|
||||
PUT( msg , pos , MPI_INT , em_message );
|
||||
PUT( msg , pos , MPI_INT , em_progress );
|
||||
PUT( msg , pos , MPI_INT , error_on_cycle );
|
||||
PUT( msg , pos , MPI_INT , fix_init_order );
|
||||
PUT( msg , pos , MPI_INT , init_method );
|
||||
PUT( msg , pos , MPI_DOUBLE , itemp_init );
|
||||
PUT( msg , pos , MPI_DOUBLE , itemp_rate );
|
||||
PUT( msg , pos , MPI_INT , log_scale );
|
||||
PUT( msg , pos , MPI_INT , max_iterate );
|
||||
PUT( msg , pos , MPI_INT , num_restart );
|
||||
PUT( msg , pos , MPI_DOUBLE , prism_epsilon );
|
||||
PUT( msg , pos , MPI_DOUBLE , std_ratio );
|
||||
PUT( msg , pos , MPI_INT , verb_em );
|
||||
PUT( msg , pos , MPI_INT , verb_graph );
|
||||
PUT( msg , pos , MPI_INT , warn );
|
||||
|
||||
MPI_Bcast(msg, sizeof(msg), MPI_PACKED, 0, MPI_COMM_WORLD);
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
int pc_mps_share_prism_flags_0(void)
|
||||
{
|
||||
char msg[256];
|
||||
int pos = 0;
|
||||
|
||||
MPI_Bcast(msg, sizeof(msg), MPI_PACKED, 0, MPI_COMM_WORLD);
|
||||
|
||||
GET( msg , pos , MPI_INT , daem );
|
||||
GET( msg , pos , MPI_INT , em_message );
|
||||
GET( msg , pos , MPI_INT , em_progress );
|
||||
GET( msg , pos , MPI_INT , error_on_cycle );
|
||||
GET( msg , pos , MPI_INT , fix_init_order );
|
||||
GET( msg , pos , MPI_INT , init_method );
|
||||
GET( msg , pos , MPI_DOUBLE , itemp_init );
|
||||
GET( msg , pos , MPI_DOUBLE , itemp_rate );
|
||||
GET( msg , pos , MPI_INT , log_scale );
|
||||
GET( msg , pos , MPI_INT , max_iterate );
|
||||
GET( msg , pos , MPI_INT , num_restart );
|
||||
GET( msg , pos , MPI_DOUBLE , prism_epsilon );
|
||||
GET( msg , pos , MPI_DOUBLE , std_ratio );
|
||||
GET( msg , pos , MPI_INT , verb_em );
|
||||
GET( msg , pos , MPI_INT , verb_graph );
|
||||
GET( msg , pos , MPI_INT , warn );
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MPI */
|
||||
13
packages/prism/src/c/mp/mp_flags.h
Normal file
13
packages/prism/src/c/mp/mp_flags.h
Normal file
@@ -0,0 +1,13 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifndef MP_FLAGS_H
|
||||
#define MP_FLAGS_H
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
int pc_mpm_share_prism_flags_0(void);
|
||||
int pc_mps_share_prism_flags_0(void);
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MP_FLAGS_H */
|
||||
191
packages/prism/src/c/mp/mp_preds.c
Normal file
191
packages/prism/src/c/mp/mp_preds.c
Normal file
@@ -0,0 +1,191 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifdef MPI
|
||||
|
||||
#include "bprolog.h"
|
||||
#include "core/error.h"
|
||||
#include "up/up.h"
|
||||
#include "mp/mp.h"
|
||||
#include "mp/mp_core.h"
|
||||
#include <unistd.h> /* STDOUT_FILENO */
|
||||
#include <string.h>
|
||||
#include <mpi.h>
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
/* cpred.c (B-Prolog) */
|
||||
int bp_string_2_term(const char *, TERM, TERM);
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
static char str_prealloc[65536];
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
static int send_term(TERM arg, int mode, int rank)
|
||||
{
|
||||
char *str;
|
||||
int len;
|
||||
|
||||
str = (char *)bpx_term_2_string(arg);
|
||||
len = strlen(str);
|
||||
|
||||
switch (mode) {
|
||||
case 0:
|
||||
MPI_Send (&len, 1 , MPI_INT , rank, TAG_GOAL_LEN, MPI_COMM_WORLD);
|
||||
MPI_Send ( str, len, MPI_CHAR, rank, TAG_GOAL_STR, MPI_COMM_WORLD);
|
||||
break;
|
||||
case 1:
|
||||
MPI_Bcast(&len, 1 , MPI_INT , rank, MPI_COMM_WORLD);
|
||||
MPI_Bcast( str, len, MPI_CHAR, rank, MPI_COMM_WORLD);
|
||||
break;
|
||||
}
|
||||
|
||||
mp_debug("SEND(%d,%d): %s", mode, rank, str);
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
static int recv_term(TERM arg, int mode, int rank)
|
||||
{
|
||||
char *str;
|
||||
TERM op1, op2;
|
||||
int len, res;
|
||||
|
||||
switch (mode) {
|
||||
case 0:
|
||||
MPI_Recv (&len, 1, MPI_INT, rank, TAG_GOAL_LEN, MPI_COMM_WORLD, NULL);
|
||||
break;
|
||||
case 1:
|
||||
MPI_Bcast(&len, 1, MPI_INT, rank, MPI_COMM_WORLD);
|
||||
break;
|
||||
}
|
||||
|
||||
if (len < sizeof(str_prealloc))
|
||||
str = str_prealloc;
|
||||
else {
|
||||
str = MALLOC(len + 1);
|
||||
}
|
||||
|
||||
switch (mode) {
|
||||
case 0:
|
||||
MPI_Recv (str, len, MPI_CHAR, rank, TAG_GOAL_STR, MPI_COMM_WORLD, NULL);
|
||||
break;
|
||||
case 1:
|
||||
MPI_Bcast(str, len, MPI_CHAR, rank, MPI_COMM_WORLD);
|
||||
break;
|
||||
}
|
||||
|
||||
*(str + len) = '\0';
|
||||
|
||||
mp_debug("RECV(%d,%d): %s", mode, rank, str);
|
||||
|
||||
op1 = bpx_build_var();
|
||||
op2 = bpx_build_var();
|
||||
|
||||
res = bp_string_2_term(str,op1,op2);
|
||||
if (str != str_prealloc) {
|
||||
free(str);
|
||||
}
|
||||
if (res == BP_TRUE) {
|
||||
return bpx_unify(arg, op1);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
int pc_mp_size_1(void)
|
||||
{
|
||||
return bpx_unify(bpx_get_call_arg(1,1), bpx_build_integer(mp_size));
|
||||
}
|
||||
|
||||
int pc_mp_rank_1(void)
|
||||
{
|
||||
return bpx_unify(bpx_get_call_arg(1,1), bpx_build_integer(mp_rank));
|
||||
}
|
||||
|
||||
int pc_mp_master_0(void)
|
||||
{
|
||||
return (mp_rank == 0) ? BP_TRUE : BP_FALSE;
|
||||
}
|
||||
|
||||
int pc_mp_abort_0(void)
|
||||
{
|
||||
mp_quit(0);
|
||||
}
|
||||
|
||||
int pc_mp_wtime_1(void)
|
||||
{
|
||||
return bpx_unify(bpx_get_call_arg(1,1), bpx_build_float(MPI_Wtime()));
|
||||
}
|
||||
|
||||
int pc_mp_sync_2(void)
|
||||
{
|
||||
int args[2], amin[2], amax[2];
|
||||
|
||||
args[0] = bpx_get_integer(bpx_get_call_arg(1,2)); /* tag */
|
||||
args[1] = bpx_get_integer(bpx_get_call_arg(2,2)); /* sync-id */
|
||||
|
||||
mp_debug("SYNC(%d,%d): BGN", args[0], args[1]);
|
||||
|
||||
MPI_Allreduce(args, amin, 2, MPI_INT, MPI_MIN, MPI_COMM_WORLD);
|
||||
MPI_Allreduce(args, amax, 2, MPI_INT, MPI_MAX, MPI_COMM_WORLD);
|
||||
|
||||
if (amin[0] != amax[0]) {
|
||||
emit_internal_error("failure on sync (%d,%d)", args[0], args[1]);
|
||||
RET_INTERNAL_ERR;
|
||||
}
|
||||
|
||||
if (amin[1] < 0) {
|
||||
return BP_FALSE;
|
||||
}
|
||||
|
||||
if (amin[1] != amax[1]) {
|
||||
emit_internal_error("failure on sync (%d,%d)", args[0], args[1]);
|
||||
RET_INTERNAL_ERR;
|
||||
}
|
||||
|
||||
mp_debug("SYNC(%d,%d): END", args[0], args[1]);
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
int pc_mp_send_goal_1(void)
|
||||
{
|
||||
MPI_Status status;
|
||||
|
||||
MPI_Recv(NULL, 0, MPI_INT, MPI_ANY_SOURCE, TAG_GOAL_REQ, MPI_COMM_WORLD, &status);
|
||||
return send_term(bpx_get_call_arg(1,1), 0, status.MPI_SOURCE);
|
||||
}
|
||||
|
||||
int pc_mp_recv_goal_1(void)
|
||||
{
|
||||
MPI_Send(NULL, 0, MPI_INT, 0, TAG_GOAL_REQ, MPI_COMM_WORLD);
|
||||
return recv_term(bpx_get_call_arg(1,1), 0, 0);
|
||||
}
|
||||
|
||||
int pc_mpm_bcast_command_1(void)
|
||||
{
|
||||
return send_term(bpx_get_call_arg(1,1), 1, 0);
|
||||
}
|
||||
|
||||
int pc_mps_bcast_command_1(void)
|
||||
{
|
||||
return recv_term(bpx_get_call_arg(1,1), 1, 0);
|
||||
}
|
||||
|
||||
int pc_mps_revert_stdout_0(void)
|
||||
{
|
||||
if (fd_dup_stdout >= 0) {
|
||||
dup2(fd_dup_stdout, STDOUT_FILENO);
|
||||
close(fd_dup_stdout);
|
||||
fd_dup_stdout = -1;
|
||||
}
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MPI */
|
||||
22
packages/prism/src/c/mp/mp_preds.h
Normal file
22
packages/prism/src/c/mp/mp_preds.h
Normal file
@@ -0,0 +1,22 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifndef MP_PREDS_H
|
||||
#define MP_PREDS_H
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
int pc_mp_size_1(void);
|
||||
int pc_mp_rank_1(void);
|
||||
int pc_mp_master_0(void);
|
||||
int pc_mp_abort_0(void);
|
||||
int pc_mp_wtime_1(void);
|
||||
int pc_mp_sync_2(void);
|
||||
int pc_mp_send_goal_1(void);
|
||||
int pc_mp_recv_goal_1(void);
|
||||
int pc_mpm_bcast_command_1(void);
|
||||
int pc_mps_bcast_command_1(void);
|
||||
int pc_mps_revert_stdout_0(void);
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MP_PREDS_H */
|
||||
206
packages/prism/src/c/mp/mp_sw.c
Normal file
206
packages/prism/src/c/mp/mp_sw.c
Normal file
@@ -0,0 +1,206 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifdef MPI
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
#include "bprolog.h"
|
||||
#include "core/idtable.h"
|
||||
#include "core/idtable_preds.h"
|
||||
#include "up/up.h"
|
||||
#include "up/em_aux.h"
|
||||
#include "up/graph.h"
|
||||
#include "up/flags.h"
|
||||
#include "mp/mp.h"
|
||||
#include "mp/mp_core.h"
|
||||
#include "mp/mp_em_aux.h"
|
||||
#include <mpi.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
int *occ_position = NULL;
|
||||
static int * sizes = NULL;
|
||||
static int ** swids = NULL;
|
||||
|
||||
#define L(i) (sizes[i * 2 + 0]) /* length of the message from RANK #i */
|
||||
#define N(i) (sizes[i * 2 + 1]) /* number of switches in RANK #i*/
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
/* cpred.c (B-Prolog) */
|
||||
int bp_string_2_term(const char *, TERM, TERM);
|
||||
|
||||
/* mic.c (B-Prolog) */
|
||||
NORET quit(const char *);
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
static void parse_switch_req(const char *msg, int src)
|
||||
{
|
||||
const char *p;
|
||||
TERM op1, op2;
|
||||
int i;
|
||||
|
||||
swids[src] = MALLOC(sizeof(int) * N(src));
|
||||
|
||||
p = msg;
|
||||
|
||||
for (i = 0; i < N(src); i++) {
|
||||
op1 = bpx_build_var();
|
||||
op2 = bpx_build_var();
|
||||
bp_string_2_term(p, op1, op2);
|
||||
swids[src][i] = prism_sw_id_register(op1);
|
||||
while (*(p++) != '\0') ;
|
||||
}
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
int pc_mp_send_switches_0(void)
|
||||
{
|
||||
char *msg, *str;
|
||||
TERM msw;
|
||||
int msglen, msgsiz;
|
||||
int vals[2];
|
||||
int i, n;
|
||||
|
||||
msglen = 0;
|
||||
msgsiz = 65536;
|
||||
msg = MALLOC(msgsiz);
|
||||
|
||||
for (i = 0; i < occ_switch_tab_size; i++) {
|
||||
msw = bpx_get_arg(1, prism_sw_ins_term(occ_switches[i]->id));
|
||||
str = (char *)bpx_term_2_string(msw);
|
||||
|
||||
n = strlen(str) + 1;
|
||||
|
||||
if (msgsiz <= msglen + n) {
|
||||
msgsiz = (msglen + n + 65536) & ~65535;
|
||||
msg = REALLOC(msg, msgsiz);
|
||||
}
|
||||
|
||||
strcpy(msg + msglen, str);
|
||||
msglen += n;
|
||||
}
|
||||
|
||||
msg[msglen++] = '\0'; /* this is safe */
|
||||
|
||||
vals[0] = msglen;
|
||||
vals[1] = occ_switch_tab_size;
|
||||
|
||||
MPI_Gather(vals, 2, MPI_INT, NULL, 0, MPI_INT, 0, MPI_COMM_WORLD);
|
||||
MPI_Send(msg, msglen, MPI_CHAR, 0, TAG_SWITCH_REQ, MPI_COMM_WORLD);
|
||||
|
||||
free(msg);
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
int pc_mp_recv_switches_0(void)
|
||||
{
|
||||
int i, lmax, vals[2];
|
||||
char *msg;
|
||||
|
||||
sizes = MALLOC(sizeof(int) * 2 * mp_size);
|
||||
swids = MALLOC(sizeof(int *) * mp_size);
|
||||
|
||||
MPI_Gather(vals, 2, MPI_INT, sizes, 2, MPI_INT, 0, MPI_COMM_WORLD);
|
||||
|
||||
lmax = 0;
|
||||
|
||||
for (i = 1; i < mp_size; i++) {
|
||||
if (lmax < L(i)) {
|
||||
lmax = L(i);
|
||||
}
|
||||
}
|
||||
|
||||
msg = MALLOC(lmax);
|
||||
|
||||
for (i = 1; i < mp_size; i++) {
|
||||
MPI_Recv(msg, L(i), MPI_CHAR, i, TAG_SWITCH_REQ, MPI_COMM_WORLD, NULL);
|
||||
parse_switch_req(msg, i);
|
||||
}
|
||||
|
||||
free(msg);
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
int pc_mp_send_swlayout_0(void)
|
||||
{
|
||||
int i, j, *msg, *pos;
|
||||
|
||||
msg = MALLOC(sizeof(int) * sw_tab_size);
|
||||
pos = MALLOC(sizeof(int) * sw_ins_tab_size);
|
||||
|
||||
j = 0;
|
||||
|
||||
for (i = 0; i < occ_switch_tab_size; i++) {
|
||||
pos[occ_switches[i]->id] = j;
|
||||
j += num_sw_vals[i];
|
||||
}
|
||||
|
||||
sw_msg_size = j;
|
||||
|
||||
for (i = 1; i < mp_size; i++) {
|
||||
for (j = 0; j < N(i); j++) {
|
||||
msg[j] = pos[switches[swids[i][j]]->id];
|
||||
}
|
||||
|
||||
MPI_Send(msg, N(i), MPI_INT, i, TAG_SWITCH_RES, MPI_COMM_WORLD);
|
||||
free(swids[i]);
|
||||
}
|
||||
|
||||
free(pos);
|
||||
free(msg);
|
||||
|
||||
free(sizes);
|
||||
free(swids);
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
int pc_mp_recv_swlayout_0(void)
|
||||
{
|
||||
occ_position = MALLOC(sizeof(int) * occ_switch_tab_size);
|
||||
|
||||
MPI_Recv(occ_position, occ_switch_tab_size, MPI_INT, 0, TAG_SWITCH_RES, MPI_COMM_WORLD, NULL);
|
||||
|
||||
/* debug */
|
||||
{
|
||||
int i;
|
||||
TERM msw;
|
||||
for (i = 0; i < occ_switch_tab_size; i++) {
|
||||
msw = bpx_get_arg(1, prism_sw_ins_term(occ_switches[i]->id));
|
||||
mp_debug("%s -> %d", bpx_term_2_string(msw), occ_position[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
int pc_mpm_alloc_occ_switches_0(void)
|
||||
{
|
||||
occ_switches = MALLOC(sizeof(SW_INS_PTR) * sw_tab_size);
|
||||
|
||||
occ_switch_tab_size = sw_tab_size;
|
||||
memcpy(occ_switches, switches, sizeof(SW_INS_PTR) * sw_tab_size);
|
||||
if (fix_init_order) {
|
||||
sort_occ_switches();
|
||||
}
|
||||
alloc_num_sw_vals();
|
||||
|
||||
return BP_TRUE;
|
||||
}
|
||||
|
||||
void release_occ_position(void)
|
||||
{
|
||||
free(occ_position);
|
||||
occ_position = NULL;
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MPI */
|
||||
22
packages/prism/src/c/mp/mp_sw.h
Normal file
22
packages/prism/src/c/mp/mp_sw.h
Normal file
@@ -0,0 +1,22 @@
|
||||
/* -*- c-basic-offset: 4 ; tab-width: 4 -*- */
|
||||
|
||||
#ifndef MP_SW_H
|
||||
#define MP_SW_H
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
extern int *occ_position;
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
int pc_mp_send_switches_0(void);
|
||||
int pc_mp_recv_switches_0(void);
|
||||
int pc_mp_send_swlayout_0(void);
|
||||
int pc_mp_recv_swlayout_0(void);
|
||||
int pc_mpm_alloc_occ_switches_0(void);
|
||||
|
||||
void release_occ_position(void);
|
||||
|
||||
/*-------------------------------------------------------------------------*/
|
||||
|
||||
#endif /* MP_SW_H */
|
||||
Reference in New Issue
Block a user