Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions oshmem/mca/atomic/ucx/atomic_ucx_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,19 @@ static ucp_request_param_t mca_spml_ucp_request_params[] = {
};
#endif

static int mca_atomic_ucx_nb_support = 0;

/*
* Initial query function that is invoked during initialization, allowing
* this module to indicate what level of thread support it provides.
*/
int mca_atomic_ucx_startup(bool enable_progress_threads, bool enable_threads)
{
unsigned major, minor, release_number;
ucp_get_version(&major, &minor, &release_number);

mca_atomic_ucx_nb_support = UCX_VERSION(major, minor, release_number) >= UCX_VERSION(1, 20, 0);

return OSHMEM_SUCCESS;
}

Expand Down Expand Up @@ -73,8 +80,15 @@ int mca_atomic_ucx_op(shmem_ctx_t ctx,
status_ptr = ucp_atomic_op_nbx(ucx_ctx->ucp_peers[pe].ucp_conn,
op, &value, 1, rva, ucx_mkey->rkey,
&mca_spml_ucp_request_params[size >> 3]);
res = opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0],
"ucp_atomic_op_nbx post");
if (mca_atomic_ucx_nb_support) {
/* UCX is packing (copying) the value pointer, so there's no need to wait for completion
(no stack corruption concerns). Additionally, there's no need to free the status pointer
as its already freed by ucp_atomic_op_nbx when a reply buffer is not provided. */
res = UCS_PTR_IS_ERR(status_ptr) ? OSHMEM_ERROR : OSHMEM_SUCCESS;
} else {
res = opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0],
"ucp_atomic_op_nbx");
}
#else
status = ucp_atomic_post(ucx_ctx->ucp_peers[pe].ucp_conn,
op, value, size, rva,
Expand Down
66 changes: 62 additions & 4 deletions oshmem/mca/spml/ucx/spml_ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -1644,19 +1644,77 @@ int mca_spml_ucx_put_all_nb(void *dest, const void *source, size_t size, long *c
return OSHMEM_SUCCESS;
}

/* This routine is not implemented */

static inline int mca_spml_ucx_signal(shmem_ctx_t ctx,
uint64_t *sig_addr,
uint64_t signal,
int sig_op,
int dst)
{
uint64_t dummy_prev;

if (sig_op == SHMEM_SIGNAL_SET) {
return MCA_ATOMIC_CALL(swap(ctx, (void*)sig_addr, (void*)&dummy_prev,
signal, sizeof(uint64_t), dst));
} else if (sig_op == SHMEM_SIGNAL_ADD) {
return MCA_ATOMIC_CALL(add(ctx, (void*)sig_addr, signal,
sizeof(uint64_t), dst));
}

SPML_UCX_ERROR("Invalid signal operation: %d", sig_op);
return OSHMEM_ERR_NOT_IMPLEMENTED;
}

static inline int mca_spml_ucx_signal_nb(shmem_ctx_t ctx,
uint64_t *sig_addr,
uint64_t signal,
int sig_op,
int dst)
{
uint64_t dummy_prev, dummy_fetch;

if (sig_op == SHMEM_SIGNAL_SET) {
return MCA_ATOMIC_CALL(swap_nb(ctx, &dummy_fetch, (void*)sig_addr, (void*)&dummy_prev,
signal, sizeof(uint64_t), dst));
} else if (sig_op == SHMEM_SIGNAL_ADD) {
return MCA_ATOMIC_CALL(fadd_nb(ctx, &dummy_fetch, (void*)sig_addr, (void*)&dummy_prev,
signal, sizeof(uint64_t), dst));
}

SPML_UCX_ERROR("Invalid signal operation: %d", sig_op);
return OSHMEM_ERR_NOT_IMPLEMENTED;
}

int mca_spml_ucx_put_signal(shmem_ctx_t ctx, void* dst_addr, size_t size, void*
src_addr, uint64_t *sig_addr, uint64_t signal, int sig_op, int dst)
{
return OSHMEM_ERR_NOT_IMPLEMENTED;
int res;

res = mca_spml_ucx_put(ctx, dst_addr, size, src_addr, dst);
if (OPAL_UNLIKELY(OSHMEM_SUCCESS != res)) {
return res;
}

return mca_spml_ucx_signal(ctx, sig_addr, signal, sig_op, dst);
}

/* This routine is not implemented */
int mca_spml_ucx_put_signal_nb(shmem_ctx_t ctx, void* dst_addr, size_t size,
void* src_addr, uint64_t *sig_addr, uint64_t signal, int sig_op, int
dst)
{
return OSHMEM_ERR_NOT_IMPLEMENTED;
int res;

res = mca_spml_ucx_put_nb(ctx, dst_addr, size, src_addr, dst, NULL);
if (OPAL_UNLIKELY(OSHMEM_SUCCESS != res)) {
return res;
}

res = mca_spml_ucx_fence(ctx);
if (OPAL_UNLIKELY(OSHMEM_SUCCESS != res)) {
return res;
}

return mca_spml_ucx_signal_nb(ctx, sig_addr, signal, sig_op, dst);
}

/* This routine is not implemented */
Expand Down