diff --git a/oshmem/mca/scoll/ucc/scoll_ucc.h b/oshmem/mca/scoll/ucc/scoll_ucc.h index fa2aa04f855..17117dcfe44 100644 --- a/oshmem/mca/scoll/ucc/scoll_ucc.h +++ b/oshmem/mca/scoll/ucc/scoll_ucc.h @@ -44,10 +44,10 @@ struct mca_scoll_ucc_component_t { char * cts; int nr_modules; bool libucc_initialized; + ucc_context_h ucc_context; ucc_lib_h ucc_lib; ucc_lib_attr_t ucc_lib_attr; ucc_coll_type_t cts_requested; - ucc_context_h ucc_context; }; typedef struct mca_scoll_ucc_component_t mca_scoll_ucc_component_t; @@ -85,6 +85,8 @@ int mca_scoll_ucc_init_query(bool enable_progress_threads, bool enable_mpi_threa int mca_scoll_ucc_team_create(mca_scoll_ucc_module_t *ucc_module, oshmem_group_t *osh_group); +int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group); + mca_scoll_base_module_t* mca_scoll_ucc_comm_query(oshmem_group_t *osh_group, int *priority); int mca_scoll_ucc_barrier(struct oshmem_group_t *group, long *pSync, int alg); diff --git a/oshmem/mca/scoll/ucc/scoll_ucc_alltoall.c b/oshmem/mca/scoll/ucc/scoll_ucc_alltoall.c index 07ad22fa6fc..3b4b1f48ca6 100644 --- a/oshmem/mca/scoll/ucc/scoll_ucc_alltoall.c +++ b/oshmem/mca/scoll/ucc/scoll_ucc_alltoall.c @@ -46,6 +46,12 @@ static inline ucc_status_t mca_scoll_ucc_alltoall_init(const void *sbuf, void *r .global_work_buffer = ucc_module->pSync, }; + if (NULL == mca_scoll_ucc_component.ucc_context) { + if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) { + return OSHMEM_ERROR; + } + } + if (NULL == ucc_module->ucc_team) { if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) { return OSHMEM_ERROR; diff --git a/oshmem/mca/scoll/ucc/scoll_ucc_barrier.c b/oshmem/mca/scoll/ucc/scoll_ucc_barrier.c index 8f7a7d5ae97..6fc88ea4a07 100644 --- a/oshmem/mca/scoll/ucc/scoll_ucc_barrier.c +++ b/oshmem/mca/scoll/ucc/scoll_ucc_barrier.c @@ -19,11 +19,19 @@ static inline ucc_status_t mca_scoll_ucc_barrier_init(mca_scoll_ucc_module_t * u .mask = 0, .coll_type = UCC_COLL_TYPE_BARRIER }; + + if (NULL == mca_scoll_ucc_component.ucc_context) { + if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) { + return OSHMEM_ERROR; + } + } + if (NULL == ucc_module->ucc_team) { if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) { return OSHMEM_ERROR; } } + SCOLL_UCC_REQ_INIT(req, coll, ucc_module); return UCC_OK; fallback: @@ -49,4 +57,3 @@ int mca_scoll_ucc_barrier(struct oshmem_group_t *group, long *pSync, int alg) pSync, alg); return rc; } - diff --git a/oshmem/mca/scoll/ucc/scoll_ucc_broadcast.c b/oshmem/mca/scoll/ucc/scoll_ucc_broadcast.c index bc3f08fcde8..d3838e2574e 100644 --- a/oshmem/mca/scoll/ucc/scoll_ucc_broadcast.c +++ b/oshmem/mca/scoll/ucc/scoll_ucc_broadcast.c @@ -28,6 +28,13 @@ static inline ucc_status_t mca_scoll_ucc_broadcast_init(void * buf, int count, .mem_type = UCC_MEMORY_TYPE_UNKNOWN } }; + + if (NULL == mca_scoll_ucc_component.ucc_context) { + if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) { + return OSHMEM_ERROR; + } + } + if (NULL == ucc_module->ucc_team) { if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) { return OSHMEM_ERROR; diff --git a/oshmem/mca/scoll/ucc/scoll_ucc_collect.c b/oshmem/mca/scoll/ucc/scoll_ucc_collect.c index b25f6e38222..3f3aa8a3d13 100644 --- a/oshmem/mca/scoll/ucc/scoll_ucc_collect.c +++ b/oshmem/mca/scoll/ucc/scoll_ucc_collect.c @@ -34,6 +34,12 @@ static inline ucc_status_t mca_scoll_ucc_collect_init(const void * sbuf, void * }, }; + if (NULL == mca_scoll_ucc_component.ucc_context) { + if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) { + return OSHMEM_ERROR; + } + } + if (NULL == ucc_module->ucc_team) { if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) { return OSHMEM_ERROR; diff --git a/oshmem/mca/scoll/ucc/scoll_ucc_component.c b/oshmem/mca/scoll/ucc/scoll_ucc_component.c index a63e78799a4..3f916d636ea 100644 --- a/oshmem/mca/scoll/ucc/scoll_ucc_component.c +++ b/oshmem/mca/scoll/ucc/scoll_ucc_component.c @@ -63,7 +63,8 @@ mca_scoll_ucc_component_t mca_scoll_ucc_component = { "basic", /* cls */ SCOLL_UCC_CTS_STR, /* cts */ 0, /* nr_modules */ - false /* libucc_initialized */ + false, /* libucc_initialized */ + NULL /* ucc_context */ }; static int ucc_register(void) diff --git a/oshmem/mca/scoll/ucc/scoll_ucc_module.c b/oshmem/mca/scoll/ucc/scoll_ucc_module.c index f70ecbdb4ba..45d9896ee6f 100644 --- a/oshmem/mca/scoll/ucc/scoll_ucc_module.c +++ b/oshmem/mca/scoll/ucc/scoll_ucc_module.c @@ -57,10 +57,12 @@ static void mca_scoll_ucc_module_destruct(mca_scoll_ucc_module_t *ucc_module) } if (0 == mca_scoll_ucc_component.nr_modules) { - if (mca_scoll_ucc_component.libucc_initialized) { + if (mca_scoll_ucc_component.libucc_initialized) { + if (mca_scoll_ucc_component.ucc_context) { + opal_progress_unregister(mca_scoll_ucc_progress); + ucc_context_destroy(mca_scoll_ucc_component.ucc_context); + } UCC_VERBOSE(1, "finalizing ucc library"); - opal_progress_unregister(mca_scoll_ucc_progress); - ucc_context_destroy(mca_scoll_ucc_component.ucc_context); ucc_finalize(mca_scoll_ucc_component.ucc_lib); mca_scoll_ucc_component.libucc_initialized = false; } @@ -199,17 +201,12 @@ static ucc_status_t oob_allgather_test(void *req) return oob_probe_test(oob_req); } -static int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group) +static int mca_scoll_ucc_init(oshmem_group_t *osh_group) { mca_scoll_ucc_component_t *cm = &mca_scoll_ucc_component; - ucc_mem_map_t *maps = NULL; - char str_buf[256]; ucc_lib_config_h lib_config; - ucc_context_config_h ctx_config; ucc_thread_mode_t tm_requested; ucc_lib_params_t lib_params; - ucc_context_params_t ctx_params; - int segment; tm_requested = oshmem_mpi_thread_multiple ? UCC_THREAD_MULTIPLE : UCC_THREAD_SINGLE; @@ -247,6 +244,25 @@ static int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group) goto cleanup_lib; } + cm->libucc_initialized = true; + return OSHMEM_SUCCESS; + +cleanup_lib: + ucc_finalize(cm->ucc_lib); + cm->ucc_enable = 0; + cm->libucc_initialized = false; + return OSHMEM_ERROR; +} + +int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group) +{ + mca_scoll_ucc_component_t *cm = &mca_scoll_ucc_component; + ucc_mem_map_t *maps = NULL; + char str_buf[256]; + ucc_context_config_h ctx_config; + ucc_context_params_t ctx_params; + int segment; + maps = (ucc_mem_map_t *)malloc(sizeof(ucc_mem_map_t) * memheap_map->n_segments); if (NULL == maps) { @@ -398,7 +414,6 @@ static int mca_scoll_ucc_module_enable(mca_scoll_base_module_t *module, opal_show_help("help-oshmem-scoll-ucc.txt", "module_enable:fatal", true, "UCC module enable failed - aborting to prevent inconsistent application state"); - goto err; } UCC_VERBOSE(1, "ucc enabled"); @@ -446,7 +461,7 @@ mca_scoll_ucc_comm_query(oshmem_group_t *osh_group, int *priority) if (!cm->libucc_initialized) { if (memheap_map && memheap_map->n_segments > 0) { - if (OSHMEM_SUCCESS != mca_scoll_ucc_init_ctx(osh_group)) { + if (OSHMEM_SUCCESS != mca_scoll_ucc_init(osh_group)) { cm->ucc_enable = 0; return NULL; } diff --git a/oshmem/mca/scoll/ucc/scoll_ucc_reduce.c b/oshmem/mca/scoll/ucc/scoll_ucc_reduce.c index 368cd479e1b..6b99f6ea983 100644 --- a/oshmem/mca/scoll/ucc/scoll_ucc_reduce.c +++ b/oshmem/mca/scoll/ucc/scoll_ucc_reduce.c @@ -54,6 +54,13 @@ static inline ucc_status_t mca_scoll_ucc_reduce_init(const void *sbuf, void *rbu coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } + + if (NULL == mca_scoll_ucc_component.ucc_context) { + if (OSHMEM_ERROR == mca_scoll_ucc_init_ctx(ucc_module->group)) { + return OSHMEM_ERROR; + } + } + if (NULL == ucc_module->ucc_team) { if (OSHMEM_ERROR == mca_scoll_ucc_team_create(ucc_module, ucc_module->group)) { return OSHMEM_ERROR;