diff --git a/src/api/cpp/nixl.h b/src/api/cpp/nixl.h index 3cc18a222..32299e948 100644 --- a/src/api/cpp/nixl.h +++ b/src/api/cpp/nixl.h @@ -325,6 +325,8 @@ class nixlAgent { * @param req_hndl [in] Transfer request obtained from makeXferReq/createXferReq * @param gpu_req_hndl [out] GPU transfer request handle * @return nixl_status_t Error code if call was not successful + * + * @note This call may block until the associated connection is established. */ nixl_status_t createGpuXferReq(const nixlXferReqH &req_hndl, nixlGpuXferReqH &gpu_req_hndl) const; diff --git a/src/plugins/ucx/ucx_backend.cpp b/src/plugins/ucx/ucx_backend.cpp index 0456c20b4..7f10700ff 100644 --- a/src/plugins/ucx/ucx_backend.cpp +++ b/src/plugins/ucx/ucx_backend.cpp @@ -1688,7 +1688,8 @@ nixlUcxEngine::createGpuXferReq(const nixlBackendReqH &req_hndl, } try { - gpu_req_hndl = nixl::ucx::createGpuXferReq(*ep, local_mems, remote_rkeys, remote_addrs); + gpu_req_hndl = nixl::ucx::createGpuXferReq( + *ep, uws, local_mems, remote_rkeys, remote_addrs); NIXL_TRACE << "Created device memory list: ep=" << ep->getEp() << " handle=" << gpu_req_hndl << " worker_id=" << workerId << " num_elements=" << local_mems.size(); return NIXL_SUCCESS; diff --git a/src/utils/ucx/gpu_xfer_req_h.cpp b/src/utils/ucx/gpu_xfer_req_h.cpp index 3055219b4..68b85f0e2 100644 --- a/src/utils/ucx/gpu_xfer_req_h.cpp +++ b/src/utils/ucx/gpu_xfer_req_h.cpp @@ -21,6 +21,8 @@ #include "rkey.h" #include "config.h" +#include + extern "C" { #ifdef HAVE_UCX_GPU_DEVICE_API #include @@ -33,6 +35,7 @@ namespace nixl::ucx { nixlGpuXferReqH createGpuXferReq(const nixlUcxEp &ep, + const std::vector> &all_workers, const std::vector &local_mems, const std::vector &remote_rkeys, const std::vector &remote_addrs) { @@ -74,8 +77,23 @@ createGpuXferReq(const nixlUcxEp &ep, params.element_size = sizeof(ucp_device_mem_list_elem_t); params.num_elements = ucp_elements.size(); + const auto start = std::chrono::steady_clock::now(); + constexpr auto timeout = std::chrono::seconds(5); ucp_device_mem_list_handle_h ucx_handle; - ucs_status_t ucs_status = ucp_device_mem_list_create(ep.getEp(), ¶ms, &ucx_handle); + ucs_status_t ucs_status; + // Workaround: loop until wireup is completed + while ((ucs_status = ucp_device_mem_list_create(ep.getEp(), ¶ms, &ucx_handle)) == + UCS_ERR_NOT_CONNECTED) { + for (const auto &w : all_workers) { + w->progress(); + } + + if (std::chrono::steady_clock::now() - start > timeout) { + throw std::runtime_error( + "Timeout waiting for endpoint wireup completion has been exceeded"); + } + } + if (ucs_status != UCS_OK) { throw std::runtime_error(std::string("Failed to create device memory list: ") + ucs_status_string(ucs_status)); @@ -96,6 +114,7 @@ releaseGpuXferReq(nixlGpuXferReqH gpu_req) noexcept { nixlGpuXferReqH createGpuXferReq(const nixlUcxEp &ep, + const std::vector> &all_workers, const std::vector &local_mems, const std::vector &remote_rkeys, const std::vector &remote_addrs) { diff --git a/src/utils/ucx/gpu_xfer_req_h.h b/src/utils/ucx/gpu_xfer_req_h.h index 9798308b1..8b700fd0f 100644 --- a/src/utils/ucx/gpu_xfer_req_h.h +++ b/src/utils/ucx/gpu_xfer_req_h.h @@ -19,17 +19,20 @@ #define NIXL_SRC_UTILS_UCX_GPU_XFER_REQ_H_H #include +#include #include "nixl_types.h" class nixlUcxEp; class nixlUcxMem; +class nixlUcxWorker; namespace nixl::ucx { class rkey; nixlGpuXferReqH createGpuXferReq(const nixlUcxEp &ep, + const std::vector> &all_workers, const std::vector &local_mems, const std::vector &remote_rkeys, const std::vector &remote_addrs); diff --git a/test/gtest/device_api/single_write_test.cu b/test/gtest/device_api/single_write_test.cu index eec111fe8..bd93747bc 100644 --- a/test/gtest/device_api/single_write_test.cu +++ b/test/gtest/device_api/single_write_test.cu @@ -194,50 +194,8 @@ protected: agent.registerMem(reg_list); } - // TODO: remove this function once a blocking CreateGpuXferReq is implemented - void - completeWireup(size_t from_agent, size_t to_agent, - const std::vector &wireup_src, - const std::vector &wireup_dst) { - nixl_opt_args_t wireup_params; - - for (size_t worker_id = 0; worker_id < numWorkers; worker_id++) { - wireup_params.customParam = "worker_id=" + std::to_string(worker_id); - - nixlXferReqH *wireup_req; - nixl_status_t status = getAgent(from_agent) - .createXferReq(NIXL_WRITE, - makeDescList(wireup_src, VRAM_SEG), - makeDescList(wireup_dst, VRAM_SEG), - getAgentName(to_agent), - wireup_req, - &wireup_params); - - ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to create wireup request for worker " << worker_id; - - status = getAgent(from_agent).postXferReq(wireup_req); - ASSERT_TRUE(status == NIXL_SUCCESS || status == NIXL_IN_PROG) - << "Failed to post wireup for worker " << worker_id; - - nixl_status_t xfer_status; - do { - xfer_status = getAgent(from_agent).getXferStatus(wireup_req); - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } while (xfer_status == NIXL_IN_PROG); - - ASSERT_EQ(xfer_status, NIXL_SUCCESS) << "Warmup failed for worker " << worker_id; - - status = getAgent(from_agent).releaseXferReq(wireup_req); - ASSERT_EQ(status, NIXL_SUCCESS); - } - } - void exchangeMD(size_t from_agent, size_t to_agent) { - std::vector wireup_src, wireup_dst; - createRegisteredMem(getAgent(from_agent), 64, 1, VRAM_SEG, wireup_src); - createRegisteredMem(getAgent(to_agent), 64, 1, VRAM_SEG, wireup_dst); - for (size_t i = 0; i < agents.size(); i++) { nixl_blob_t md; nixl_status_t status = agents[i]->getLocalMD(md); @@ -251,8 +209,6 @@ protected: EXPECT_EQ(remote_agent_name, getAgentName(i)); } } - - completeWireup(from_agent, to_agent, wireup_src, wireup_dst); } void