Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion oshmem/mca/scoll/ucc/scoll_ucc.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ struct mca_scoll_ucc_component_t {
char * cts;
int nr_modules;
bool libucc_initialized;
ucc_context_h ucc_context;
ucc_lib_h ucc_lib;
ucc_lib_attr_t ucc_lib_attr;
ucc_coll_type_t cts_requested;
ucc_context_h ucc_context;
};
typedef struct mca_scoll_ucc_component_t mca_scoll_ucc_component_t;

Expand Down Expand Up @@ -85,6 +85,8 @@ int mca_scoll_ucc_init_query(bool enable_progress_threads, bool enable_mpi_threa
int mca_scoll_ucc_team_create(mca_scoll_ucc_module_t *ucc_module,
oshmem_group_t *osh_group);

int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group);

mca_scoll_base_module_t* mca_scoll_ucc_comm_query(oshmem_group_t *osh_group, int *priority);

int mca_scoll_ucc_barrier(struct oshmem_group_t *group, long *pSync, int alg);
Expand Down
6 changes: 6 additions & 0 deletions oshmem/mca/scoll/ucc/scoll_ucc_alltoall.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ static inline ucc_status_t mca_scoll_ucc_alltoall_init(const void *sbuf, void *r
.global_work_buffer = ucc_module->pSync,
};

if (NULL == mca_scoll_ucc_component.ucc_context) {
if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) {
return OSHMEM_ERROR;
}
}

if (NULL == ucc_module->ucc_team) {
if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) {
return OSHMEM_ERROR;
Expand Down
9 changes: 8 additions & 1 deletion oshmem/mca/scoll/ucc/scoll_ucc_barrier.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,19 @@ static inline ucc_status_t mca_scoll_ucc_barrier_init(mca_scoll_ucc_module_t * u
.mask = 0,
.coll_type = UCC_COLL_TYPE_BARRIER
};

if (NULL == mca_scoll_ucc_component.ucc_context) {
if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) {
return OSHMEM_ERROR;
}
}

if (NULL == ucc_module->ucc_team) {
if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) {
return OSHMEM_ERROR;
}
}

SCOLL_UCC_REQ_INIT(req, coll, ucc_module);
return UCC_OK;
fallback:
Expand All @@ -49,4 +57,3 @@ int mca_scoll_ucc_barrier(struct oshmem_group_t *group, long *pSync, int alg)
pSync, alg);
return rc;
}

7 changes: 7 additions & 0 deletions oshmem/mca/scoll/ucc/scoll_ucc_broadcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ static inline ucc_status_t mca_scoll_ucc_broadcast_init(void * buf, int count,
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
}
};

if (NULL == mca_scoll_ucc_component.ucc_context) {
if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) {
return OSHMEM_ERROR;
}
}

if (NULL == ucc_module->ucc_team) {
if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) {
return OSHMEM_ERROR;
Expand Down
6 changes: 6 additions & 0 deletions oshmem/mca/scoll/ucc/scoll_ucc_collect.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ static inline ucc_status_t mca_scoll_ucc_collect_init(const void * sbuf, void *
},
};

if (NULL == mca_scoll_ucc_component.ucc_context) {
if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) {
return OSHMEM_ERROR;
}
}

if (NULL == ucc_module->ucc_team) {
if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) {
return OSHMEM_ERROR;
Expand Down
3 changes: 2 additions & 1 deletion oshmem/mca/scoll/ucc/scoll_ucc_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ mca_scoll_ucc_component_t mca_scoll_ucc_component = {
"basic", /* cls */
SCOLL_UCC_CTS_STR, /* cts */
0, /* nr_modules */
false /* libucc_initialized */
false, /* libucc_initialized */
NULL /* ucc_context */
};

static int ucc_register(void)
Expand Down
37 changes: 26 additions & 11 deletions oshmem/mca/scoll/ucc/scoll_ucc_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ static void mca_scoll_ucc_module_destruct(mca_scoll_ucc_module_t *ucc_module)
}

if (0 == mca_scoll_ucc_component.nr_modules) {
if (mca_scoll_ucc_component.libucc_initialized) {
if (mca_scoll_ucc_component.libucc_initialized) {
if (mca_scoll_ucc_component.ucc_context) {
opal_progress_unregister(mca_scoll_ucc_progress);
ucc_context_destroy(mca_scoll_ucc_component.ucc_context);
}
UCC_VERBOSE(1, "finalizing ucc library");
opal_progress_unregister(mca_scoll_ucc_progress);
ucc_context_destroy(mca_scoll_ucc_component.ucc_context);
ucc_finalize(mca_scoll_ucc_component.ucc_lib);
mca_scoll_ucc_component.libucc_initialized = false;
}
Expand Down Expand Up @@ -199,17 +201,12 @@ static ucc_status_t oob_allgather_test(void *req)
return oob_probe_test(oob_req);
}

static int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group)
static int mca_scoll_ucc_init(oshmem_group_t *osh_group)
{
mca_scoll_ucc_component_t *cm = &mca_scoll_ucc_component;
ucc_mem_map_t *maps = NULL;
char str_buf[256];
ucc_lib_config_h lib_config;
ucc_context_config_h ctx_config;
ucc_thread_mode_t tm_requested;
ucc_lib_params_t lib_params;
ucc_context_params_t ctx_params;
int segment;

tm_requested = oshmem_mpi_thread_multiple ? UCC_THREAD_MULTIPLE :
UCC_THREAD_SINGLE;
Expand Down Expand Up @@ -247,6 +244,25 @@ static int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group)
goto cleanup_lib;
}

cm->libucc_initialized = true;
return OSHMEM_SUCCESS;

cleanup_lib:
ucc_finalize(cm->ucc_lib);
cm->ucc_enable = 0;
cm->libucc_initialized = false;
return OSHMEM_ERROR;
}

int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group)
{
mca_scoll_ucc_component_t *cm = &mca_scoll_ucc_component;
ucc_mem_map_t *maps = NULL;
char str_buf[256];
ucc_context_config_h ctx_config;
ucc_context_params_t ctx_params;
int segment;

maps = (ucc_mem_map_t *)malloc(sizeof(ucc_mem_map_t) *
memheap_map->n_segments);
if (NULL == maps) {
Expand Down Expand Up @@ -398,7 +414,6 @@ static int mca_scoll_ucc_module_enable(mca_scoll_base_module_t *module,
opal_show_help("help-oshmem-scoll-ucc.txt",
"module_enable:fatal", true,
"UCC module enable failed - aborting to prevent inconsistent application state");

goto err;
}
UCC_VERBOSE(1, "ucc enabled");
Expand Down Expand Up @@ -446,7 +461,7 @@ mca_scoll_ucc_comm_query(oshmem_group_t *osh_group, int *priority)

if (!cm->libucc_initialized) {
if (memheap_map && memheap_map->n_segments > 0) {
if (OSHMEM_SUCCESS != mca_scoll_ucc_init_ctx(osh_group)) {
if (OSHMEM_SUCCESS != mca_scoll_ucc_init(osh_group)) {
cm->ucc_enable = 0;
return NULL;
}
Expand Down
7 changes: 7 additions & 0 deletions oshmem/mca/scoll/ucc/scoll_ucc_reduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ static inline ucc_status_t mca_scoll_ucc_reduce_init(const void *sbuf, void *rbu
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}

if (NULL == mca_scoll_ucc_component.ucc_context) {
if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) {
return OSHMEM_ERROR;
}
}

if (NULL == ucc_module->ucc_team) {
if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) {
return OSHMEM_ERROR;
Expand Down