diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_module.c b/oshmem/mca/atomic/ucx/atomic_ucx_module.c index 172626dd0c6..ed95431f5a4 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_module.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_module.c @@ -29,12 +29,19 @@ static ucp_request_param_t mca_spml_ucp_request_params[] = { }; #endif +static int mca_atomic_ucx_nb_support = 0; + /* * Initial query function that is invoked during initialization, allowing * this module to indicate what level of thread support it provides. */ int mca_atomic_ucx_startup(bool enable_progress_threads, bool enable_threads) { + unsigned major, minor, release_number; + ucp_get_version(&major, &minor, &release_number); + + mca_atomic_ucx_nb_support = UCX_VERSION(major, minor, release_number) >= UCX_VERSION(1, 20, 0); + return OSHMEM_SUCCESS; } @@ -73,8 +80,15 @@ int mca_atomic_ucx_op(shmem_ctx_t ctx, status_ptr = ucp_atomic_op_nbx(ucx_ctx->ucp_peers[pe].ucp_conn, op, &value, 1, rva, ucx_mkey->rkey, &mca_spml_ucp_request_params[size >> 3]); - res = opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0], - "ucp_atomic_op_nbx post"); + if (mca_atomic_ucx_nb_support) { + /* UCX is packing (copying) the value pointer, so there's no need to wait for completion + (no stack corruption concerns). Additionally, there's no need to free the status pointer + as its already freed by ucp_atomic_op_nbx when a reply buffer is not provided. */ + res = UCS_PTR_IS_ERR(status_ptr) ? OSHMEM_ERROR : OSHMEM_SUCCESS; + } else { + res = opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0], + "ucp_atomic_op_nbx"); + } #else status = ucp_atomic_post(ucx_ctx->ucp_peers[pe].ucp_conn, op, value, size, rva, diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index e73653b75e1..7973721f429 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -1644,19 +1644,77 @@ int mca_spml_ucx_put_all_nb(void *dest, const void *source, size_t size, long *c return OSHMEM_SUCCESS; } -/* This routine is not implemented */ + +static inline int mca_spml_ucx_signal(shmem_ctx_t ctx, + uint64_t *sig_addr, + uint64_t signal, + int sig_op, + int dst) +{ + uint64_t dummy_prev; + + if (sig_op == SHMEM_SIGNAL_SET) { + return MCA_ATOMIC_CALL(swap(ctx, (void*)sig_addr, (void*)&dummy_prev, + signal, sizeof(uint64_t), dst)); + } else if (sig_op == SHMEM_SIGNAL_ADD) { + return MCA_ATOMIC_CALL(add(ctx, (void*)sig_addr, signal, + sizeof(uint64_t), dst)); + } + + SPML_UCX_ERROR("Invalid signal operation: %d", sig_op); + return OSHMEM_ERR_NOT_IMPLEMENTED; +} + +static inline int mca_spml_ucx_signal_nb(shmem_ctx_t ctx, + uint64_t *sig_addr, + uint64_t signal, + int sig_op, + int dst) +{ + uint64_t dummy_prev, dummy_fetch; + + if (sig_op == SHMEM_SIGNAL_SET) { + return MCA_ATOMIC_CALL(swap_nb(ctx, &dummy_fetch, (void*)sig_addr, (void*)&dummy_prev, + signal, sizeof(uint64_t), dst)); + } else if (sig_op == SHMEM_SIGNAL_ADD) { + return MCA_ATOMIC_CALL(fadd_nb(ctx, &dummy_fetch, (void*)sig_addr, (void*)&dummy_prev, + signal, sizeof(uint64_t), dst)); + } + + SPML_UCX_ERROR("Invalid signal operation: %d", sig_op); + return OSHMEM_ERR_NOT_IMPLEMENTED; +} + int mca_spml_ucx_put_signal(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_addr, uint64_t *sig_addr, uint64_t signal, int sig_op, int dst) { - return OSHMEM_ERR_NOT_IMPLEMENTED; + int res; + + res = mca_spml_ucx_put(ctx, dst_addr, size, src_addr, dst); + if (OPAL_UNLIKELY(OSHMEM_SUCCESS != res)) { + return res; + } + + return mca_spml_ucx_signal(ctx, sig_addr, signal, sig_op, dst); } -/* This routine is not implemented */ int mca_spml_ucx_put_signal_nb(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_addr, uint64_t *sig_addr, uint64_t signal, int sig_op, int dst) { - return OSHMEM_ERR_NOT_IMPLEMENTED; + int res; + + res = mca_spml_ucx_put_nb(ctx, dst_addr, size, src_addr, dst, NULL); + if (OPAL_UNLIKELY(OSHMEM_SUCCESS != res)) { + return res; + } + + res = mca_spml_ucx_fence(ctx); + if (OPAL_UNLIKELY(OSHMEM_SUCCESS != res)) { + return res; + } + + return mca_spml_ucx_signal_nb(ctx, sig_addr, signal, sig_op, dst); } /* This routine is not implemented */