Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
6519034
Cuda Arena migration plan
yuslepukhin Apr 1, 2026
26fcaae
Update the design
yuslepukhin Apr 1, 2026
9dad919
Clarify IArena inhertance
yuslepukhin Apr 1, 2026
0027c19
Address review comments
yuslepukhin Apr 1, 2026
ad48120
Clarify Environment::CreateAndRegisterAllocatorV2()
yuslepukhin Apr 1, 2026
93850d9
Address review comments
yuslepukhin Apr 1, 2026
318edae
Re-design for a in-plugin arena using examples as a base
yuslepukhin Apr 2, 2026
b6973b6
Address review comments
yuslepukhin Apr 2, 2026
6748f7d
Re-work inheritance of Cuda Arean allocators
yuslepukhin Apr 2, 2026
2bcd8d3
Adjust CudaMempoolOrtAllocator
yuslepukhin Apr 2, 2026
4730e8d
Address review comments
yuslepukhin Apr 2, 2026
d335e7b
Address comments
yuslepukhin Apr 2, 2026
71c3ec5
Implement Phase I
yuslepukhin Apr 2, 2026
32f1fbc
lintrunner
yuslepukhin Apr 2, 2026
a19d9d3
Address review comments and make this build and test run. Phase I
yuslepukhin Apr 3, 2026
7b3bb5f
Address review comments
yuslepukhin Apr 3, 2026
a71b93a
Address comments
yuslepukhin Apr 4, 2026
1ea0d94
Address comments
yuslepukhin Apr 4, 2026
8f850a3
Address review comments
yuslepukhin Apr 4, 2026
27c3bc4
Integrate CudMempoolAllocator
yuslepukhin Apr 4, 2026
9c60d8a
Merge branch 'main' into yuslepukhin/cuda_arena_ep
yuslepukhin Apr 6, 2026
2cde673
Address review comments
yuslepukhin Apr 6, 2026
8f81a39
Address review comments, add public Reserve API, improve test coverage
yuslepukhin Apr 6, 2026
552d0e6
address comments
yuslepukhin Apr 6, 2026
700eb6c
Address review issues
yuslepukhin Apr 6, 2026
5a73a66
Add Shrink API
yuslepukhin Apr 6, 2026
c60b59b
Address review comments
yuslepukhin Apr 6, 2026
9961b56
Address review comments
yuslepukhin Apr 7, 2026
121d53b
Merge branch 'main' into yuslepukhin/cuda_arena_ep
yuslepukhin Apr 7, 2026
982eb6a
Add ArenaAllocator wrapper for Shrink and ReleaseStreamBuffers
yuslepukhin Apr 7, 2026
540962d
Address review comments
yuslepukhin Apr 7, 2026
6151008
Update docs
yuslepukhin Apr 7, 2026
9aebc8c
address review comments
yuslepukhin Apr 7, 2026
1c612cc
Address most recent comments
yuslepukhin Apr 7, 2026
da13dd5
Address compile issues. Add test.
yuslepukhin Apr 7, 2026
4eb4238
Merge branch 'main' into yuslepukhin/cuda_arena_ep
yuslepukhin Apr 8, 2026
e0204a8
Address review comments
yuslepukhin Apr 8, 2026
65769d5
Build error
yuslepukhin Apr 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions cmake/onnxruntime_providers_cuda_plugin.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_plugin
${CUDA_PLUGIN_EP_CC_SRCS}
${CUDA_PLUGIN_EP_CU_SRCS}
)

# Mirror directory structure in the Visual Studio solution tree under "onnxruntime".
source_group(TREE ${ONNXRUNTIME_ROOT} PREFIX "onnxruntime" FILES ${CUDA_EP_CC_SRCS} ${CUDA_EP_CU_SRCS})
source_group(TREE ${ONNXRUNTIME_ROOT} PREFIX "onnxruntime" FILES ${CUDA_CONTRIB_OPS_CC_SRCS} ${CUDA_CONTRIB_OPS_CU_SRCS})
# Keep the plugin CUDA target aligned with the repo-wide C++20 baseline.
# Forcing CUDA C++17 here breaks newer protobuf/absl headers used by the plugin
# build, as absl::compare expects standard ordering support in this configuration.
Expand Down Expand Up @@ -143,22 +147,14 @@ target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE
"$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--std c++20>"
"$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr;-Xcudafe;--diag_suppress=550>"
"$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcudafe --diag_suppress=2810>"
# Force-include adapters.h and cuda_kernel_adapter.h for CXX sources.
# GCC/Clang use -include, MSVC uses /FI.
"$<$<AND:$<COMPILE_LANGUAGE:CXX>,$<NOT:$<CXX_COMPILER_ID:MSVC>>>:-include;${REPO_ROOT}/include/onnxruntime/ep/adapters.h>"
"$<$<AND:$<COMPILE_LANGUAGE:CXX>,$<NOT:$<CXX_COMPILER_ID:MSVC>>>:SHELL:-include ${CUDA_PLUGIN_EP_DIR}/cuda_kernel_adapter.h>"
"$<$<AND:$<COMPILE_LANGUAGE:CXX>,$<CXX_COMPILER_ID:MSVC>>:/FI${REPO_ROOT}/include/onnxruntime/ep/adapters.h>"
"$<$<AND:$<COMPILE_LANGUAGE:CXX>,$<CXX_COMPILER_ID:MSVC>>:/FI${CUDA_PLUGIN_EP_DIR}/cuda_kernel_adapter.h>"
)

# Force-include adapter headers for CXX files.
# MSVC uses /FI; GCC/Clang use -include.
if (MSVC)
target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE
"$<$<COMPILE_LANGUAGE:CXX>:SHELL:/FI \"${REPO_ROOT}/include/onnxruntime/ep/adapters.h\">"
"$<$<COMPILE_LANGUAGE:CXX>:SHELL:/FI \"${CUDA_PLUGIN_EP_DIR}/cuda_kernel_adapter.h\">"
)
else()
target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE
"$<$<COMPILE_LANGUAGE:CXX>:SHELL:-include ${REPO_ROOT}/include/onnxruntime/ep/adapters.h>"
"$<$<COMPILE_LANGUAGE:CXX>:SHELL:-include ${CUDA_PLUGIN_EP_DIR}/cuda_kernel_adapter.h>"
)
endif()

if (MSVC)
target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE
"$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /permissive>"
Expand All @@ -170,6 +166,11 @@ if (MSVC)
)

target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE
# /permissive is required for CUTLASS cute headers (cute::stride.hpp, cute::Layout etc.)
"$<$<COMPILE_LANGUAGE:CXX>:/permissive>"
Comment thread
yuslepukhin marked this conversation as resolved.
# /permissive disables C++ alternative tokens (or, and, not, etc.).
# Force-include iso646.h to restore them as macros.
"$<$<COMPILE_LANGUAGE:CXX>:/FIiso646.h>"
"$<$<COMPILE_LANGUAGE:CXX>:/wd4127>"
)
endif()
Expand Down Expand Up @@ -287,9 +288,10 @@ endif()



# Set output name
# Set output name and solution folder
set_target_properties(onnxruntime_providers_cuda_plugin PROPERTIES
OUTPUT_NAME "onnxruntime_providers_cuda_plugin"
FOLDER "ONNXRuntime"
)

# Install
Expand Down
7 changes: 7 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,13 @@ if (onnxruntime_USE_CUDA AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_R
)
list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cuda_src})

if (onnxruntime_BUILD_CUDA_EP_AS_PLUGIN)
file(GLOB onnxruntime_test_providers_cuda_plugin_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/providers/cuda/plugin/*.cc"
)
list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cuda_plugin_src})
endif()

if (onnxruntime_USE_CUDA_NHWC_OPS AND CUDNN_MAJOR_VERSION GREATER 8)
file(GLOB onnxruntime_test_providers_cuda_nhwc_src CONFIGURE_DEPENDS
"${TEST_SRC_DIR}/providers/cuda/nhwc/*.cc"
Expand Down
757 changes: 757 additions & 0 deletions docs/cuda_plugin_ep/arena_allocator_migration_design.md

Large diffs are not rendered by default.

118 changes: 50 additions & 68 deletions docs/cuda_plugin_ep/cuda_plugin_ep_design.md

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions include/onnxruntime/core/framework/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ class IAllocator {
*stats = {};
}

// Returns a pointer to this allocator as an IArena if it is one, nullptr otherwise.
// Used by SafeArenaCast to avoid dependency on RTTI.
virtual class IArena* AsArena() { return nullptr; }
virtual const class IArena* AsArena() const { return nullptr; }

static bool CalcMemSizeForArray(size_t nmemb, size_t size, size_t* out) noexcept {
return CalcMemSizeForArrayWithAlignment(nmemb, size, 0, out);
}
Expand Down Expand Up @@ -364,6 +369,8 @@ class IArena : public IAllocator {
virtual Status Shrink() = 0;
// Only implemented when IsStreamAware() returns true
virtual void ReleaseStreamBuffers(Stream* /*stream*/) {}
IArena* AsArena() override { return this; }
const IArena* AsArena() const override { return this; }
static IArena* SafeArenaCast(IAllocator* allocator);
};

Expand Down
16 changes: 16 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,22 @@ typedef struct OrtAllocator {
* \since 1.23
*/
void*(ORT_API_CALL* AllocOnStream)(struct OrtAllocator* this_, size_t size, OrtSyncStream* stream);

/** \brief Release unused memory held by the allocator back to the system.
*
* For arena-based allocators, this frees allocation regions that are completely unused.
* For mempool-based allocators, this trims the pool to a configured minimum.
* For non-arena allocators this is a no-op.
*
* \param[in] this_ OrtAllocator instance
*
* \return nullptr on success, or an OrtStatus* on failure.
*
* \note Implementation of this function is optional and Shrink may be set to a nullptr.
* Callers must check for nullptr before invoking.
* \since 1.25
*/
ORT_API2_STATUS(Shrink, _In_ struct OrtAllocator* this_);
Comment thread
yuslepukhin marked this conversation as resolved.
} OrtAllocator;

typedef void(ORT_API_CALL* OrtLoggingFunction)(
Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,7 @@ struct AllocatorImpl : Base<T> {
using B::B;

void* Alloc(size_t size);
void* Reserve(size_t size);
MemoryAllocation GetAllocation(size_t size);
void Free(void* p);
ConstMemoryInfo GetInfo() const;
Expand All @@ -1057,6 +1058,12 @@ struct AllocatorImpl : Base<T> {
* \return A pointer to a KeyValuePairs object that will be filled with the allocator statistics.
*/
KeyValuePairs GetStats() const;

/** \brief Release unused memory held by the allocator.
*
* Calls the optional Shrink function pointer if available; does nothing otherwise.
*/
void Shrink();
};
} // namespace detail

Expand Down
22 changes: 22 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,19 @@ inline void* AllocatorImpl<T>::Alloc(size_t size) {
return out;
}

template <typename T>
inline void* AllocatorImpl<T>::Reserve(size_t size) {
// Reserve was added in version 18. For older allocators the field may be
// uninitialized, so we must not dereference it.
if (this->p_->version >= 18 && this->p_->Reserve) {
return this->p_->Reserve(this->p_, size);
}
// Fall back to Alloc() for allocators that don't implement Reserve,
// matching the ORT-core adapter behavior (IAllocatorImplWrappingOrtAllocator,
// IArenaImplWrappingOrtAllocator).
return this->p_->Alloc(this->p_, size);
}

template <typename T>
inline MemoryAllocation AllocatorImpl<T>::GetAllocation(size_t size) {
void* out;
Expand All @@ -250,6 +263,15 @@ inline KeyValuePairs AllocatorImpl<T>::GetStats() const {
ThrowOnError(GetApi().AllocatorGetStats(this->p_, &out));
return KeyValuePairs(out);
}

template <typename T>
inline void AllocatorImpl<T>::Shrink() {
// Shrink was added in version 25. For older allocators the field may be
// uninitialized, so we must not dereference it.
if (this->p_->version >= 25 && this->p_->Shrink) {
ThrowOnError(this->p_->Shrink(this->p_));
}
}
} // namespace detail

inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

using namespace onnxruntime::cuda;
#ifndef BUILD_CUDA_EP_AS_PLUGIN
using onnxruntime::OpKernelContext;
using onnxruntime::OpKernelInfo;
#endif
using onnxruntime::cuda::CudaKernel;
class DynamicTimeWarping final : public CudaKernel {
public:
DynamicTimeWarping(const OpKernelInfo& info) : CudaKernel(info) {}
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/contrib_ops/cuda/tensor/unfold.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

using namespace onnxruntime::cuda;
#ifndef BUILD_CUDA_EP_AS_PLUGIN
using onnxruntime::OpKernelContext;
using onnxruntime::OpKernelInfo;
#endif
using onnxruntime::cuda::CudaKernel;
class UnfoldTensor final : public CudaKernel {
public:
UnfoldTensor(const OpKernelInfo& info) : CudaKernel(info) {
Expand Down
7 changes: 1 addition & 6 deletions onnxruntime/core/framework/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,7 @@ void* AllocateBufferWithOptions(IAllocator& alloc, size_t size, bool use_reserve
}

IArena* IArena::SafeArenaCast(IAllocator* allocator) {
#if !defined(ORT_NO_RTTI)
auto* result = dynamic_cast<IArena*>(allocator);
return result;
#else
return static_cast<IArena*>(allocator);
#endif
return allocator ? allocator->AsArena() : nullptr;
}

} // namespace onnxruntime
Expand Down
13 changes: 4 additions & 9 deletions onnxruntime/core/framework/device_stream_collection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,10 @@ class DeviceStreamCollectionImpl {
void ReleaseSingleStreamBuffers(Stream* stream) {
if (!stream) return;
for (const auto& it : allocators_) {
if (it.second->Info().device == stream->GetDevice() &&
it.second->Info().alloc_type == OrtArenaAllocator) {
if (it.second->IsStreamAware()) {
// Previously we only had one StreamAwareBFCArena. We need to guard
// against multiple allocators now.
auto* arena_alloc = IArena::SafeArenaCast(it.second.get());
if (arena_alloc) {
arena_alloc->ReleaseStreamBuffers(stream);
}
if (it.second->Info().device == stream->GetDevice()) {
auto* arena = it.second->AsArena();
if (arena && arena->IsStreamAware()) {
arena->ReleaseStreamBuffers(stream);
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

#include "core/providers/shared_library/provider_api.h"
#include "shared_inc/cuda_call.h"
#ifdef BUILD_CUDA_EP_AS_PLUGIN
#include "ep/adapters.h"
#include "plugin/provider_api_shims.h"
#else
#include <core/platform/env.h>
#endif

#ifdef _WIN32
Comment thread
yuslepukhin marked this conversation as resolved.
#else // POSIX
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cuda/cuda_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ struct CudaOrtAllocator : OrtAllocator {
Reserve = AllocImpl; // no special behavior for Reserve so use AllocImpl
GetStats = nullptr; // GetStatsImpl. The CUDA allocators don't have stats currently so we can skip.
AllocOnStream = nullptr; // TODO. Plugin EP arena to provide this.
Shrink = nullptr;

const OrtEpApi& ep_api = *api.GetEpApi();
const OrtMemoryDevice* mem_device = ep_api.MemoryInfo_GetMemoryDevice(mem_info);
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_stream_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ DeferredCpuAllocator::DeferredCpuAllocator(CudaStream& cuda_stream) : cuda_strea
auto self = reinterpret_cast<const DeferredCpuAllocator*>(this_);
return &self->cuda_stream_.GetCpuAllocator()->Info();
};
OrtAllocator::Reserve = nullptr;
OrtAllocator::GetStats = nullptr;
OrtAllocator::AllocOnStream = nullptr;
OrtAllocator::Shrink = nullptr;
}

struct CudaNotification : public synchronize::Notification {
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/cuda/cudnn_fe_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

#include "core/providers/cuda/shared_inc/cudnn_fe_call.h"
#include "core/providers/shared_library/provider_api.h"
#ifdef BUILD_CUDA_EP_AS_PLUGIN
#include "ep/adapters.h"
#include "plugin/provider_api_shims.h"
#else
#include <core/platform/env.h>
#endif
Comment thread
yuslepukhin marked this conversation as resolved.
#if !defined(__CUDACC__) && !defined(USE_CUDA_MINIMAL)
#include <cudnn_frontend.h>
#endif
Expand Down
51 changes: 51 additions & 0 deletions onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@

#include "cuda_plugin_utils.h"

#include <algorithm>
#include <sstream>
#include <string>
#include <type_traits>

namespace onnxruntime {
namespace cuda_plugin {

Expand All @@ -35,6 +40,52 @@ class CudaAllocatorBase : public OrtAllocator {
const OrtMemoryInfo* memory_info_;
};

// CudaAllocatorBase derives from OrtAllocator via single non-virtual inheritance.
// This guarantees OrtAllocator sits at offset 0 in the derived layout, so
// static_cast between OrtAllocator* and CudaAllocatorBase* is safe.
static_assert(!std::is_polymorphic_v<CudaAllocatorBase>,
"CudaAllocatorBase must not be polymorphic (no virtual functions) "
"to ensure OrtAllocator is at offset 0.");

/// Allocator statistics tracked by arena allocators.
struct AllocatorStats {
int64_t num_allocs = 0;
int64_t num_reserves = 0;
int64_t num_arena_extensions = 0;
int64_t num_arena_shrinkages = 0;
int64_t bytes_in_use = 0;
int64_t total_allocated_bytes = 0;
int64_t max_bytes_in_use = 0;
int64_t max_alloc_size = 0;
int64_t bytes_limit = 0;

void ToKeyValuePairs(const OrtApi& api, OrtKeyValuePairs* kvps) const {
api.AddKeyValuePair(kvps, "Limit", std::to_string(bytes_limit).c_str());
api.AddKeyValuePair(kvps, "InUse", std::to_string(bytes_in_use).c_str());
api.AddKeyValuePair(kvps, "TotalAllocated", std::to_string(total_allocated_bytes).c_str());
api.AddKeyValuePair(kvps, "MaxInUse", std::to_string(max_bytes_in_use).c_str());
api.AddKeyValuePair(kvps, "NumAllocs", std::to_string(num_allocs).c_str());
api.AddKeyValuePair(kvps, "NumReserves", std::to_string(num_reserves).c_str());
api.AddKeyValuePair(kvps, "NumArenaExtensions", std::to_string(num_arena_extensions).c_str());
api.AddKeyValuePair(kvps, "NumArenaShrinkages", std::to_string(num_arena_shrinkages).c_str());
api.AddKeyValuePair(kvps, "MaxAllocSize", std::to_string(max_alloc_size).c_str());
}

std::string DebugString() const {
std::ostringstream ss;
ss << "Limit: " << bytes_limit << "\n"
<< "InUse: " << bytes_in_use << "\n"
<< "TotalAllocated: " << total_allocated_bytes << "\n"
<< "MaxInUse: " << max_bytes_in_use << "\n"
<< "NumAllocs: " << num_allocs << "\n"
<< "NumReserves: " << num_reserves << "\n"
<< "NumArenaExtensions: " << num_arena_extensions << "\n"
<< "NumArenaShrinkages: " << num_arena_shrinkages << "\n"
<< "MaxAllocSize: " << max_alloc_size << "\n";
return ss.str();
}
};

/// CUDA device memory allocator using cudaMalloc/cudaFree.
/// Lifetime is managed by the EP factory (ReleaseAllocatorImpl), not by a Release callback.
class CudaDeviceAllocator final : public CudaAllocatorBase {
Expand Down
Loading
Loading