Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
06c380a
UCP/DEVICE: Make memh and local_addr optional for counter operations
michal-shalev Oct 11, 2025
a494c52
Merge branch 'master' into optional-memh-localaddr-counter-ops
michal-shalev Oct 13, 2025
60a1131
UCP/DEVICE: PR fixes
michal-shalev Oct 13, 2025
6a3eb78
UCP/DEVICE: PR fixes 2.0
michal-shalev Oct 13, 2025
205b01e
UCP/DEVICE: PR fixes 3.0
michal-shalev Oct 13, 2025
589b0b4
UCP/DEVICE: PR fixes 4.0
michal-shalev Oct 13, 2025
4573621
UCP/DEVICE: PR fixes 5.0
michal-shalev Oct 13, 2025
b0eff73
UCP/DEVICE: Update perftest
michal-shalev Oct 13, 2025
10646cd
UCP/DEVICE: Update perftest 2.0
michal-shalev Oct 13, 2025
534057b
UCP/DEVICE: PR fixes 6.0
michal-shalev Oct 13, 2025
dc11fd6
UCP/DEVICE: CI fix
michal-shalev Oct 13, 2025
c718cec
UCP/DEVICE: PR fixes 7.0
michal-shalev Oct 13, 2025
2288f1b
UCP/DEVICE: PR fixes 8.0
michal-shalev Oct 15, 2025
b1f2ad0
UCP/DEVICE: PR fixes 9.0
michal-shalev Oct 15, 2025
5c3b023
UCP/DEVICE: PR fixes 10.0
michal-shalev Oct 16, 2025
06a662e
UCP/DEVICE: PR fixes 11.0
michal-shalev Oct 20, 2025
777d06a
UCP/DEVICE: PR fixes 12.0
michal-shalev Oct 20, 2025
62e6732
UCP/DEVICE: PR fixes 13.0
michal-shalev Oct 20, 2025
8219c47
UCP/DEVICE: Fix code style
michal-shalev Oct 20, 2025
025a1a0
UCP/DEVICE: PR fixes 14.0
michal-shalev Oct 20, 2025
e46a302
UCP/DEVICE: Fix documentation
michal-shalev Oct 20, 2025
c70739c
UCP/DEVICE: Add ucp_device_detect_uct_memh
michal-shalev Oct 20, 2025
af75f5c
UCP/DEVICE: Add tests
michal-shalev Oct 20, 2025
c967767
UCP/DEVICE: PR fixes 15.0
michal-shalev Oct 22, 2025
c444251
UCP/DEVICE: PR fixes 16.0
michal-shalev Oct 22, 2025
a07231c
Merge branch 'master' into optional-memh-localaddr-counter-ops
michal-shalev Oct 22, 2025
1f0d285
UCP/DEVICE: PR fixes 17.0
michal-shalev Oct 25, 2025
0d30594
UCP/DEVICE: PR fixes 18.0
michal-shalev Oct 25, 2025
e9c1382
UCP/DEVICE: PR fixes 19.0
michal-shalev Oct 26, 2025
2bbfa15
UCP/DEVICE: PR fixes 20.0
michal-shalev Oct 26, 2025
438346f
UCP/DEVICE: PR fixes 21.0
michal-shalev Oct 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/ucp/api/device/ucp_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ BEGIN_C_DECLS
enum ucp_device_mem_list_elem_field {
UCP_DEVICE_MEM_LIST_ELEM_FIELD_MEMH = UCS_BIT(0), /**< Source memory handle */
UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY = UCS_BIT(1), /**< Unpacked remote memory key */
UCP_DEVICE_MEM_LIST_ELEM_FIELD_LOCAL_ADDR = UCS_BIT(2), /**< Local address */
UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR = UCS_BIT(3), /**< Remote address */
UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH = UCS_BIT(4) /**< Length of the local buffer in bytes */
UCP_DEVICE_MEM_LIST_ELEM_FIELD_LOCAL_ADDR = UCS_BIT(2), /**< Local address (optional for counter elements) */
UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR = UCS_BIT(3), /**< Remote address */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, it is also always required, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not always, check remote_offset in partial, users can to pass null to remote address and then the addr aa remote_offset

UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH = UCS_BIT(4) /**< Length of the local buffer in bytes (optional for counter elements) */
};


Expand All @@ -48,6 +48,14 @@ enum ucp_device_mem_list_elem_field {
*
* This describes a pair of local and remote memory for which a memory operation
* can later be performed multiple times, possibly with varying memory offsets.
*
* @note The @a memh and @a local_addr fields are optional for elements
* that are only used for remote addressing (e.g., counter elements):
* - @ref ucp_device_counter_inc: All elements may omit these fields
* - @ref ucp_device_put_multi: The last element (counter) may omit these
* fields
* - @ref ucp_device_put_multi_partial: The element at counter_index may
* omit these fields if not also in mem_list_indices
*/
typedef struct ucp_device_mem_list_elem {
/**
Expand All @@ -60,11 +68,13 @@ typedef struct ucp_device_mem_list_elem {

/**
* Local memory registration handle.
* Optional for elements used only for remote addressing (e.g., counters).
*/
ucp_mem_h memh;

/**
* Local memory address for the device transfer operations.
* Optional for elements used only for remote addressing (e.g., counters).
*/
void* local_addr;

Expand Down
100 changes: 82 additions & 18 deletions src/ucp/core/ucp_device.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ KHASH_IMPL(ucp_device_handle_allocs, ucp_device_mem_list_handle_h,
static khash_t(ucp_device_handle_allocs) ucp_device_handle_hash;
static ucs_spinlock_t ucp_device_handle_hash_lock;

/* Size of temporary allocation for local sys_dev detection */
#define UCP_DEVICE_LOCAL_SYS_DEV_DETECT_SIZE 64


void ucp_device_init(void)
{
Expand Down Expand Up @@ -121,34 +124,46 @@ ucp_device_mem_list_params_check(const ucp_device_mem_list_params_t *params,
RKEY, NULL);

/* TODO: Delegate most of checks below to proto selection */
if ((rkey == NULL) || (memh == NULL)) {
ucs_error("element[%lu] rkey=%p, memh=%p", i, rkey, memh);
if (rkey == NULL) {
ucs_error("element[%lu] rkey is NULL", i);
return UCS_ERR_INVALID_PARAM;
}

if (i == 0) {
*local_sys_dev = memh->sys_dev;
*local_md_map = memh->md_map;
*mem_type = memh->mem_type;
if (memh != NULL) {
*local_sys_dev = memh->sys_dev;
*local_md_map = memh->md_map;
*mem_type = memh->mem_type;
} else {
*mem_type = rkey->mem_type;
*local_md_map = UINT64_MAX;
}
*rkey_cfg_index = rkey->cfg_index;
if (*rkey_cfg_index == UCP_WORKER_CFG_INDEX_NULL) {
ucs_debug("invalid first rkey: cfg_index=%d", *rkey_cfg_index);
return UCS_ERR_INVALID_PARAM;
}
} else {
*local_md_map &= memh->md_map;
if (rkey->cfg_index != *rkey_cfg_index) {
ucs_debug("mismatched rkey config index: "
"ucp_rkey[%lu]->cfg_index=%u cfg_index=%u",
i, rkey->cfg_index, *rkey_cfg_index);
return UCS_ERR_UNSUPPORTED;
}

if (memh->sys_dev != *local_sys_dev) {
ucs_debug("mismatched local sys_dev: ucp_memh[%zu].sys_dev=%u "
"first_sys_dev=%u",
i, memh->sys_dev, *local_sys_dev);
return UCS_ERR_UNSUPPORTED;
if (memh != NULL) {
if (*local_sys_dev == UCS_SYS_DEVICE_ID_UNKNOWN) {
*local_sys_dev = memh->sys_dev;
*local_md_map = memh->md_map;
} else {
*local_md_map &= memh->md_map;
if (memh->sys_dev != *local_sys_dev) {
ucs_debug("mismatched local sys_dev: ucp_memh[%zu].sys_dev=%u "
"first_sys_dev=%u",
i, memh->sys_dev, *local_sys_dev);
return UCS_ERR_UNSUPPORTED;
}
}
}
}
}
Expand Down Expand Up @@ -230,6 +245,42 @@ static void ucp_device_mem_list_lane_lookup(
}
}

static ucs_status_t
ucp_device_detect_local_sys_dev(ucp_context_h context,
ucs_memory_type_t mem_type,
ucs_sys_device_t *local_sys_dev_p)
{
ucs_memory_info_t mem_info;
uct_allocated_memory_t detect_mem;
ucs_status_t status;

status = ucp_mem_do_alloc(context, NULL,
UCP_DEVICE_LOCAL_SYS_DEV_DETECT_SIZE,
UCT_MD_MEM_ACCESS_LOCAL_READ |
UCT_MD_MEM_ACCESS_LOCAL_WRITE,
mem_type, UCS_SYS_DEVICE_ID_UNKNOWN,
"local_sys_dev_detect", &detect_mem);
if (status != UCS_OK) {
ucs_error("failed to allocate memory for sys_dev detection: %s",
ucs_status_string(status));
return status;
}

ucp_memory_detect_internal(context, detect_mem.address, detect_mem.length,
&mem_info);
*local_sys_dev_p = mem_info.sys_dev;

uct_mem_free(&detect_mem);

if (*local_sys_dev_p == UCS_SYS_DEVICE_ID_UNKNOWN) {
ucs_error("detected unknown local_sys_dev");
return UCS_ERR_UNSUPPORTED;
}

ucs_trace("detected local_sys_dev=%u", *local_sys_dev_p);
return UCS_OK;
}

static ucs_status_t ucp_device_mem_list_create_handle(
ucp_ep_h ep, ucs_sys_device_t local_sys_dev,
const ucp_device_mem_list_params_t *params,
Expand Down Expand Up @@ -361,13 +412,17 @@ static ucs_status_t ucp_device_mem_list_create_handle(
ucp_ep_get_rsc_index(ep, lanes[i]));
ucp_element = params->elements;
for (j = 0; j < params->num_elements; j++) {
/* Local registration */
uct_memh = ucp_element->memh->uct[local_md_index];
ucs_assertv((ucp_element->memh->md_map & UCS_BIT(local_md_index)) !=
0,
"uct_memh=%p md_map=0x%lx local_md_index=%u", uct_memh,
ucp_element->memh->md_map, local_md_index);
ucs_assert(uct_memh != UCT_MEM_HANDLE_NULL);
if (ucp_element->memh != NULL) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems we can't assume it is NULL without checking UCP_DEVICE_MEM_LIST_ELEM_FIELD_MEMH

/* Local registration */
uct_memh = ucp_element->memh->uct[local_md_index];
ucs_assertv(
(ucp_element->memh->md_map & UCS_BIT(local_md_index)) != 0,
"uct_memh=%p md_map=0x%lx local_md_index=%u", uct_memh,
ucp_element->memh->md_map, local_md_index);
ucs_assert(uct_memh != UCT_MEM_HANDLE_NULL);
} else {
uct_memh = UCT_MEM_HANDLE_NULL;
}

/* Remote registration */
rkey_index =
Expand Down Expand Up @@ -430,6 +485,15 @@ ucp_device_mem_list_create(ucp_ep_h ep,
return status;
}

if (local_sys_dev == UCS_SYS_DEVICE_ID_UNKNOWN) {
status = ucp_device_detect_local_sys_dev(ep->worker->context, mem_type,
&local_sys_dev);
if (status != UCS_OK) {
ucs_error("failed to detect local_sys_dev: %s", ucs_status_string(status));
return status;
}
}

/* Perform pseudo lane selection without size */
rkey_config = &ep->worker->rkey_config[rkey_cfg_index];
ep_config = ucp_worker_ep_config(ep->worker, rkey_config->key.ep_cfg_index);
Expand Down
63 changes: 46 additions & 17 deletions test/gtest/ucp/test_ucp_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,15 @@ class test_ucp_device : public ucp_test {
static constexpr uint64_t SEED_SRC = 0x1234;
static constexpr uint64_t SEED_DST = 0x4321;

enum mem_list_mode_t {
MODE_DATA_ONLY,
MODE_COUNTER_ONLY,
MODE_LAST_ELEM_COUNTER
};

mem_list(entity &sender, entity &receiver, size_t size, unsigned count,
ucs_memory_type_t mem_type = UCS_MEMORY_TYPE_CUDA);
ucs_memory_type_t mem_type = UCS_MEMORY_TYPE_CUDA,
mem_list_mode_t mode = MODE_DATA_ONLY);
~mem_list();

void *src_ptr(unsigned index) const;
Expand All @@ -49,6 +56,7 @@ class test_ucp_device : public ucp_test {
std::vector<std::unique_ptr<mapped_buffer>> m_src, m_dst;
std::vector<ucs::handle<ucp_rkey_h>> m_rkeys;
ucp_device_mem_list_handle_h m_mem_list_h;
mem_list_mode_t m_mode;
};

size_t counter_size();
Expand Down Expand Up @@ -86,32 +94,50 @@ void test_ucp_device::init()

test_ucp_device::mem_list::mem_list(entity &sender, entity &receiver,
size_t size, unsigned count,
ucs_memory_type_t mem_type) :
m_receiver(receiver)
ucs_memory_type_t mem_type,
mem_list_mode_t mode) :
m_receiver(receiver), m_mode(mode)
{
// Prepare src and dst buffers
for (auto i = 0; i < count; ++i) {
m_src.emplace_back(new mapped_buffer(size, sender, 0, mem_type));
bool need_src = (mode == MODE_DATA_ONLY) ||
(mode == MODE_LAST_ELEM_COUNTER && i < count - 1);
if (need_src) {
m_src.emplace_back(new mapped_buffer(size, sender, 0, mem_type));
m_src.back()->pattern_fill(SEED_SRC, size);
}
m_dst.emplace_back(new mapped_buffer(size, receiver, 0, mem_type));
m_rkeys.push_back(m_dst.back()->rkey(sender));
m_src.back()->pattern_fill(SEED_SRC, size);
m_dst.back()->pattern_fill(SEED_DST, size);
}

// Initialize elements
std::vector<ucp_device_mem_list_elem_t> elems(count);
for (auto i = 0; i < count; ++i) {
auto &elem = elems[i];
elem.field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_MEMH |
UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY |
UCP_DEVICE_MEM_LIST_ELEM_FIELD_LOCAL_ADDR |
UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR |
UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH;
elem.memh = m_src[i]->memh();
auto &elem = elems[i];
bool is_counter = (mode == MODE_COUNTER_ONLY) ||
(mode == MODE_LAST_ELEM_COUNTER && i == count - 1);

if (is_counter) {
elem.field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY |
UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR |
UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH;
elem.memh = NULL;
elem.local_addr = NULL;
elem.length = size;
} else {
/* Data element: with memh and local_addr */
elem.field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_MEMH |
UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY |
UCP_DEVICE_MEM_LIST_ELEM_FIELD_LOCAL_ADDR |
UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR |
UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH;
elem.memh = m_src[i]->memh();
elem.local_addr = m_src[i]->ptr();
elem.length = m_src[i]->size();
}
elem.rkey = m_rkeys[i];
elem.local_addr = m_src[i]->ptr();
elem.remote_addr = reinterpret_cast<uint64_t>(m_dst[i]->ptr());
elem.length = m_src[i]->size();
}

// Initialize parameters
Expand Down Expand Up @@ -466,7 +492,8 @@ UCS_TEST_P(test_ucp_device_xfer, put_multi)
{
static constexpr size_t size = 32 * UCS_KBYTE;
unsigned count = get_multi_elem_count();
mem_list list(sender(), receiver(), size, count + 1);
mem_list list(sender(), receiver(), size, count + 1, UCS_MEMORY_TYPE_CUDA,
mem_list::MODE_LAST_ELEM_COUNTER);

const unsigned counter_index = count;
list.dst_counter_init(counter_index);
Expand All @@ -490,7 +517,8 @@ UCS_TEST_P(test_ucp_device_xfer, put_multi_partial)
{
static constexpr size_t size = 32 * UCS_KBYTE;
unsigned total_count = get_multi_elem_count() * 2;
mem_list list(sender(), receiver(), size, total_count + 1);
mem_list list(sender(), receiver(), size, total_count + 1, UCS_MEMORY_TYPE_CUDA,
mem_list::MODE_LAST_ELEM_COUNTER);

const unsigned counter_index = total_count;
list.dst_counter_init(counter_index);
Expand Down Expand Up @@ -540,7 +568,8 @@ UCS_TEST_P(test_ucp_device_xfer, put_multi_partial)
UCS_TEST_P(test_ucp_device_xfer, counter)
{
const size_t size = counter_size();
mem_list list(sender(), receiver(), size, 1);
mem_list list(sender(), receiver(), size, 1, UCS_MEMORY_TYPE_CUDA,
mem_list::MODE_COUNTER_ONLY);

static constexpr unsigned mem_list_index = 0;
list.dst_counter_init(mem_list_index);
Expand Down
Loading