-
Notifications
You must be signed in to change notification settings - Fork 291
NIXL EP: Use VMM API for device memory allocation. #1415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
itayalroy
merged 32 commits into
ai-dynamo:main
from
ofirfarjun7:topic/nixl-ep-use-vmm-api
Apr 5, 2026
Merged
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
b6cb209
NIXL/EP: Use vmm API instead of cudaMalloc
ofirfarjun7 c36a4c2
NIXL/EP: Use vmm API instead of cudaMalloc
29c2c7a
NIXL/EP: revert
daecca9
NIXL/EP: Support gdr copy with vmm.
3f5f55f
NIXL/EP: Improve.
ofirfarjun7 07674ee
NIXL/EP: Format.
ofirfarjun7 0771375
NIXL/EP: Merge branch 'main' into topic/nixl-ep-use-vmm-api
ofirfarjun7 bbb8d27
NIXL/EP: check return val.
ofirfarjun7 7b60106
NIXL/EP: Format.
ofirfarjun7 25b2e5b
NIXL/EP:Improve.
ofirfarjun7 02e379b
NIXL/EP: Format.
ofirfarjun7 2c701be
NIXL/EP: Improve.
ofirfarjun7 329fb22
NIXL/EP: Format.
ofirfarjun7 ee2fd64
NIXL/EP: fallback to cudaMalloc if fabric not supported
ofirfarjun7 7677e59
NIXL/EP: set default vals
ofirfarjun7 8b4b567
NIXL/EP: Fix.
ofirfarjun7 cbe8d1d
NIXL/EP: Fix comments.
ofirfarjun7 088c960
NIXL/EP: Fix.
ofirfarjun7 15f9174
NIXL/EP: Fix.
ofirfarjun7 a44a094
NIXL/EP: not needed.
ofirfarjun7 f120473
NIXL/EP: new files terms.
ofirfarjun7 9ef9dca
NIXL/EP: Fix comments.
ofirfarjun7 eab7509
NIXL/EP: Merge main
ofirfarjun7 22c59f7
NIXL/EP: Fix comment.
ofirfarjun7 0a58f96
NIXL/EP: Fix.
ofirfarjun7 d2bdfea
NIXL/EP: fix.
ofirfarjun7 0405899
NIXL/EP: Fix comment.
ofirfarjun7 e82d422
NIXL/EP: Fix comment.
ofirfarjun7 5b30ea4
NIXL/EP: Fix comment.
ofirfarjun7 437b97a
NIXL/EP: Merge branch 'main'
ofirfarjun7 32cd739
NIXL/EP: Fix comment.
ofirfarjun7 b1cab66
NIXL/EP: Fix comment.
ofirfarjun7 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,174 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #include <iostream> | ||
| #include <stdexcept> | ||
|
|
||
| #include "config.hpp" | ||
| #include "vmm.hpp" | ||
|
|
||
| namespace { | ||
|
|
||
| constexpr const char *k_vmm_ctx = "vmm_region"; | ||
|
|
||
| /** Log a non-fatal warning if a CUDA driver API call failed (e.g. during teardown). */ | ||
| void | ||
| warn_cu_api(CUresult status, const char *context, const char *operation) noexcept { | ||
| if (status != CUDA_SUCCESS) { | ||
| const char *msg = nullptr; | ||
| if (cuGetErrorString(status, &msg) != CUDA_SUCCESS || msg == nullptr) { | ||
| msg = "unknown CUDA driver error"; | ||
| } | ||
| std::cerr << "WARNING: " << context << " failed to " << operation << ": " << msg << '\n'; | ||
| } | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| namespace nixl_ep { | ||
|
|
||
| void | ||
| vmm_region::release() noexcept { | ||
| if (is_cuda_malloc_) { | ||
| if (ptr_) { | ||
| warn_cu_api(cuMemFree(ptr_), k_vmm_ctx, "cuMemFree"); | ||
| } | ||
| ptr_ = 0; | ||
| return; | ||
| } | ||
|
|
||
| if (vmm_mapped_) { | ||
| warn_cu_api(cuMemUnmap(ptr_, size_), k_vmm_ctx, "cuMemUnmap"); | ||
| vmm_mapped_ = false; | ||
| } | ||
| if (ptr_) { | ||
| warn_cu_api(cuMemAddressFree(ptr_, size_), k_vmm_ctx, "cuMemAddressFree"); | ||
| ptr_ = 0; | ||
| } | ||
| if (handle_) { | ||
| warn_cu_api(cuMemRelease(handle_), k_vmm_ctx, "cuMemRelease"); | ||
| handle_ = 0; | ||
| } | ||
| } | ||
|
|
||
| vmm_region::~vmm_region() { | ||
| release(); | ||
| } | ||
|
|
||
| vmm_region::vmm_region(size_t size) { | ||
| if (size == 0) { | ||
| throw std::invalid_argument("vmm_region: size must be non-zero"); | ||
| } | ||
|
|
||
| struct cuda_alloc_ctx { | ||
| bool fabric_supported; | ||
| CUmemAllocationProp prop; | ||
| size_t granularity; | ||
| CUdevice device; | ||
| CUmemAccessDesc access_desc = {}; | ||
|
|
||
| cuda_alloc_ctx() : fabric_supported(false), prop({}), granularity(0) { | ||
| int version; | ||
|
|
||
| if (cuCtxGetDevice(&device) != CUDA_SUCCESS) { | ||
| throw std::runtime_error("CUDA device should be set before creating a vmm_region"); | ||
| } | ||
|
|
||
| if (cuDriverGetVersion(&version) != CUDA_SUCCESS) { | ||
| throw std::runtime_error("Failed to get CUDA driver version"); | ||
| } | ||
|
|
||
| if (version < 11000) { | ||
| return; /* too old — fall back to cudaMalloc */ | ||
| } | ||
|
|
||
| int fab = 0; | ||
| if ((cuDeviceGetAttribute(&fab, | ||
| CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, | ||
| device) != CUDA_SUCCESS) || | ||
| (!fab)) { | ||
| return; /* no fabric — fall back to cudaMalloc */ | ||
| } | ||
|
|
||
| int rdma_vmm_supported = 0; | ||
| if (cuDeviceGetAttribute(&rdma_vmm_supported, | ||
| CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED, | ||
| device) != CUDA_SUCCESS) { | ||
| throw std::runtime_error( | ||
| "Failed to query GPUDirect RDMA with VMM support attribute"); | ||
| } | ||
|
|
||
| if (!rdma_vmm_supported) { | ||
| std::cerr << "DIAG: " << k_vmm_ctx | ||
| << " - GPUDirect RDMA with CUDA VMM not supported; falling back to " | ||
| "cuMemAlloc\n"; | ||
| return; | ||
| } | ||
|
|
||
| prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; | ||
| prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; | ||
| prop.location.id = device; | ||
| prop.allocFlags.gpuDirectRDMACapable = 1; | ||
| prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; | ||
|
|
||
| if (cuMemGetAllocationGranularity( | ||
| &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM) != CUDA_SUCCESS) { | ||
| throw std::runtime_error("Failed to get CUDA allocation granularity"); | ||
| } | ||
|
|
||
| access_desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; | ||
| access_desc.location.id = device; | ||
| access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; | ||
| fabric_supported = true; | ||
| } | ||
| }; | ||
|
|
||
| static cuda_alloc_ctx ctx; | ||
|
|
||
| if (!ctx.fabric_supported) { | ||
| size_ = size; | ||
| is_cuda_malloc_ = true; | ||
| if (cuMemAlloc(&ptr_, size) != CUDA_SUCCESS) { | ||
| throw std::runtime_error("cuMemAlloc fallback failed"); | ||
| } | ||
| return; | ||
| } | ||
|
|
||
| size_ = nixl_ep::align_up<size_t>(size, ctx.granularity); | ||
|
|
||
| if (cuMemCreate(&handle_, size_, &ctx.prop, 0) != CUDA_SUCCESS) { | ||
| throw std::runtime_error("Failed to create CUDA VMM allocation"); | ||
| } | ||
|
|
||
| if (cuMemAddressReserve(&ptr_, size_, 0, 0, 0) != CUDA_SUCCESS) { | ||
| release(); | ||
| throw std::runtime_error("Failed to reserve CUDA virtual address"); | ||
| } | ||
|
|
||
| if (cuMemMap(ptr_, size_, 0, handle_, 0) != CUDA_SUCCESS) { | ||
| release(); | ||
| throw std::runtime_error("Failed to map CUDA VMM memory"); | ||
| } | ||
| vmm_mapped_ = true; | ||
|
|
||
| if (cuMemSetAccess(ptr_, size_, &ctx.access_desc, 1) != CUDA_SUCCESS) { | ||
| release(); | ||
| throw std::runtime_error("Failed to set CUDA memory access"); | ||
| } | ||
| } | ||
|
|
||
| } // namespace nixl_ep | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <cuda.h> | ||
| #include <cstddef> | ||
| #include <cstdint> | ||
|
|
||
| namespace nixl_ep { | ||
|
|
||
| class vmm_region { | ||
rakhmets marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| public: | ||
| explicit vmm_region(size_t size); | ||
|
|
||
| ~vmm_region(); | ||
|
|
||
| vmm_region(const vmm_region &) = delete; | ||
| vmm_region & | ||
| operator=(const vmm_region &) = delete; | ||
| vmm_region(vmm_region &&) = delete; | ||
| vmm_region & | ||
| operator=(vmm_region &&) = delete; | ||
|
|
||
| [[nodiscard]] void * | ||
| ptr() const noexcept { | ||
| return reinterpret_cast<void *>(static_cast<std::uintptr_t>(ptr_)); | ||
| } | ||
|
|
||
| private: | ||
| void | ||
| release() noexcept; | ||
|
|
||
| CUdeviceptr ptr_ = 0; | ||
| size_t size_ = 0; | ||
| CUmemGenericAllocationHandle handle_ = 0; | ||
| bool is_cuda_malloc_ = false; | ||
| bool vmm_mapped_ = false; | ||
| }; | ||
|
|
||
| } // namespace nixl_ep | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,6 +63,7 @@ endif | |
|
|
||
| nixl_ep_sources = [ | ||
| 'csrc/nixl_ep.cpp', | ||
| 'csrc/vmm.cpp', | ||
| 'csrc/kernels/nixl_ep.cu', | ||
| ] | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.