diff --git a/src/api/cpp/nixl.h b/src/api/cpp/nixl.h index 3cc18a222..337aeb547 100644 --- a/src/api/cpp/nixl.h +++ b/src/api/cpp/nixl.h @@ -322,12 +322,22 @@ class nixlAgent { /** * @brief Create a GPU transfer request from a transfer request. * - * @param req_hndl [in] Transfer request obtained from makeXferReq/createXferReq - * @param gpu_req_hndl [out] GPU transfer request handle - * @return nixl_status_t Error code if call was not successful + * + * @param local_descs [in] Local descriptor list (empty for signal-only case) + * @param remote_descs [in] Remote descriptor list + * @param remote_agent [in] Remote agent name for accessing the remote data + * @param gpu_req_hndl [out] GPU transfer request handle + * @param req_hndl [out] Transfer request handle + * @param extra_params [in] Optional extra parameters + * @return nixl_status_t Error code if call was not successful */ nixl_status_t - createGpuXferReq(const nixlXferReqH &req_hndl, nixlGpuXferReqH &gpu_req_hndl) const; + createGpuXferReq(const nixl_xfer_dlist_t &local_descs, + const nixl_xfer_dlist_t &remote_descs, + const std::string &remote_agent, + nixlGpuXferReqH &gpu_req_hndl, + nixlXferReqH *&req_hndl, + const nixl_opt_args_t *extra_params = nullptr) const; /** * @brief Release transfer request from GPU memory diff --git a/src/core/nixl_agent.cpp b/src/core/nixl_agent.cpp index 2d8bff7a3..830ffb699 100644 --- a/src/core/nixl_agent.cpp +++ b/src/core/nixl_agent.cpp @@ -879,43 +879,33 @@ nixlAgent::createXferReq(const nixl_xfer_op_t &operation, if (!extra_params || extra_params->backends.size() == 0) { // Finding backends that support the corresponding memories // locally and remotely, and find the common ones. - backend_set_t* local_set = - data->memorySection->queryBackends(local_descs.getType()); - backend_set_t* remote_set = - data->remoteSections[remote_agent]->queryBackends( - remote_descs.getType()); + backend_set_t *local_set = data->memorySection->queryBackends(local_descs.getType()); + backend_set_t *remote_set = + data->remoteSections[remote_agent]->queryBackends(remote_descs.getType()); if (!local_set || !remote_set) { NIXL_ERROR_FUNC << "no backends found for local or remote for their " "corresponding memory type"; return NIXL_ERR_NOT_FOUND; } - for (auto & elm : *local_set) - if (remote_set->count(elm) != 0) - backend_set->insert(elm); + for (auto &elm : *local_set) + if (remote_set->count(elm) != 0) backend_set->insert(elm); if (backend_set->empty()) { NIXL_ERROR_FUNC << "no potential backend found to be able to do the transfer"; return NIXL_ERR_NOT_FOUND; } } else { - for (auto & elm : extra_params->backends) + for (auto &elm : extra_params->backends) backend_set->insert(elm->engine); } - // TODO: when central KV is supported, add a call to fetchRemoteMD - // TODO: merge descriptors back to back in memory (like makeXferReq). - // TODO [Perf]: Avoid heap allocation on the datapath, maybe use a mem pool - std::unique_ptr handle = std::make_unique(); handle->initiatorDescs = new nixl_meta_dlist_t(local_descs.getType()); handle->targetDescs = new nixl_meta_dlist_t(remote_descs.getType()); - // Currently we loop through and find first local match. Can use a - // preference list or more exhaustive search. for (auto & backend : *backend_set) { - // If populate fails, it clears the resp before return ret1 = data->memorySection->populate( local_descs, backend, *handle->initiatorDescs); ret2 = data->remoteSections[remote_agent]->populate( @@ -1217,25 +1207,139 @@ nixlAgent::releaseXferReq(nixlXferReqH *req_hndl) const { } nixl_status_t -nixlAgent::createGpuXferReq(const nixlXferReqH &req_hndl, nixlGpuXferReqH &gpu_req_hndl) const { - if (!req_hndl.engine) { - NIXL_ERROR_FUNC << "Invalid request handle[" << &req_hndl << "]: engine is null"; - return NIXL_ERR_INVALID_PARAM; +nixlAgent::createGpuXferReq(const nixl_xfer_dlist_t &local_descs, + const nixl_xfer_dlist_t &remote_descs, + const std::string &remote_agent, + nixlGpuXferReqH &gpu_req_hndl, + nixlXferReqH *&req_hndl, + const nixl_opt_args_t *extra_params) const { + nixl_status_t ret1, ret2; + nixl_opt_b_args_t opt_args; + + std::unique_ptr backend_set = std::make_unique(); + + req_hndl = nullptr; + + NIXL_SHARED_LOCK_GUARD(data->lock); + + if (data->remoteSections.count(remote_agent) == 0) { + NIXL_ERROR_FUNC << "metadata for remote agent '" << remote_agent << "' not found"; + data->addErrorTelemetry(NIXL_ERR_NOT_FOUND); + return NIXL_ERR_NOT_FOUND; } - if (!req_hndl.backendHandle) { - NIXL_ERROR_FUNC << "Invalid request handle[" << &req_hndl << "]: backendHandle is null"; + size_t total_bytes = 0; + if (local_descs.descCount() != remote_descs.descCount()) { + NIXL_ERROR_FUNC << "different descriptor list sizes (local=" << local_descs.descCount() + << ", remote=" << remote_descs.descCount() << ")"; return NIXL_ERR_INVALID_PARAM; } + for (int i = 0; i < local_descs.descCount(); ++i) { + if (local_descs[i].len != remote_descs[i].len) { + NIXL_ERROR_FUNC << "length mismatch at index " << i; + return NIXL_ERR_INVALID_PARAM; + } + total_bytes += local_descs[i].len; + } - NIXL_SHARED_LOCK_GUARD(data->lock); - const auto status = req_hndl.engine->createGpuXferReq( - *req_hndl.backendHandle, *req_hndl.initiatorDescs, *req_hndl.targetDescs, gpu_req_hndl); + if (!extra_params || extra_params->backends.size() == 0) { + backend_set_t* local_set = + data->memorySection->queryBackends(local_descs.getType()); + backend_set_t* remote_set = + data->remoteSections[remote_agent]->queryBackends( + remote_descs.getType()); + if (!local_set || !remote_set) { + NIXL_ERROR_FUNC << "no backends found for local or remote for their " + "corresponding memory type"; + return NIXL_ERR_NOT_FOUND; + } + + for (auto & elm : *local_set) + if (remote_set->count(elm) != 0) + backend_set->insert(elm); + + if (backend_set->empty()) { + NIXL_ERROR_FUNC << "no potential backend found to be able to do the transfer"; + return NIXL_ERR_NOT_FOUND; + } + } else { + for (auto & elm : extra_params->backends) + backend_set->insert(elm->engine); + } + + std::unique_ptr handle = std::make_unique(); + handle->initiatorDescs = new nixl_meta_dlist_t(local_descs.getType()); + + handle->targetDescs = new nixl_meta_dlist_t(remote_descs.getType()); + + for (auto &backend : *backend_set) { + ret1 = data->memorySection->populate(local_descs, backend, *handle->initiatorDescs); + ret2 = data->remoteSections[remote_agent]->populate( + remote_descs, backend, *handle->targetDescs); + + if ((ret1 == NIXL_SUCCESS) && (ret2 == NIXL_SUCCESS)) { + NIXL_INFO << "Selected backend: " << backend->getType(); + handle->engine = backend; + break; + } + } + + if (!handle->engine) { + NIXL_ERROR_FUNC << "no specified or potential backend had the required " + "registrations to be able to do the transfer"; + data->addErrorTelemetry(NIXL_ERR_NOT_FOUND); + return NIXL_ERR_NOT_FOUND; + } + + if (extra_params) { + if (extra_params->hasNotif) { + opt_args.notifMsg = extra_params->notifMsg; + opt_args.hasNotif = true; + } + + if (extra_params->customParam.length() > 0) + opt_args.customParam = extra_params->customParam; + } + + if (opt_args.hasNotif && (!handle->engine->supportsNotif())) { + NIXL_ERROR_FUNC << "the selected backend '" << handle->engine->getType() + << "' does not support notifications"; + data->addErrorTelemetry(NIXL_ERR_BACKEND); + return NIXL_ERR_BACKEND; + } + + handle->remoteAgent = remote_agent; + handle->status = NIXL_ERR_NOT_POSTED; + handle->notifMsg = opt_args.notifMsg; + handle->hasNotif = opt_args.hasNotif; + + if (data->telemetryEnabled) { + handle->telemetry.totalBytes = total_bytes; + handle->telemetry.descCount = handle->initiatorDescs->descCount(); + } + + ret1 = handle->engine->prepXfer (handle->backendOp, + *handle->initiatorDescs, + *handle->targetDescs, + handle->remoteAgent, + handle->backendHandle, + &opt_args); + if (ret1 != NIXL_SUCCESS) { + NIXL_ERROR_FUNC << "backend '" << handle->engine->getType() + << "' failed to prepare the transfer request with status " << ret1; + data->addErrorTelemetry(ret1); + return ret1; + } + + req_hndl = handle.release(); + + const auto status = req_hndl->engine->createGpuXferReq( + *req_hndl->backendHandle, *req_hndl->initiatorDescs, *req_hndl->targetDescs, gpu_req_hndl); if (status == NIXL_SUCCESS) { - data->gpuReqToEngine.emplace(gpu_req_hndl, req_hndl.engine); + data->gpuReqToEngine.emplace(gpu_req_hndl, req_hndl->engine); } - return status; + return NIXL_SUCCESS; } void diff --git a/src/plugins/ucx/ucx_backend.cpp b/src/plugins/ucx/ucx_backend.cpp index 37bac433a..247d5d7b2 100644 --- a/src/plugins/ucx/ucx_backend.cpp +++ b/src/plugins/ucx/ucx_backend.cpp @@ -38,14 +38,14 @@ #endif namespace { - void moveNotifList(notif_list_t &src, notif_list_t &tgt) - { - if (src.size() > 0) { - std::move(src.begin(), src.end(), std::back_inserter(tgt)); - src.clear(); - } +void +moveNotifList(notif_list_t &src, notif_list_t &tgt) { + if (src.size() > 0) { + std::move(src.begin(), src.end(), std::back_inserter(tgt)); + src.clear(); } } +} // namespace /**************************************** * CUDA related code @@ -62,24 +62,33 @@ class nixlUcxCudaCtx { myDevId = -1; } #endif - void cudaResetCtxPtr(); - int cudaUpdateCtxPtr(void *address, int expected_dev, bool &was_updated); - int cudaSetCtx(); + void + cudaResetCtxPtr(); + int + cudaUpdateCtxPtr(void *address, int expected_dev, bool &was_updated); + int + cudaSetCtx(); }; class nixlUcxCudaDevicePrimaryCtx { #ifndef HAVE_CUDA public: - bool push() { return false; } - void pop() {}; + bool + push() { + return false; + } + + void + pop() {}; #else static constexpr int defaultCudaDeviceOrdinal = 0; int m_ordinal{defaultCudaDeviceOrdinal}; CUdevice m_device{CU_DEVICE_INVALID}; CUcontext m_context{nullptr}; -public: - bool push() { +public: + bool + push() { CUcontext context; const auto res = cuCtxGetCurrent(&context); @@ -103,7 +112,8 @@ class nixlUcxCudaDevicePrimaryCtx { return cuCtxPushCurrent(m_context) == CUDA_SUCCESS; } - void pop() { + void + pop() { cuCtxPopCurrent(nullptr); } @@ -117,13 +127,14 @@ class nixlUcxCudaDevicePrimaryCtx { class nixlUcxCudaCtxGuard { nixlUcxCudaDevicePrimaryCtxPtr m_primary; + public: - nixlUcxCudaCtxGuard(nixl_mem_t nixl_mem, - nixlUcxCudaDevicePrimaryCtxPtr primary) { + nixlUcxCudaCtxGuard(nixl_mem_t nixl_mem, nixlUcxCudaDevicePrimaryCtxPtr primary) { if (nixl_mem == VRAM_SEG && primary && primary->push()) { m_primary = primary; } } + ~nixlUcxCudaCtxGuard() { if (m_primary) { m_primary->pop(); @@ -133,9 +144,8 @@ class nixlUcxCudaCtxGuard { #ifdef HAVE_CUDA -static int cudaQueryAddr(void *address, bool &is_dev, - CUdevice &dev, CUcontext &ctx) -{ +static int +cudaQueryAddr(void *address, bool &is_dev, CUdevice &dev, CUcontext &ctx) { CUmemorytype mem_type = CU_MEMORYTYPE_HOST; uint32_t is_managed = 0; #define NUM_ATTRS 4 @@ -159,8 +169,8 @@ static int cudaQueryAddr(void *address, bool &is_dev, return (CUDA_SUCCESS != result); } -int nixlUcxCudaCtx::cudaUpdateCtxPtr(void *address, int expected_dev, bool &was_updated) -{ +int +nixlUcxCudaCtx::cudaUpdateCtxPtr(void *address, int expected_dev, bool &was_updated) { bool is_dev; CUdevice dev; CUcontext ctx; @@ -169,12 +179,10 @@ int nixlUcxCudaCtx::cudaUpdateCtxPtr(void *address, int expected_dev, bool &was_ was_updated = false; /* TODO: proper error codes and log outputs through this method */ - if (expected_dev == -1) - return -1; + if (expected_dev == -1) return -1; // incorrect dev id from first registration - if (myDevId != -1 && expected_dev != myDevId) - return -1; + if (myDevId != -1 && expected_dev != myDevId) return -1; ret = cudaQueryAddr(address, is_dev, dev, ctx); if (ret) { @@ -205,8 +213,8 @@ int nixlUcxCudaCtx::cudaUpdateCtxPtr(void *address, int expected_dev, bool &was_ return 0; } -int nixlUcxCudaCtx::cudaSetCtx() -{ +int +nixlUcxCudaCtx::cudaSetCtx() { CUresult result; if (NULL == pthrCudaCtx) { return 0; @@ -219,21 +227,22 @@ int nixlUcxCudaCtx::cudaSetCtx() #else -int nixlUcxCudaCtx::cudaUpdateCtxPtr(void *address, int expected_dev, bool &was_updated) -{ +int +nixlUcxCudaCtx::cudaUpdateCtxPtr(void *address, int expected_dev, bool &was_updated) { was_updated = false; return 0; } -int nixlUcxCudaCtx::cudaSetCtx() { +int +nixlUcxCudaCtx::cudaSetCtx() { return 0; } #endif -void nixlUcxEngine::vramInitCtx() -{ +void +nixlUcxEngine::vramInitCtx() { cudaCtx = std::make_unique(); } @@ -244,7 +253,7 @@ nixlUcxEngine::vramUpdateCtx(void *address, uint64_t dev_id, bool &restart_reqd) restart_reqd = false; - if(!cuda_addr_wa) { + if (!cuda_addr_wa) { // Nothing to do return 0; } @@ -259,9 +268,9 @@ nixlUcxEngine::vramUpdateCtx(void *address, uint64_t dev_id, bool &restart_reqd) return 0; } -int nixlUcxEngine::vramApplyCtx() -{ - if(!cuda_addr_wa) { +int +nixlUcxEngine::vramApplyCtx() { + if (!cuda_addr_wa) { // Nothing to do return 0; } @@ -269,14 +278,14 @@ int nixlUcxEngine::vramApplyCtx() return cudaCtx->cudaSetCtx(); } -void nixlUcxEngine::vramFiniCtx() -{ +void +nixlUcxEngine::vramFiniCtx() { cudaCtx.reset(); } /**************************************** * UCX request management -*****************************************/ + *****************************************/ class nixlUcxIntReq { @@ -315,28 +324,28 @@ nixlUcxReqSetConnection(nixlUcxReq req, ucx_connection_ptr_t conn) { req_int->setConnection(conn); } -static void _internalRequestInit(void *request) -{ +static void +_internalRequestInit(void *request) { /* Initialize request in-place (aka "placement new")*/ - new(request) nixlUcxIntReq; + new (request) nixlUcxIntReq; } -static void _internalRequestFini(void *request) -{ +static void +_internalRequestFini(void *request) { /* Finalize request */ - nixlUcxIntReq *req = (nixlUcxIntReq*)request; + nixlUcxIntReq *req = (nixlUcxIntReq *)request; req->~nixlUcxIntReq(); } - -static void _internalRequestReset(nixlUcxIntReq *req) { +static void +_internalRequestReset(nixlUcxIntReq *req) { _internalRequestFini((void *)req); _internalRequestInit((void *)req); } /**************************************** * Backend request management -*****************************************/ + *****************************************/ class nixlUcxBackendH : public nixlBackendReqH { private: @@ -346,15 +355,19 @@ class nixlUcxBackendH : public nixlBackendReqH { // Notification to be sent after completion of all requests struct Notif { - std::string agent; - nixl_blob_t payload; - Notif(const std::string& remote_agent, const nixl_blob_t& msg) - : agent(remote_agent), payload(msg) {} + std::string agent; + nixl_blob_t payload; + + Notif(const std::string &remote_agent, const nixl_blob_t &msg) + : agent(remote_agent), + payload(msg) {} }; + std::optional notif; public: - auto& notification() { + auto & + notification() { return notif; } @@ -451,14 +464,15 @@ class nixlUcxBackendH : public nixlBackendReqH { return worker; } - size_t getWorkerId() const { + size_t + getWorkerId() const { return worker_id; } }; /**************************************** * Progress thread management -*****************************************/ + *****************************************/ /* * This class encapsulates a thread that polls one or multiple UCX workers @@ -1118,7 +1132,7 @@ nixlUcxEngine::nixlUcxEngine(const nixlBackendInitParams &init_params) std::vector devs; /* Empty vector */ nixl_b_params_t *custom_params = init_params.customParams; - if (custom_params->count("device_list")!=0) + if (custom_params->count("device_list") != 0) devs = str_split((*custom_params)["device_list"], ", "); size_t num_workers = nixl_b_params_get(custom_params, "num_workers", 1); @@ -1166,7 +1180,8 @@ nixlUcxEngine::nixlUcxEngine(const nixlBackendInitParams &init_params) vramInitCtx(); } -nixl_mem_list_t nixlUcxEngine::getSupportedMems () const { +nixl_mem_list_t +nixlUcxEngine::getSupportedMems() const { nixl_mem_list_t mems; mems.push_back(DRAM_SEG); mems.push_back(VRAM_SEG); @@ -1187,19 +1202,22 @@ nixlUcxEngine::~nixlUcxEngine() { /**************************************** * Connection management -*****************************************/ + *****************************************/ -nixl_status_t nixlUcxEngine::checkConn(const std::string &remote_agent) { +nixl_status_t +nixlUcxEngine::checkConn(const std::string &remote_agent) { return remoteConnMap.count(remote_agent) ? NIXL_SUCCESS : NIXL_ERR_NOT_FOUND; } -nixl_status_t nixlUcxEngine::getConnInfo(std::string &str) const { +nixl_status_t +nixlUcxEngine::getConnInfo(std::string &str) const { str = workerAddr; return NIXL_SUCCESS; } -nixl_status_t nixlUcxEngine::connect(const std::string &remote_agent) { - if(remote_agent == localAgent) { +nixl_status_t +nixlUcxEngine::connect(const std::string &remote_agent) { + if (remote_agent == localAgent) { return loadRemoteConnInfo(remote_agent, workerAddr); } @@ -1207,7 +1225,8 @@ nixl_status_t nixlUcxEngine::connect(const std::string &remote_agent) { NIXL_SUCCESS; } -nixl_status_t nixlUcxEngine::disconnect(const std::string &remote_agent) { +nixl_status_t +nixlUcxEngine::disconnect(const std::string &remote_agent) { auto search = remoteConnMap.find(remote_agent); if (search == remoteConnMap.end()) { @@ -1219,20 +1238,20 @@ nixl_status_t nixlUcxEngine::disconnect(const std::string &remote_agent) { return NIXL_SUCCESS; } -nixl_status_t nixlUcxEngine::loadRemoteConnInfo (const std::string &remote_agent, - const std::string &remote_conn_info) -{ +nixl_status_t +nixlUcxEngine::loadRemoteConnInfo(const std::string &remote_agent, + const std::string &remote_conn_info) { size_t size = remote_conn_info.size(); std::vector addr(size); - if(remoteConnMap.count(remote_agent)) { + if (remoteConnMap.count(remote_agent)) { return NIXL_ERR_INVALID_PARAM; } nixlSerDes::_stringToBytes(addr.data(), remote_conn_info, size); std::shared_ptr conn = std::make_shared(); bool error = false; - for (auto &uw: uws) { + for (auto &uw : uws) { auto result = uw->connect(addr.data(), size); if (!result.ok()) { error = true; @@ -1241,8 +1260,7 @@ nixl_status_t nixlUcxEngine::loadRemoteConnInfo (const std::string &remote_agent conn->eps.push_back(std::move(*result)); } - if (error) - return NIXL_ERR_BACKEND; + if (error) return NIXL_ERR_BACKEND; conn->remoteAgent = remote_agent; @@ -1253,18 +1271,18 @@ nixl_status_t nixlUcxEngine::loadRemoteConnInfo (const std::string &remote_agent /**************************************** * Memory management -*****************************************/ -nixl_status_t nixlUcxEngine::registerMem (const nixlBlobDesc &mem, - const nixl_mem_t &nixl_mem, - nixlBackendMD* &out) -{ + *****************************************/ +nixl_status_t +nixlUcxEngine::registerMem(const nixlBlobDesc &mem, + const nixl_mem_t &nixl_mem, + nixlBackendMD *&out) { auto priv = std::make_unique(); if (nixl_mem == VRAM_SEG) { bool need_restart; - if (vramUpdateCtx((void*)mem.addr, mem.devId, need_restart)) { + if (vramUpdateCtx((void *)mem.addr, mem.devId, need_restart)) { return NIXL_ERR_NOT_SUPPORTED; - //TODO Add to logging + // TODO Add to logging } if (need_restart) { vramApplyCtx(); @@ -1272,7 +1290,7 @@ nixl_status_t nixlUcxEngine::registerMem (const nixlBlobDesc &mem, } // TODO: Add nixl_mem check? - const int ret = uc->memReg((void*) mem.addr, mem.len, priv->mem, nixl_mem); + const int ret = uc->memReg((void *)mem.addr, mem.len, priv->mem, nixl_mem); if (ret) { return NIXL_ERR_BACKEND; } @@ -1285,27 +1303,26 @@ nixl_status_t nixlUcxEngine::registerMem (const nixlBlobDesc &mem, return NIXL_SUCCESS; } -nixl_status_t nixlUcxEngine::deregisterMem (nixlBackendMD* meta) -{ - nixlUcxPrivateMetadata *priv = (nixlUcxPrivateMetadata*) meta; +nixl_status_t +nixlUcxEngine::deregisterMem(nixlBackendMD *meta) { + nixlUcxPrivateMetadata *priv = (nixlUcxPrivateMetadata *)meta; uc->memDereg(priv->mem); delete priv; return NIXL_SUCCESS; } -nixl_status_t nixlUcxEngine::getPublicData (const nixlBackendMD* meta, - std::string &str) const { - const nixlUcxPrivateMetadata *priv = (nixlUcxPrivateMetadata*) meta; +nixl_status_t +nixlUcxEngine::getPublicData(const nixlBackendMD *meta, std::string &str) const { + const nixlUcxPrivateMetadata *priv = (nixlUcxPrivateMetadata *)meta; str = priv->get(); return NIXL_SUCCESS; } - // To be cleaned up nixl_status_t -nixlUcxEngine::internalMDHelper (const nixl_blob_t &blob, - const std::string &agent, - nixlBackendMD* &output) { +nixlUcxEngine::internalMDHelper(const nixl_blob_t &blob, + const std::string &agent, + nixlBackendMD *&output) { try { auto md = std::make_unique(); size_t size = blob.size(); @@ -1336,27 +1353,26 @@ nixlUcxEngine::internalMDHelper (const nixl_blob_t &blob, } nixl_status_t -nixlUcxEngine::loadLocalMD (nixlBackendMD* input, - nixlBackendMD* &output) -{ - nixlUcxPrivateMetadata* input_md = (nixlUcxPrivateMetadata*) input; +nixlUcxEngine::loadLocalMD(nixlBackendMD *input, nixlBackendMD *&output) { + nixlUcxPrivateMetadata *input_md = (nixlUcxPrivateMetadata *)input; return internalMDHelper(input_md->rkeyStr, localAgent, output); } // To be cleaned up -nixl_status_t nixlUcxEngine::loadRemoteMD (const nixlBlobDesc &input, - const nixl_mem_t &nixl_mem, - const std::string &remote_agent, - nixlBackendMD* &output) -{ +nixl_status_t +nixlUcxEngine::loadRemoteMD(const nixlBlobDesc &input, + const nixl_mem_t &nixl_mem, + const std::string &remote_agent, + nixlBackendMD *&output) { // Set CUDA context of first device, UCX will anyways detect proper device when sending nixlUcxCudaCtxGuard guard(nixl_mem, m_cudaPrimaryCtx); return internalMDHelper(input.metaInfo, remote_agent, output); } -nixl_status_t nixlUcxEngine::unloadMD (nixlBackendMD* input) { +nixl_status_t +nixlUcxEngine::unloadMD(nixlBackendMD *input) { - nixlUcxPublicMetadata *md = (nixlUcxPublicMetadata*) input; //typecast? + nixlUcxPublicMetadata *md = (nixlUcxPublicMetadata *)input; // typecast? delete md; return NIXL_SUCCESS; @@ -1364,12 +1380,12 @@ nixl_status_t nixlUcxEngine::unloadMD (nixlBackendMD* input) { /**************************************** * Data movement -*****************************************/ + *****************************************/ static nixl_status_t _retHelper(nixl_status_t ret, nixlUcxBackendH *hndl, nixlUcxReq &req, ucx_connection_ptr_t conn) { /* if transfer wasn't immediately completed */ - switch(ret) { + switch (ret) { case NIXL_IN_PROG: // TODO: this cast does not look safe // We need to allocate a vector of nixlUcxIntReq and set nixlUcxReqt @@ -1399,13 +1415,13 @@ nixlUcxEngine::getWorkerId() const { return it->second; } -nixl_status_t nixlUcxEngine::prepXfer (const nixl_xfer_op_t &operation, - const nixl_meta_dlist_t &local, - const nixl_meta_dlist_t &remote, - const std::string &remote_agent, - nixlBackendReqH* &handle, - const nixl_opt_b_args_t* opt_args) const -{ +nixl_status_t +nixlUcxEngine::prepXfer(const nixl_xfer_op_t &operation, + const nixl_meta_dlist_t &local, + const nixl_meta_dlist_t &remote, + const std::string &remote_agent, + nixlBackendReqH *&handle, + const nixl_opt_b_args_t *opt_args) const { if (local.descCount() == 0 || remote.descCount() == 0) { NIXL_ERROR << "Local or remote descriptor list is empty"; return NIXL_ERR_INVALID_PARAM; @@ -1420,16 +1436,16 @@ nixl_status_t nixlUcxEngine::prepXfer (const nixl_xfer_op_t &operation, return NIXL_SUCCESS; } -nixl_status_t nixlUcxEngine::estimateXferCost (const nixl_xfer_op_t &operation, - const nixl_meta_dlist_t &local, - const nixl_meta_dlist_t &remote, - const std::string &remote_agent, - nixlBackendReqH* const &handle, - std::chrono::microseconds &duration, - std::chrono::microseconds &err_margin, - nixl_cost_t &method, - const nixl_opt_args_t* opt_args) const -{ +nixl_status_t +nixlUcxEngine::estimateXferCost(const nixl_xfer_op_t &operation, + const nixl_meta_dlist_t &local, + const nixl_meta_dlist_t &remote, + const std::string &remote_agent, + nixlBackendReqH *const &handle, + std::chrono::microseconds &duration, + std::chrono::microseconds &err_margin, + nixl_cost_t &method, + const nixl_opt_args_t *opt_args) const { nixlUcxBackendH *intHandle = (nixlUcxBackendH *)handle; size_t workerId = intHandle->getWorkerId(); @@ -1452,17 +1468,19 @@ nixl_status_t nixlUcxEngine::estimateXferCost (const nixl_xfer_op_t &operation, size_t lsize = local[i].len; size_t rsize = remote[i].len; - nixlUcxPrivateMetadata *lmd = static_cast(local[i].metadataP); - nixlUcxPublicMetadata *rmd = static_cast(remote[i].metadataP); + nixlUcxPrivateMetadata *lmd = static_cast(local[i].metadataP); + nixlUcxPublicMetadata *rmd = static_cast(remote[i].metadataP); - NIXL_ASSERT(lmd && rmd) << "No metadata found in descriptor lists at index " << i << " during cost estimation"; + NIXL_ASSERT(lmd && rmd) << "No metadata found in descriptor lists at index " << i + << " during cost estimation"; NIXL_ASSERT(lsize == rsize) << "Local size (" << lsize << ") != Remote size (" << rsize << ") at index " << i << " during cost estimation"; std::chrono::microseconds msg_duration; std::chrono::microseconds msg_err_margin; nixl_cost_t msg_method; - nixl_status_t ret = rmd->conn->getEp(workerId)->estimateCost(lsize, msg_duration, msg_err_margin, msg_method); + nixl_status_t ret = rmd->conn->getEp(workerId)->estimateCost( + lsize, msg_duration, msg_err_margin, msg_method); if (ret != NIXL_SUCCESS) { NIXL_ERROR << "Worker failed to estimate cost for segment " << i << " status: " << ret; return ret; @@ -1495,13 +1513,13 @@ nixlUcxEngine::sendXferRange(const nixl_xfer_op_t &operation, intHandle->reserve(end_idx - start_idx + 2); for (size_t i = start_idx; i < end_idx; i++) { - void *laddr = (void*) local[i].addr; + void *laddr = (void *)local[i].addr; size_t lsize = local[i].len; uint64_t raddr = (uint64_t)remote[i].addr; size_t rsize = remote[i].len; - lmd = (nixlUcxPrivateMetadata*) local[i].metadataP; - rmd = (nixlUcxPublicMetadata*) remote[i].metadataP; + lmd = (nixlUcxPrivateMetadata *)local[i].metadataP; + rmd = (nixlUcxPublicMetadata *)remote[i].metadataP; auto &ep = rmd->conn->getEp(workerId); if (lsize != rsize) { @@ -1583,10 +1601,10 @@ nixlUcxEngine::postXfer(const nixl_xfer_op_t &operation, return ret; } -nixl_status_t nixlUcxEngine::checkXfer (nixlBackendReqH* handle) const -{ +nixl_status_t +nixlUcxEngine::checkXfer(nixlBackendReqH *handle) const { nixlUcxBackendH *intHandle = (nixlUcxBackendH *)handle; - auto& notif = intHandle->notification(); + auto ¬if = intHandle->notification(); nixl_status_t handle_status = intHandle->status(); if ((handle_status != NIXL_SUCCESS) || !notif.has_value()) { @@ -1615,8 +1633,8 @@ nixl_status_t nixlUcxEngine::checkXfer (nixlBackendReqH* handle) const return intHandle->status(); } -nixl_status_t nixlUcxEngine::releaseReqH(nixlBackendReqH* handle) const -{ +nixl_status_t +nixlUcxEngine::releaseReqH(nixlBackendReqH *handle) const { nixlUcxBackendH *intHandle = (nixlUcxBackendH *)handle; nixl_status_t status = intHandle->release(); @@ -1633,16 +1651,6 @@ nixlUcxEngine::createGpuXferReq(const nixlBackendReqH &req_hndl, nixlGpuXferReqH &gpu_req_hndl) const { auto intHandle = static_cast(&req_hndl); - if (local_descs.descCount() != remote_descs.descCount()) { - NIXL_ERROR << "Mismatch between local and remote descriptor counts"; - return NIXL_ERR_INVALID_PARAM; - } - - if (local_descs.descCount() == 0) { - NIXL_ERROR << "Empty descriptor lists"; - return NIXL_ERR_INVALID_PARAM; - } - auto remoteMd = static_cast(remote_descs[0].metadataP); if (!remoteMd || !remoteMd->conn) { NIXL_ERROR << "No connection found in remote metadata"; @@ -1655,9 +1663,11 @@ nixlUcxEngine::createGpuXferReq(const nixlBackendReqH &req_hndl, std::vector local_mems; std::vector remote_rkeys; std::vector remote_addrs; + std::vector remote_lengths; local_mems.reserve(local_descs.descCount()); remote_rkeys.reserve(remote_descs.descCount()); remote_addrs.reserve(remote_descs.descCount()); + remote_lengths.reserve(remote_descs.descCount()); for (size_t i = 0; i < static_cast(local_descs.descCount()); i++) { auto localMd = static_cast(local_descs[i].metadataP); @@ -1666,10 +1676,12 @@ nixlUcxEngine::createGpuXferReq(const nixlBackendReqH &req_hndl, local_mems.push_back(localMd->mem); remote_rkeys.push_back(&remoteMdDesc->getRkey(workerId)); remote_addrs.push_back(static_cast(remote_descs[i].addr)); + remote_lengths.push_back(remote_descs[i].len); } try { - gpu_req_hndl = nixl::ucx::createGpuXferReq(*ep, local_mems, remote_rkeys, remote_addrs); + gpu_req_hndl = nixl::ucx::createGpuXferReq( + *ep, local_mems, remote_rkeys, remote_addrs, remote_lengths); return NIXL_SUCCESS; } catch (const std::exception &e) { @@ -1713,19 +1725,20 @@ nixlUcxEngine::prepGpuSignal(const nixlBackendMD &meta, void *signal) const { } } -int nixlUcxEngine::progress() { +int +nixlUcxEngine::progress() { // TODO: add listen for connection handling if necessary int ret = 0; - for (auto &uw: uws) + for (auto &uw : uws) ret += uw->progress(); return ret; } /**************************************** * Notifications -*****************************************/ + *****************************************/ -//agent will provide cached msg +// agent will provide cached msg nixl_status_t nixlUcxEngine::notifSendPriv(const std::string &remote_agent, const std::string &msg, @@ -1742,7 +1755,7 @@ nixlUcxEngine::notifSendPriv(const std::string &remote_agent, ret = ep->sendAm( NOTIF_STR, NULL, 0, (void *)buffer->data(), buffer->size(), UCP_AM_SEND_FLAG_EAGER, req); if (ret == NIXL_IN_PROG) { - nixlUcxIntReq* nReq = (nixlUcxIntReq*)req; + nixlUcxIntReq *nReq = (nixlUcxIntReq *)req; nReq->amBuffer = std::move(buffer); } return ret; @@ -1760,15 +1773,16 @@ nixlUcxEngine::appendNotif(std::string remote_name, std::string msg) { } ucs_status_t -nixlUcxEngine::notifAmCb(void *arg, const void *header, - size_t header_length, void *data, +nixlUcxEngine::notifAmCb(void *arg, + const void *header, + size_t header_length, + void *data, size_t length, - const ucp_am_recv_param_t *param) -{ + const ucp_am_recv_param_t *param) { nixlSerDes ser_des; - std::string ser_str( (char*) data, length); - nixlUcxEngine* engine = (nixlUcxEngine*) arg; + std::string ser_str((char *)data, length); + nixlUcxEngine *engine = (nixlUcxEngine *)arg; // send_am should be forcing EAGER protocol NIXL_ASSERT(!(param->recv_attr & UCP_AM_RECV_ATTR_FLAG_RNDV)); @@ -1787,8 +1801,8 @@ nixlUcxEngine::getNotifsImpl(notif_list_t ¬if_list) { moveNotifList(notifMainList, notif_list); } -nixl_status_t nixlUcxEngine::getNotifs(notif_list_t ¬if_list) -{ +nixl_status_t +nixlUcxEngine::getNotifs(notif_list_t ¬if_list) { if (!notif_list.empty()) return NIXL_ERR_INVALID_PARAM; while (progress()) @@ -1797,8 +1811,8 @@ nixl_status_t nixlUcxEngine::getNotifs(notif_list_t ¬if_list) return NIXL_SUCCESS; } -nixl_status_t nixlUcxEngine::genNotif(const std::string &remote_agent, const std::string &msg) const -{ +nixl_status_t +nixlUcxEngine::genNotif(const std::string &remote_agent, const std::string &msg) const { nixl_status_t ret; nixlUcxReq req; @@ -1808,7 +1822,7 @@ nixl_status_t nixlUcxEngine::genNotif(const std::string &remote_agent, const std } ret = notifSendPriv(remote_agent, msg, req, conn->getEp(getWorkerId())); - switch(ret) { + switch (ret) { case NIXL_IN_PROG: /* do not track the request */ getWorker(getWorkerId())->reqRelease(req); diff --git a/src/utils/ucx/gpu_xfer_req_h.cpp b/src/utils/ucx/gpu_xfer_req_h.cpp index a75974f09..5106f221f 100644 --- a/src/utils/ucx/gpu_xfer_req_h.cpp +++ b/src/utils/ucx/gpu_xfer_req_h.cpp @@ -35,34 +35,40 @@ nixlGpuXferReqH createGpuXferReq(const nixlUcxEp &ep, const std::vector &local_mems, const std::vector &remote_rkeys, - const std::vector &remote_addrs) { + const std::vector &remote_addrs, + const std::vector &remote_lengths) { nixl_status_t status = ep.checkTxState(); if (status != NIXL_SUCCESS) { throw std::runtime_error("Endpoint not in valid state for creating memory list"); } - if (local_mems.empty() || remote_rkeys.empty() || remote_addrs.empty()) { - throw std::invalid_argument("Empty memory, rkey, or address lists provided"); - } - - if (local_mems.size() != remote_rkeys.size() || local_mems.size() != remote_addrs.size()) { - throw std::invalid_argument( - "Local memory, remote rkey, and remote address lists must have same size"); - } std::vector ucp_elements; ucp_elements.reserve(local_mems.size()); for (size_t i = 0; i < local_mems.size(); i++) { ucp_device_mem_list_elem_t ucp_elem; - ucp_elem.field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_MEMH | - UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY | UCP_DEVICE_MEM_LIST_ELEM_FIELD_LOCAL_ADDR | - UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR | UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH; - ucp_elem.memh = local_mems[i].getMemh(); + bool has_local_mem = local_mems[i].getMemh() != nullptr; + + if (has_local_mem) { + /* Data element with local memory */ + ucp_elem.field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_MEMH | + UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY | UCP_DEVICE_MEM_LIST_ELEM_FIELD_LOCAL_ADDR | + UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR | UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH; + ucp_elem.memh = local_mems[i].getMemh(); + ucp_elem.local_addr = local_mems[i].getBase(); + ucp_elem.length = local_mems[i].getSize(); + } else { + /* Signal element without local memory */ + ucp_elem.field_mask = UCP_DEVICE_MEM_LIST_ELEM_FIELD_RKEY | + UCP_DEVICE_MEM_LIST_ELEM_FIELD_REMOTE_ADDR | UCP_DEVICE_MEM_LIST_ELEM_FIELD_LENGTH; + ucp_elem.memh = nullptr; + ucp_elem.local_addr = nullptr; + ucp_elem.length = remote_lengths[i]; + } + ucp_elem.rkey = remote_rkeys[i]->get(); - ucp_elem.local_addr = local_mems[i].getBase(); ucp_elem.remote_addr = remote_addrs[i]; - ucp_elem.length = local_mems[i].getSize(); ucp_elements.push_back(ucp_elem); } @@ -97,7 +103,8 @@ nixlGpuXferReqH createGpuXferReq(const nixlUcxEp &ep, const std::vector &local_mems, const std::vector &remote_rkeys, - const std::vector &remote_addrs) { + const std::vector &remote_addrs, + const std::vector &remote_lengths) { NIXL_ERROR << "UCX GPU device API not supported"; throw std::runtime_error("UCX GPU device API not available"); } diff --git a/src/utils/ucx/gpu_xfer_req_h.h b/src/utils/ucx/gpu_xfer_req_h.h index 9798308b1..e7f4b2bc6 100644 --- a/src/utils/ucx/gpu_xfer_req_h.h +++ b/src/utils/ucx/gpu_xfer_req_h.h @@ -32,7 +32,8 @@ nixlGpuXferReqH createGpuXferReq(const nixlUcxEp &ep, const std::vector &local_mems, const std::vector &remote_rkeys, - const std::vector &remote_addrs); + const std::vector &remote_addrs, + const std::vector &remote_lengths); void releaseGpuXferReq(nixlGpuXferReqH gpu_req) noexcept; diff --git a/test/gtest/device_api/single_write_test.cu b/test/gtest/device_api/single_write_test.cu index 0425150f2..a1da58990 100644 --- a/test/gtest/device_api/single_write_test.cu +++ b/test/gtest/device_api/single_write_test.cu @@ -385,7 +385,7 @@ TEST_P(SingleWriteTest, BasicSingleWriteTest) { constexpr size_t count = 1; nixl_mem_t mem_type = VRAM_SEG; size_t num_threads = 32; - const size_t num_iters = 10000; + const size_t num_iters = 10; constexpr unsigned index = 0; const bool is_no_delay = true; @@ -405,22 +405,21 @@ TEST_P(SingleWriteTest, BasicSingleWriteTest) { extra_params.notifMsg = NOTIF_MSG; nixlXferReqH *xfer_req = nullptr; + nixlGpuXferReqH gpu_req_hndl; + nixl_status_t status = getAgent(SENDER_AGENT) - .createXferReq(NIXL_WRITE, + .createGpuXferReq( makeDescList(src_buffers, mem_type), makeDescList(dst_buffers, mem_type), getAgentName(RECEIVER_AGENT), + gpu_req_hndl, xfer_req, &extra_params); ASSERT_EQ(status, NIXL_SUCCESS) - << "Failed to create xfer request " << nixlEnumStrings::statusStr(status); + << "Failed to create GPU xfer request " << nixlEnumStrings::statusStr(status); EXPECT_NE(xfer_req, nullptr); - nixlGpuXferReqH gpu_req_hndl; - status = getAgent(SENDER_AGENT).createGpuXferReq(*xfer_req, gpu_req_hndl); - ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to create GPU xfer request"; - ASSERT_NE(gpu_req_hndl, nullptr) << "GPU request handle is null after createGpuXferReq"; size_t src_offset = 0; @@ -485,7 +484,7 @@ TEST_P(SingleWriteTest, VariableSizeTest) { constexpr size_t count = 1; nixl_mem_t mem_type = VRAM_SEG; size_t num_threads = 32; - const size_t num_iters = 50000; + const size_t num_iters = 10; constexpr unsigned index = 0; const bool is_no_delay = true; @@ -507,19 +506,18 @@ TEST_P(SingleWriteTest, VariableSizeTest) { extra_params.notifMsg = NOTIF_MSG; nixlXferReqH *xfer_req = nullptr; + nixlGpuXferReqH gpu_req_hndl; + nixl_status_t status = getAgent(SENDER_AGENT) - .createXferReq(NIXL_WRITE, + .createGpuXferReq( makeDescList(src_buffers, mem_type), makeDescList(dst_buffers, mem_type), getAgentName(RECEIVER_AGENT), + gpu_req_hndl, xfer_req, &extra_params); - ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to create xfer request for size " << test_size; - - nixlGpuXferReqH gpu_req_hndl; - status = getAgent(SENDER_AGENT).createGpuXferReq(*xfer_req, gpu_req_hndl); ASSERT_EQ(status, NIXL_SUCCESS) << "Failed to create GPU xfer request for size " << test_size;