diff --git a/oshmem/mca/atomic/atomic.h b/oshmem/mca/atomic/atomic.h index d4683ee4afd..c60c855108d 100644 --- a/oshmem/mca/atomic/atomic.h +++ b/oshmem/mca/atomic/atomic.h @@ -296,6 +296,11 @@ struct mca_atomic_base_module_1_0_0_t { uint64_t value, size_t size, int pe); + int (*atomic_set)(shmem_ctx_t ctx, + void *target, + uint64_t value, + size_t size, + int pe); }; typedef struct mca_atomic_base_module_1_0_0_t mca_atomic_base_module_1_0_0_t; diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_module.c b/oshmem/mca/atomic/ucx/atomic_ucx_module.c index 172626dd0c6..d0f31a21f91 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_module.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_module.c @@ -73,8 +73,7 @@ int mca_atomic_ucx_op(shmem_ctx_t ctx, status_ptr = ucp_atomic_op_nbx(ucx_ctx->ucp_peers[pe].ucp_conn, op, &value, 1, rva, ucx_mkey->rkey, &mca_spml_ucp_request_params[size >> 3]); - res = opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0], - "ucp_atomic_op_nbx post"); + res = UCS_PTR_IS_ERR(status_ptr) ? OSHMEM_ERROR : OSHMEM_SUCCESS; #else status = ucp_atomic_post(ucx_ctx->ucp_peers[pe].ucp_conn, op, value, size, rva, @@ -336,9 +335,18 @@ static int mca_atomic_ucx_cswap_nb(shmem_ctx_t ctx, return OSHMEM_ERR_NOT_IMPLEMENTED; } - - - +static int mca_atomic_ucx_set(shmem_ctx_t ctx, + void *target, + uint64_t value, + size_t size, + int pe) +{ +#if HAVE_DECL_UCP_ATOMIC_OP_NBX + return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_OP_SWAP); +#else + return OSHMEM_ERR_NOT_IMPLEMENTED; +#endif +} mca_atomic_base_module_t * mca_atomic_ucx_query(int *priority) @@ -365,6 +373,7 @@ mca_atomic_ucx_query(int *priority) module->super.atomic_fxor_nb = mca_atomic_ucx_fxor_nb; module->super.atomic_swap_nb = mca_atomic_ucx_swap_nb; module->super.atomic_cswap_nb = mca_atomic_ucx_cswap_nb; + module->super.atomic_set = mca_atomic_ucx_set; return &(module->super); } diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index 9fc1d965524..7cd50c9aaf6 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -1644,19 +1644,60 @@ int mca_spml_ucx_put_all_nb(void *dest, const void *source, size_t size, long *c return OSHMEM_SUCCESS; } -/* This routine is not implemented */ + +static inline int mca_spml_ucx_signal(shmem_ctx_t ctx, + uint64_t *sig_addr, + uint64_t signal, + int sig_op, + int dst) +{ + if (sig_op == SHMEM_SIGNAL_SET) { + return MCA_ATOMIC_CALL(set(ctx, (void*)sig_addr, signal, + sizeof(uint64_t), dst)); + } else if (sig_op == SHMEM_SIGNAL_ADD) { + return MCA_ATOMIC_CALL(add(ctx, (void*)sig_addr, signal, + sizeof(uint64_t), dst)); + } + + SPML_UCX_ERROR("Invalid signal operation: %d", sig_op); + return OSHMEM_ERR_NOT_IMPLEMENTED; +} + int mca_spml_ucx_put_signal(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_addr, uint64_t *sig_addr, uint64_t signal, int sig_op, int dst) { - return OSHMEM_ERR_NOT_IMPLEMENTED; + int res; + + res = mca_spml_ucx_put(ctx, dst_addr, size, src_addr, dst); + if (OPAL_UNLIKELY(OSHMEM_SUCCESS != res)) { + return res; + } + + res = mca_spml_ucx_fence(ctx); + if (OPAL_UNLIKELY(OSHMEM_SUCCESS != res)) { + return res; + } + + return mca_spml_ucx_signal(ctx, sig_addr, signal, sig_op, dst); } -/* This routine is not implemented */ int mca_spml_ucx_put_signal_nb(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_addr, uint64_t *sig_addr, uint64_t signal, int sig_op, int dst) { - return OSHMEM_ERR_NOT_IMPLEMENTED; + int res; + + res = mca_spml_ucx_put_nb(ctx, dst_addr, size, src_addr, dst, NULL); + if (OPAL_UNLIKELY(OSHMEM_SUCCESS != res)) { + return res; + } + + res = mca_spml_ucx_fence(ctx); + if (OPAL_UNLIKELY(OSHMEM_SUCCESS != res)) { + return res; + } + + return mca_spml_ucx_signal(ctx, sig_addr, signal, sig_op, dst); } /* This routine is not implemented */ diff --git a/oshmem/shmem/c/shmem_set.c b/oshmem/shmem/c/shmem_set.c index 268ba8221bd..179f2776dd8 100644 --- a/oshmem/shmem/c/shmem_set.c +++ b/oshmem/shmem/c/shmem_set.c @@ -25,19 +25,16 @@ */ #define DO_SHMEM_TYPE_ATOMIC_SET(ctx, type, target, value, pe) do { \ int rc = OSHMEM_SUCCESS; \ - size_t size = 0; \ - type out_value; \ + size_t size = sizeof(type); \ uint64_t value_tmp; \ RUNTIME_CHECK_INIT(); \ RUNTIME_CHECK_PE(pe); \ RUNTIME_CHECK_ADDR(target); \ \ - size = sizeof(out_value); \ memcpy(&value_tmp, &value, size); \ - rc = MCA_ATOMIC_CALL(swap( \ + rc = MCA_ATOMIC_CALL(set( \ ctx, \ (void*)target, \ - (void*)&out_value, \ value_tmp, \ size, \ pe)); \