Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion config/ompi_check_ucx.m4
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[
[have ucp_tag_send_nbr()])], [],
[#include <ucp/api/ucp.h>])
AC_CHECK_DECLS([ucp_ep_flush_nb, ucp_worker_flush_nb,
ucp_request_check_status, ucp_put_nb, ucp_get_nb],
ucp_request_check_status, ucp_put_nb, ucp_get_nb,
ucp_put_nbx, ucp_get_nbx, ucp_atomic_op_nbx],
[], [],
[#include <ucp/api/ucp.h>])
AC_CHECK_DECLS([ucm_test_events,
Expand Down
18 changes: 18 additions & 0 deletions oshmem/mca/atomic/ucx/atomic_ucx_cswap.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ int mca_atomic_ucx_cswap(shmem_ctx_t ctx,
spml_ucx_mkey_t *ucx_mkey;
uint64_t rva;
mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx;
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
ucp_request_param_t param = {
.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE |
UCP_OP_ATTR_FIELD_REPLY_BUFFER,
.datatype = ucp_dt_make_contig(size),
.reply_buffer = prev
};
#endif

if ((8 != size) && (4 != size)) {
ATOMIC_ERROR("[#%d] Type size must be 4 or 8 bytes.", my_pe);
Expand All @@ -41,15 +49,25 @@ int mca_atomic_ucx_cswap(shmem_ctx_t ctx,

*prev = value;
ucx_mkey = mca_spml_ucx_get_mkey(ctx, pe, target, (void *)&rva, mca_spml_self);
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
status_ptr = ucp_atomic_op_nbx(ucx_ctx->ucp_peers[pe].ucp_conn,
UCP_ATOMIC_OP_CSWAP, &cond, 1, rva,
ucx_mkey->rkey, &param);
#else
status_ptr = ucp_atomic_fetch_nb(ucx_ctx->ucp_peers[pe].ucp_conn,
UCP_ATOMIC_FETCH_OP_CSWAP, cond, prev, size,
rva, ucx_mkey->rkey,
opal_common_ucx_empty_complete_cb);
#endif

if (OPAL_LIKELY(!UCS_PTR_IS_ERR(status_ptr))) {
mca_spml_ucx_remote_op_posted(ucx_ctx, pe);
}

return opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0],
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
"ucp_atomic_op_nbx");
#else
"ucp_atomic_fetch_nb");
#endif
}
85 changes: 78 additions & 7 deletions oshmem/mca/atomic/ucx/atomic_ucx_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@
#include "oshmem/proc/proc.h"
#include "atomic_ucx.h"

#if HAVE_DECL_UCP_ATOMIC_OP_NBX
/*
* A static params array, for datatypes of size 4 and 8. "size >> 3" is used to
* access the corresponding offset.
*/
static ucp_request_param_t mca_spml_ucp_request_params[] = {
{.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE, .datatype = ucp_dt_make_contig(4)},
{.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE, .datatype = ucp_dt_make_contig(8)}
};
#endif

/*
* Initial query function that is invoked during initialization, allowing
* this module to indicate what level of thread support it provides.
Expand All @@ -38,20 +49,37 @@ int mca_atomic_ucx_op(shmem_ctx_t ctx,
uint64_t value,
size_t size,
int pe,
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
ucp_atomic_op_t op)
#else
ucp_atomic_post_op_t op)
#endif
{
ucs_status_t status;
spml_ucx_mkey_t *ucx_mkey;
uint64_t rva;
mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx;
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
ucs_status_ptr_t status_ptr;
#endif

assert((8 == size) || (4 == size));

ucx_mkey = mca_spml_ucx_get_mkey(ctx, pe, target, (void *)&rva, mca_spml_self);

#if HAVE_DECL_UCP_ATOMIC_OP_NBX
status_ptr = ucp_atomic_op_nbx(ucx_ctx->ucp_peers[pe].ucp_conn,
op, &value, 1, rva, ucx_mkey->rkey,
&mca_spml_ucp_request_params[size >> 3]);
if (OPAL_LIKELY(!UCS_PTR_IS_ERR(status_ptr))) {
mca_spml_ucx_remote_op_posted(ucx_ctx, pe);
}
status = UCS_PTR_STATUS(status_ptr);
#else
status = ucp_atomic_post(ucx_ctx->ucp_peers[pe].ucp_conn,
op, value, size, rva,
ucx_mkey->rkey);

#endif
if (OPAL_LIKELY(UCS_OK == status)) {
mca_spml_ucx_remote_op_posted(ucx_ctx, pe);
}
Expand All @@ -66,22 +94,41 @@ int mca_atomic_ucx_fop(shmem_ctx_t ctx,
uint64_t value,
size_t size,
int pe,
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
ucp_atomic_op_t op)
#else
ucp_atomic_fetch_op_t op)
#endif
{
ucs_status_ptr_t status_ptr;
spml_ucx_mkey_t *ucx_mkey;
uint64_t rva;
mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx;
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
ucp_request_param_t param = {
.op_attr_mask = UCP_OP_ATTR_FIELD_DATATYPE |
UCP_OP_ATTR_FIELD_REPLY_BUFFER,
.datatype = ucp_dt_make_contig(size),
.reply_buffer = prev
};
#endif

assert((8 == size) || (4 == size));

ucx_mkey = mca_spml_ucx_get_mkey(ctx, pe, target, (void *)&rva, mca_spml_self);
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
status_ptr = ucp_atomic_op_nbx(ucx_ctx->ucp_peers[pe].ucp_conn, op, &value, 1,
rva, ucx_mkey->rkey, &param);
return opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0],
"ucp_atomic_op_nbx");
#else
status_ptr = ucp_atomic_fetch_nb(ucx_ctx->ucp_peers[pe].ucp_conn,
op, value, prev, size,
rva, ucx_mkey->rkey,
opal_common_ucx_empty_complete_cb);
return opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0],
"ucp_atomic_fetch_nb");
#endif
}

static int mca_atomic_ucx_add(shmem_ctx_t ctx,
Expand All @@ -90,7 +137,11 @@ static int mca_atomic_ucx_add(shmem_ctx_t ctx,
size_t size,
int pe)
{
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_OP_ADD);
#else
return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_POST_OP_ADD);
#endif
}

static int mca_atomic_ucx_and(shmem_ctx_t ctx,
Expand All @@ -99,7 +150,9 @@ static int mca_atomic_ucx_and(shmem_ctx_t ctx,
size_t size,
int pe)
{
#if HAVE_DECL_UCP_ATOMIC_POST_OP_AND
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_OP_AND);
#elif HAVE_DECL_UCP_ATOMIC_POST_OP_AND
return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_POST_OP_AND);
#else
return OSHMEM_ERR_NOT_IMPLEMENTED;
Expand All @@ -112,7 +165,9 @@ static int mca_atomic_ucx_or(shmem_ctx_t ctx,
size_t size,
int pe)
{
#if HAVE_DECL_UCP_ATOMIC_POST_OP_OR
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_OP_OR);
#elif HAVE_DECL_UCP_ATOMIC_POST_OP_OR
return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_POST_OP_OR);
#else
return OSHMEM_ERR_NOT_IMPLEMENTED;
Expand All @@ -125,7 +180,9 @@ static int mca_atomic_ucx_xor(shmem_ctx_t ctx,
size_t size,
int pe)
{
#if HAVE_DECL_UCP_ATOMIC_POST_OP_XOR
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_OP_XOR);
#elif HAVE_DECL_UCP_ATOMIC_POST_OP_XOR
return mca_atomic_ucx_op(ctx, target, value, size, pe, UCP_ATOMIC_POST_OP_XOR);
#else
return OSHMEM_ERR_NOT_IMPLEMENTED;
Expand All @@ -139,7 +196,11 @@ static int mca_atomic_ucx_fadd(shmem_ctx_t ctx,
size_t size,
int pe)
{
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_OP_ADD);
#else
return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_FETCH_OP_FADD);
#endif
}

static int mca_atomic_ucx_fand(shmem_ctx_t ctx,
Expand All @@ -149,7 +210,9 @@ static int mca_atomic_ucx_fand(shmem_ctx_t ctx,
size_t size,
int pe)
{
#if HAVE_DECL_UCP_ATOMIC_FETCH_OP_FAND
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_OP_AND);
#elif HAVE_DECL_UCP_ATOMIC_FETCH_OP_FAND
return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_FETCH_OP_FAND);
#else
return OSHMEM_ERR_NOT_IMPLEMENTED;
Expand All @@ -163,7 +226,9 @@ static int mca_atomic_ucx_for(shmem_ctx_t ctx,
size_t size,
int pe)
{
#if HAVE_DECL_UCP_ATOMIC_FETCH_OP_FOR
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_OP_OR);
#elif HAVE_DECL_UCP_ATOMIC_FETCH_OP_FOR
return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_FETCH_OP_FOR);
#else
return OSHMEM_ERR_NOT_IMPLEMENTED;
Expand All @@ -177,7 +242,9 @@ static int mca_atomic_ucx_fxor(shmem_ctx_t ctx,
size_t size,
int pe)
{
#if HAVE_DECL_UCP_ATOMIC_FETCH_OP_FXOR
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_OP_XOR);
#elif HAVE_DECL_UCP_ATOMIC_FETCH_OP_FXOR
return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_FETCH_OP_FXOR);
#else
return OSHMEM_ERR_NOT_IMPLEMENTED;
Expand All @@ -191,7 +258,11 @@ static int mca_atomic_ucx_swap(shmem_ctx_t ctx,
size_t size,
int pe)
{
#if HAVE_DECL_UCP_ATOMIC_OP_NBX
return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_OP_SWAP);
#else
return mca_atomic_ucx_fop(ctx, target, prev, value, size, pe, UCP_ATOMIC_FETCH_OP_SWAP);
#endif
}


Expand Down
Loading