Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UCP/RKEY: Acquire context lock when calling ucp_rkey_pack_memh #10462

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/ucp/core/ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -1963,6 +1963,34 @@ static void ucp_apply_params(ucp_context_h context, const ucp_params_t *params,
}
}

void ucp_context_set_worker_async(ucp_context_h context,
ucs_async_context_t *async)
{
if (async != NULL) {
/* Setting new worker async mutex */
if (context->mt_lock.mt_type == UCP_MT_TYPE_WORKER_ASYNC) {
ucs_error("worker async %p is already set for context %p",
context->mt_lock.lock.mt_worker_async, context);
} else if (context->mt_lock.mt_type != UCP_MT_TYPE_NONE) {
ucs_debug("context %p is already set with mutex mt_type %d",
context, context->mt_lock.mt_type);
} else {
context->mt_lock.mt_type = UCP_MT_TYPE_WORKER_ASYNC;
context->mt_lock.lock.mt_worker_async = async;
}
} else {
/* Resetting existing worker async mutex */
if (context->mt_lock.mt_type == UCP_MT_TYPE_WORKER_ASYNC) {
if (context->mt_lock.lock.mt_worker_async != NULL) {
context->mt_lock.mt_type = UCP_MT_TYPE_NONE;
context->mt_lock.lock.mt_worker_async = NULL;
} else {
ucs_error("worker async is not set for context %p", context);
}
}
}
}

static ucs_status_t
ucp_fill_rndv_frag_config(const ucp_context_config_names_t *config,
const size_t *default_sizes, size_t *sizes)
Expand Down
4 changes: 4 additions & 0 deletions src/ucp/core/ucp_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -755,4 +755,8 @@ ucp_config_modify_internal(ucp_config_t *config, const char *name,

void ucp_apply_uct_config_list(ucp_context_h context, void *config);


void ucp_context_set_worker_async(ucp_context_h context,
ucs_async_context_t *async);

#endif
5 changes: 3 additions & 2 deletions src/ucp/core/ucp_rkey.c
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ ucp_rkey_unpack_distance(const ucp_rkey_packed_distance_t *packed_distance,
distance->bandwidth = UCS_FP8_UNPACK(BANDWIDTH, packed_distance->bandwidth);
}

/* context->mt_lock must be held */
UCS_PROFILE_FUNC(ssize_t, ucp_rkey_pack_memh,
(context, md_map, memh, address, length, mem_info, sys_dev_map,
sys_distance, uct_flags, buffer),
Expand Down Expand Up @@ -641,7 +642,7 @@ ucp_memh_pack_internal(ucp_mem_h memh, const ucp_memh_pack_params_t *params,
return UCS_OK;
}

UCP_THREAD_CS_ENTER(&context->mt_lock);
UCP_THREAD_CS_ASYNC_ENTER(&context->mt_lock);

size = ucp_memh_packed_size(memh, flags, rkey_compat);

Expand Down Expand Up @@ -676,7 +677,7 @@ ucp_memh_pack_internal(ucp_mem_h memh, const ucp_memh_pack_params_t *params,
err_destroy:
ucs_free(memh_buffer);
out:
UCP_THREAD_CS_EXIT(&context->mt_lock);
UCP_THREAD_CS_ASYNC_EXIT(&context->mt_lock);
return status;
}

Expand Down
57 changes: 44 additions & 13 deletions src/ucp/core/ucp_thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
typedef enum ucp_mt_type {
UCP_MT_TYPE_NONE = 0,
UCP_MT_TYPE_SPINLOCK,
UCP_MT_TYPE_MUTEX
UCP_MT_TYPE_MUTEX,
UCP_MT_TYPE_WORKER_ASYNC
} ucp_mt_type_t;


Expand All @@ -36,6 +37,13 @@ typedef struct ucp_mt_lock {
at one time. Spinlock is the default option. */
ucs_recursive_spinlock_t mt_spinlock;
pthread_mutex_t mt_mutex;
/* Lock for MULTI_THREAD_WORKER case, when mt-single context is used by
* a single mt-shared worker. In this case the worker progress flow is
* already protected by worker mutex, and we don't need to lock inside
* that flow. This is to protect certain API calls that can be triggered
* from the user thread without holding a worker mutex.
* Essentially this mutex is a pointer to a worker mutex */
ucs_async_context_t *mt_worker_async;
} lock;
} ucp_mt_lock_t;

Expand All @@ -58,21 +66,44 @@ typedef struct ucp_mt_lock {
pthread_mutex_destroy(&((_lock_ptr)->lock.mt_mutex)); \
} \
} while (0)
#define UCP_THREAD_CS_ENTER(_lock_ptr) \

static UCS_F_ALWAYS_INLINE void ucp_mt_lock_lock(ucp_mt_lock_t *lock)
{
if (lock->mt_type == UCP_MT_TYPE_SPINLOCK) {
ucs_recursive_spin_lock(&lock->lock.mt_spinlock);
} else if (lock->mt_type == UCP_MT_TYPE_MUTEX) {
pthread_mutex_lock(&lock->lock.mt_mutex);
}
}

static UCS_F_ALWAYS_INLINE void ucp_mt_lock_unlock(ucp_mt_lock_t *lock)
{
if (lock->mt_type == UCP_MT_TYPE_SPINLOCK) {
ucs_recursive_spin_unlock(&lock->lock.mt_spinlock);
} else if (lock->mt_type == UCP_MT_TYPE_MUTEX) {
pthread_mutex_unlock(&lock->lock.mt_mutex);
}
}

#define UCP_THREAD_CS_ENTER(_lock_ptr) ucp_mt_lock_lock(_lock_ptr)
#define UCP_THREAD_CS_EXIT(_lock_ptr) ucp_mt_lock_unlock(_lock_ptr)

#define UCP_THREAD_CS_ASYNC_ENTER(_lock_ptr) \
do { \
if ((_lock_ptr)->mt_type == UCP_MT_TYPE_SPINLOCK) { \
ucs_recursive_spin_lock(&((_lock_ptr)->lock.mt_spinlock)); \
} else if ((_lock_ptr)->mt_type == UCP_MT_TYPE_MUTEX) { \
pthread_mutex_lock(&((_lock_ptr)->lock.mt_mutex)); \
if ((_lock_ptr)->mt_type == UCP_MT_TYPE_WORKER_ASYNC) { \
UCS_ASYNC_BLOCK((_lock_ptr)->lock.mt_worker_async); \
} else { \
ucp_mt_lock_lock(_lock_ptr); \
} \
} while (0)
#define UCP_THREAD_CS_EXIT(_lock_ptr) \
} while(0)

#define UCP_THREAD_CS_ASYNC_EXIT(_lock_ptr) \
do { \
if ((_lock_ptr)->mt_type == UCP_MT_TYPE_SPINLOCK) { \
ucs_recursive_spin_unlock(&((_lock_ptr)->lock.mt_spinlock)); \
} else if ((_lock_ptr)->mt_type == UCP_MT_TYPE_MUTEX) { \
pthread_mutex_unlock(&((_lock_ptr)->lock.mt_mutex)); \
if ((_lock_ptr)->mt_type == UCP_MT_TYPE_WORKER_ASYNC) { \
UCS_ASYNC_UNBLOCK((_lock_ptr)->lock.mt_worker_async); \
} else { \
ucp_mt_lock_unlock(_lock_ptr); \
} \
} while (0)
} while(0)

#endif
10 changes: 10 additions & 0 deletions src/ucp/core/ucp_worker.c
Original file line number Diff line number Diff line change
Expand Up @@ -2569,6 +2569,10 @@ ucs_status_t ucp_worker_create(ucp_context_h context,
goto err_free_tm_offload_stats;
}

if (worker->flags & UCP_WORKER_FLAG_THREAD_MULTI) {
ucp_context_set_worker_async(context, &worker->async);
}

/* Create the underlying UCT worker */
status = uct_worker_create(&worker->async, uct_thread_mode, &worker->uct);
if (status != UCS_OK) {
Expand Down Expand Up @@ -2668,6 +2672,9 @@ ucs_status_t ucp_worker_create(ucp_context_h context,
err_destroy_uct_worker:
uct_worker_destroy(worker->uct);
err_destroy_async:
if (worker->flags & UCP_WORKER_FLAG_THREAD_MULTI) {
ucp_context_set_worker_async(context, NULL);
}
ucs_async_context_cleanup(&worker->async);
err_free_tm_offload_stats:
UCS_STATS_NODE_FREE(worker->tm_offload_stats);
Expand Down Expand Up @@ -2923,6 +2930,9 @@ void ucp_worker_destroy(ucp_worker_h worker)
ucs_conn_match_cleanup(&worker->conn_match_ctx);
ucp_worker_wakeup_cleanup(worker);
uct_worker_destroy(worker->uct);
if (worker->flags & UCP_WORKER_FLAG_THREAD_MULTI) {
ucp_context_set_worker_async(worker->context, NULL);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems weird that context is pointing to async lock that belongs to a worker created by it, at least from object ownership/hierarchy perspective. also, what happens if there are multiple workers on the context?
I would expect the context to create the async context, and workers of the context to point to it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree, it's better if context owns this async, but let me precise this approach below

if there are multiple workers on the context?

If there are multiple workers then context must be created with mt_workers_shared flag, which defines context->mt_lock. Once it's defined, then ucp_context_set_worker_async has no any effect (just debug message)

I would expect the context to create the async context, and workers of the context to point to it.

Well, multiple workers should not point to the same context-async object, each worker should own it's own async object. Otherwise it will be a single point of contention.
In fact we need this context-async object only in one case: when context is mt-single, but there is a SINGLE worker which is mt-shared (UCP_WORKER_FLAG_THREAD_MULTI). Only in this case this single worker points to a context async, right?

As we discussed before:

  • Context is shared, single thread per worker: In this case there is a lock (mutex/spin) per context, and workers are lock free
  • Context is single, there is a single worker used from multiple threads - this is the gap, and only in this case worker uses context-async lock
  • Context is shared, multiple workers/threads - we are covered here by existing locking scheme: mutex/spin per context and async per each worker. No context-async is used in this case

Do we agree on this?
So the context-async is gonna be used only in a case of a single MT worker, that's why I initially came up with this solution. But I will do like you suggested, it will not change much

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One side effect of this change is that we create context-async "just in case", because context doesn't know in advance whether context->async will be used later on by a single mt-worker. Not a big deal though, just to make it clear

ucs_async_context_cleanup(&worker->async);
UCS_STATS_NODE_FREE(worker->tm_offload_stats);
UCS_STATS_NODE_FREE(worker->stats);
Expand Down
21 changes: 12 additions & 9 deletions src/ucp/proto/proto_common.inl
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,12 @@ ucp_proto_request_set_stage(ucp_request_t *req, uint8_t proto_stage)
{
const ucp_proto_t *proto = req->send.proto_config->proto;

ucs_assertv(proto_stage < UCP_PROTO_STAGE_LAST, "stage=%"PRIu8,
ucs_assertv(proto_stage < UCP_PROTO_STAGE_LAST, "stage=%" PRIu8,
proto_stage);
ucs_assert(proto->progress[proto_stage] != NULL);

ucp_trace_req(req, "set to stage %u, progress function '%s'", proto_stage,
ucs_debug_get_symbol_name(proto->progress[proto_stage]));
ucs_debug_get_symbol_name((void *)proto->progress[proto_stage]));
req->send.proto_stage = proto_stage;

/* Set pointer to progress function */
Expand All @@ -186,7 +186,7 @@ static void ucp_proto_request_set_proto(ucp_request_t *req,
const ucp_proto_config_t *proto_config,
size_t msg_length)
{
ucs_assertv(req->flags & UCP_REQUEST_FLAG_PROTO_SEND, "flags=0x%"PRIx32,
ucs_assertv(req->flags & UCP_REQUEST_FLAG_PROTO_SEND, "flags=0x%" PRIx32,
req->flags);

req->send.proto_config = proto_config;
Expand Down Expand Up @@ -346,6 +346,7 @@ ucp_proto_request_pack_rkey(ucp_request_t *req, ucp_md_map_t md_map,
const ucs_sys_dev_distance_t *dev_distance,
void *rkey_buffer)
{
ucp_context_h context = req->send.ep->worker->context;
const ucp_datatype_iter_t *dt_iter = &req->send.state.dt_iter;
ucp_mem_h memh;
ssize_t packed_rkey_size;
Expand All @@ -366,17 +367,19 @@ ucp_proto_request_pack_rkey(ucp_request_t *req, ucp_md_map_t md_map,
ucs_unlikely(memh->flags & UCP_MEMH_FLAG_HAS_AUTO_GVA)) {
ucp_memh_disable_gva(memh, md_map);
}

if (!ucs_test_all_flags(memh->md_map, md_map)) {
ucs_trace("dt_iter_md_map=0x%"PRIx64" md_map=0x%"PRIx64, memh->md_map,
md_map);
ucs_trace("dt_iter_md_map=0x%" PRIx64 " md_map=0x%" PRIx64,
memh->md_map, md_map);
}

/* TODO: context lock is not scalable. Consider fine-grained lock per memh,
* immutable memh with rkey cache, RCU/COW */
UCP_THREAD_CS_ENTER(&context->mt_lock);
packed_rkey_size = ucp_rkey_pack_memh(
req->send.ep->worker->context, md_map & memh->md_map, memh,
dt_iter->type.contig.buffer, dt_iter->length, &dt_iter->mem_info,
distance_dev_map, dev_distance,
context, md_map & memh->md_map, memh, dt_iter->type.contig.buffer,
dt_iter->length, &dt_iter->mem_info, distance_dev_map, dev_distance,
ucp_ep_config(req->send.ep)->uct_rkey_pack_flags, rkey_buffer);
UCP_THREAD_CS_EXIT(&context->mt_lock);

if (packed_rkey_size < 0) {
ucs_error("failed to pack remote key: %s",
Expand Down
12 changes: 10 additions & 2 deletions src/ucp/rndv/rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,15 @@ size_t ucp_rndv_rts_pack(ucp_request_t *sreq, ucp_rndv_rts_hdr_t *rndv_rts_hdr,
rndv_rts_hdr->address = (uintptr_t)sreq->send.buffer;
rkey_buf = UCS_PTR_BYTE_OFFSET(rndv_rts_hdr,
sizeof(*rndv_rts_hdr));
packed_rkey_size = ucp_rkey_pack_memh(

UCP_THREAD_CS_ENTER(&worker->context->mt_lock);
packed_rkey_size = ucp_rkey_pack_memh(
worker->context, sreq->send.rndv.md_map,
sreq->send.state.dt.dt.contig.memh, sreq->send.buffer,
sreq->send.length, &mem_info, 0, NULL,
ucp_ep_config(sreq->send.ep)->uct_rkey_pack_flags, rkey_buf);
UCP_THREAD_CS_EXIT(&worker->context->mt_lock);

if (packed_rkey_size < 0) {
ucs_fatal("failed to pack rendezvous remote key: %s",
ucs_status_string((ucs_status_t)packed_rkey_size));
Expand All @@ -205,6 +209,7 @@ static size_t ucp_rndv_rtr_pack(void *dest, void *arg)
ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = dest;
ucp_request_t *rreq = ucp_request_get_super(rndv_req);
ucp_ep_h ep = rndv_req->send.ep;
ucp_context_h context = ep->worker->context;
ucp_memory_info_t mem_info;
ssize_t packed_rkey_size;

Expand All @@ -221,12 +226,15 @@ static size_t ucp_rndv_rtr_pack(void *dest, void *arg)
mem_info.type = rreq->recv.dt_iter.mem_info.type;
mem_info.sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN;

UCP_THREAD_CS_ENTER(&context->mt_lock);
packed_rkey_size = ucp_rkey_pack_memh(
ep->worker->context, rndv_req->send.rndv.md_map,
context, rndv_req->send.rndv.md_map,
rreq->recv.dt_iter.type.contig.memh,
rreq->recv.dt_iter.type.contig.buffer, rndv_req->send.length,
&mem_info, 0, NULL, ucp_ep_config(ep)->uct_rkey_pack_flags,
rndv_rtr_hdr + 1);
UCP_THREAD_CS_EXIT(&context->mt_lock);

if (packed_rkey_size < 0) {
return packed_rkey_size;
}
Expand Down
9 changes: 7 additions & 2 deletions src/ucp/rndv/rndv_rtr.c
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ static size_t ucp_proto_rndv_rtr_mtype_pack(void *dest, void *arg)
{
ucp_rndv_rtr_hdr_t *rtr = dest;
ucp_request_t *req = arg;
ucp_context_h context = req->send.ep->worker->context;
const ucp_proto_rndv_rtr_priv_t *rpriv = req->send.proto_config->priv;
ucp_md_map_t md_map = rpriv->super.md_map;
ucp_mem_desc_t *mdesc = req->send.rndv.mdesc;
Expand All @@ -266,10 +267,14 @@ static size_t ucp_proto_rndv_rtr_mtype_pack(void *dest, void *arg)
/* Pack remote key for the fragment */
mem_info.type = mdesc->memh->mem_type;
mem_info.sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN;
packed_rkey_size = ucp_rkey_pack_memh(req->send.ep->worker->context, md_map,
mdesc->memh, mdesc->ptr,

UCP_THREAD_CS_ENTER(&context->mt_lock);
packed_rkey_size = ucp_rkey_pack_memh(context, md_map, mdesc->memh,
mdesc->ptr,
req->send.state.dt_iter.length,
&mem_info, 0, NULL, 0, rtr + 1);
UCP_THREAD_CS_EXIT(&context->mt_lock);

if (packed_rkey_size < 0) {
ucs_error("failed to pack remote key: %s",
ucs_status_string((ucs_status_t)packed_rkey_size));
Expand Down
Loading
Loading