Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,85 @@ CudaDeviceAllocator::CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int d
return AllocImpl(this_ptr, size);
}

// ---------------------------------------------------------------------------
// CudaExternalDeviceAllocator — delegates to user-provided function pointers.
// ---------------------------------------------------------------------------

CudaExternalDeviceAllocator::CudaExternalDeviceAllocator(const OrtMemoryInfo* memory_info, int device_id,
void* alloc_fn, void* free_fn, void* empty_cache_fn)
: CudaAllocatorBase(CudaAllocatorKind::kDevice, memory_info),
device_id_(device_id),
alloc_fn_(reinterpret_cast<ExternalAlloc>(alloc_fn)),
free_fn_(reinterpret_cast<ExternalFree>(free_fn)),
empty_cache_fn_(reinterpret_cast<ExternalEmptyCache>(empty_cache_fn)) {
version = kCudaPluginEpMinOrtApiVersion;
Alloc = AllocImpl;
Free = FreeImpl;
Info = InfoImpl;
Reserve = ReserveImpl;
GetStats = nullptr;
AllocOnStream = nullptr;
}
Comment thread
tianleiwu marked this conversation as resolved.

/*static*/ void* ORT_API_CALL CudaExternalDeviceAllocator::AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept {
auto* alloc = static_cast<CudaExternalDeviceAllocator*>(this_ptr);
if (size == 0) return nullptr;
if (alloc->alloc_fn_ == nullptr) return nullptr;

int prev_device = -1;
const bool restore_prev_device = cudaGetDevice(&prev_device) == cudaSuccess;
if (cudaSetDevice(alloc->device_id_) != cudaSuccess) {
RestoreDeviceIfKnown(restore_prev_device, prev_device);
return nullptr;
}

void* p = alloc->alloc_fn_(size);
RestoreDeviceIfKnown(restore_prev_device, prev_device);
return p;
}

/*static*/ void ORT_API_CALL CudaExternalDeviceAllocator::FreeImpl(OrtAllocator* this_ptr, void* p) noexcept {
auto* alloc = static_cast<CudaExternalDeviceAllocator*>(this_ptr);
if (p != nullptr && alloc->free_fn_ != nullptr) {
int prev_device = -1;
const bool restore_prev_device = cudaGetDevice(&prev_device) == cudaSuccess;
if (cudaSetDevice(alloc->device_id_) != cudaSuccess) {
RestoreDeviceIfKnown(restore_prev_device, prev_device);
return;
}

alloc->free_fn_(p);
RestoreDeviceIfKnown(restore_prev_device, prev_device);

// If this was a reserved allocation, invoke empty_cache to release cached memory
// (matching bundled CUDA EP behavior).
std::lock_guard<std::mutex> lock(alloc->lock_);
auto it = alloc->reserved_.find(p);
if (it != alloc->reserved_.end()) {
alloc->reserved_.erase(it);
if (alloc->empty_cache_fn_) {
alloc->empty_cache_fn_();
}
}
}
}

/*static*/ const OrtMemoryInfo* ORT_API_CALL CudaExternalDeviceAllocator::InfoImpl(
const OrtAllocator* this_ptr) noexcept {
const auto* alloc = static_cast<const CudaExternalDeviceAllocator*>(this_ptr);
return alloc->GetMemoryInfo();
}

/*static*/ void* ORT_API_CALL CudaExternalDeviceAllocator::ReserveImpl(OrtAllocator* this_ptr, size_t size) noexcept {
void* p = AllocImpl(this_ptr, size);
if (p != nullptr) {
auto* alloc = static_cast<CudaExternalDeviceAllocator*>(this_ptr);
std::lock_guard<std::mutex> lock(alloc->lock_);
alloc->reserved_.insert(p);
}
return p;
}

// ---------------------------------------------------------------------------
// CudaPinnedAllocator — uses cudaHostAlloc/cudaFreeHost for page-locked
// host memory visible to the GPU.
Expand Down
29 changes: 29 additions & 0 deletions onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
#include "cuda_plugin_utils.h"

#include <algorithm>
#include <mutex>
#include <sstream>
#include <string>
#include <type_traits>
#include <unordered_set>

namespace onnxruntime {
namespace cuda_plugin {
Expand Down Expand Up @@ -102,6 +104,33 @@ class CudaDeviceAllocator final : public CudaAllocatorBase {
int device_id_;
};

/// CUDA device memory allocator using external user-provided function pointers.
/// Delegates alloc/free/empty_cache to the caller-supplied callbacks.
/// Lifetime is managed by the EP factory (ReleaseAllocatorImpl), not by a Release callback.
class CudaExternalDeviceAllocator final : public CudaAllocatorBase {
typedef void* (*ExternalAlloc)(size_t size);
typedef void (*ExternalFree)(void* p);
typedef void (*ExternalEmptyCache)();

public:
CudaExternalDeviceAllocator(const OrtMemoryInfo* memory_info, int device_id,
void* alloc_fn, void* free_fn, void* empty_cache_fn);
~CudaExternalDeviceAllocator() = default;

private:
static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept;
static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept;
static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept;
static void* ORT_API_CALL ReserveImpl(OrtAllocator* this_ptr, size_t size) noexcept;

int device_id_;
ExternalAlloc alloc_fn_;
ExternalFree free_fn_;
ExternalEmptyCache empty_cache_fn_;
mutable std::mutex lock_;
std::unordered_set<void*> reserved_;
};

/// CUDA pinned (host) memory allocator using cudaHostAlloc/cudaFreeHost.
/// Lifetime is managed by the EP factory (ReleaseAllocatorImpl), not by a Release callback.
class CudaPinnedAllocator final : public CudaAllocatorBase {
Expand Down
14 changes: 11 additions & 3 deletions onnxruntime/core/providers/cuda/plugin/cuda_ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& lo
adapter_config.fuse_conv_bias = config_.fuse_conv_bias;
adapter_config.sdpa_kernel = config_.sdpa_kernel;
adapter_config.device_id = config_.device_id;
adapter_config.do_copy_in_default_stream = config_.do_copy_in_default_stream;
onnxruntime::cuda::SetCudaKernelAdapterRuntimeConfigForProvider(
static_cast<const void*>(EpImpl()), adapter_config);

Expand Down Expand Up @@ -369,7 +370,11 @@ OrtStatus* ORT_API_CALL CudaEp::CreateSyncStreamForDeviceImpl(

auto cuda_stream = std::make_unique<CudaSyncStream>(ep->factory_, device_id, this_ptr);

if (ep->config_.enable_cuda_graph) {
if (ep->config_.has_user_compute_stream && ep->config_.user_compute_stream != nullptr) {
// Wrap the user-provided external CUDA stream with full cuBLAS/cuDNN handles.
RETURN_IF_ERROR(cuda_stream->InitHandlesWithUserStream(
static_cast<cudaStream_t>(ep->config_.user_compute_stream)));
} else if (ep->config_.enable_cuda_graph) {
// When CUDA graph capture is enabled, all operations on this thread must go
// through the thread's graph stream so capture/replay sees the same stream
// as kernels.
Expand Down Expand Up @@ -417,8 +422,11 @@ OrtStatus* ORT_API_CALL CudaEp::IsConcurrentRunSupportedImpl(
return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "is_supported must not be null.");
}

ORT_UNUSED_PARAMETER(this_ptr);
*is_supported = true;
auto* ep = static_cast<CudaEp*>(this_ptr);
// When a unified stream is in use (either from user_compute_stream, external
// allocator, or explicit use_ep_level_unified_stream), all operations share a
// single stream so concurrent runs are not safe.
*is_supported = !ep->config_.use_ep_level_unified_stream;
return nullptr;
}

Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/cuda/plugin/cuda_ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ class CudaEp : public onnxruntime::ep::adapter::Ep {
int sdpa_kernel = 0; ///< Attention backend bitmask override.
bool enable_cuda_graph = false; ///< Enable CUDA graph capture and replay.
int min_num_runs_before_cuda_graph_capture = 2; ///< Warm-up runs before graph capture begins.
bool has_user_compute_stream = false; ///< Whether user provided an external CUDA stream.
void* user_compute_stream = nullptr; ///< User-provided CUDA stream (cudaStream_t cast to void*).
bool do_copy_in_default_stream = true; ///< Use default stream for H2D/D2H copies.
bool use_ep_level_unified_stream = false; ///< Force all ops to share one stream (no concurrency).
void* external_alloc = nullptr; ///< External GPU memory allocation function pointer.
void* external_free = nullptr; ///< External GPU memory deallocation function pointer.
void* external_empty_cache = nullptr; ///< External GPU memory cache-clear function pointer.
};

CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& logger);
Expand Down
138 changes: 138 additions & 0 deletions onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <algorithm>
#include <cassert>
#include <cctype>
#include <cstdint>

Check warning on line 12 in onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: cuda_ep_factory.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc:12: Found C++ system header after other header. Should be: cuda_ep_factory.h, c system, c++ system, other. [build/include_order] [4]
#include <cstring>
#include <climits>
#include <cstdlib>
Expand Down Expand Up @@ -490,6 +491,13 @@
const std::string sdpa_kernel_key = ep_options_prefix + "sdpa_kernel";
const std::string enable_cuda_graph_key = ep_options_prefix + "enable_cuda_graph";
const std::string min_runs_key = ep_options_prefix + "min_num_runs_before_cuda_graph_capture";
const std::string has_user_compute_stream_key = ep_options_prefix + "has_user_compute_stream";
const std::string user_compute_stream_key = ep_options_prefix + "user_compute_stream";
const std::string do_copy_in_default_stream_key = ep_options_prefix + "do_copy_in_default_stream";
const std::string use_ep_level_unified_stream_key = ep_options_prefix + "use_ep_level_unified_stream";
const std::string gpu_external_alloc_key = ep_options_prefix + "gpu_external_alloc";
const std::string gpu_external_free_key = ep_options_prefix + "gpu_external_free";
const std::string gpu_external_empty_cache_key = ep_options_prefix + "gpu_external_empty_cache";

// Prefer plugin-provider-option keys, then fall back to the legacy ep.cuda.*
// aliases and finally to the historical flat session config names.
Expand Down Expand Up @@ -523,6 +531,113 @@
{min_runs_key, "ep.cuda.min_num_runs_before_cuda_graph_capture"},
config.min_num_runs_before_cuda_graph_capture);

// --- Stream and allocator options ---
read_session_config_bool(
{has_user_compute_stream_key, "ep.cuda.has_user_compute_stream", "has_user_compute_stream"},
config.has_user_compute_stream);
read_session_config_bool(
{do_copy_in_default_stream_key, "ep.cuda.do_copy_in_default_stream", "do_copy_in_default_stream"},
config.do_copy_in_default_stream);
read_session_config_bool(
{use_ep_level_unified_stream_key, "ep.cuda.use_ep_level_unified_stream", "use_ep_level_unified_stream"},
config.use_ep_level_unified_stream);

// Parse user_compute_stream as a pointer-sized integer (address of a cudaStream_t).
// Uses base 0 so that "0x..." hex strings are auto-detected, and validates that
// the entire string was consumed (matching the bundled EP's ParseStringWithClassicLocale behavior).
auto read_session_config_pointer = [&](std::initializer_list<std::string_view> keys, void*& value) {
for (const auto& key : keys) {
auto raw_value = try_get_session_config(key);
if (!raw_value.has_value()) {
continue;
}

ORT_TRY {
size_t pos = 0;
unsigned long long address = std::stoull(*raw_value, &pos, 0);

Check warning on line 557 in onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Use int16_t/int64_t/etc, rather than the C type long [runtime/int] [4] Raw Output: onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc:557: Use int16_t/int64_t/etc, rather than the C type long [runtime/int] [4]
if (pos == raw_value->size()) {
if (address > std::numeric_limits<std::uintptr_t>::max()) {
log_invalid_session_config(key, "a pointer-sized integer (value exceeds address space)");
return;
}
value = reinterpret_cast<void*>(static_cast<std::uintptr_t>(address));
return;
Comment thread
tianleiwu marked this conversation as resolved.
}
}
ORT_CATCH(const std::exception&) {
}

log_invalid_session_config(key, "a pointer-sized integer (decimal or 0x-prefixed hex address)");
return;
}
};

read_session_config_pointer(
{user_compute_stream_key, "ep.cuda.user_compute_stream", "user_compute_stream"},
config.user_compute_stream);

// If user_compute_stream is provided, force has_user_compute_stream to true.
if (config.user_compute_stream != nullptr) {
config.has_user_compute_stream = true;
}

// Parse external allocator function pointers.
read_session_config_pointer(
{gpu_external_alloc_key, "ep.cuda.gpu_external_alloc", "gpu_external_alloc"},
config.external_alloc);
read_session_config_pointer(
{gpu_external_free_key, "ep.cuda.gpu_external_free", "gpu_external_free"},
config.external_free);
read_session_config_pointer(
{gpu_external_empty_cache_key, "ep.cuda.gpu_external_empty_cache", "gpu_external_empty_cache"},
config.external_empty_cache);

// Warn if only one of alloc/free is provided (both are required for external allocator).
if ((config.external_alloc == nullptr) != (config.external_free == nullptr)) {
LogWarning(factory->ort_api_, factory->default_logger_, ORT_FILE, __LINE__, "CudaEpFactory::CreateEpImpl",
"Only one of gpu_external_alloc/gpu_external_free is set. "
"Both must be provided for the external allocator to be used. Ignoring.");
config.external_alloc = nullptr;
config.external_free = nullptr;
config.external_empty_cache = nullptr;
}

// Validate: user_compute_stream and external allocator cannot both be active.
if (config.has_user_compute_stream && config.external_alloc != nullptr && config.external_free != nullptr) {
return factory->ort_api_.CreateStatus(
ORT_INVALID_ARGUMENT,
"CUDA plugin EP does not support using both user_compute_stream and external allocator simultaneously.");
}

// Validate: user_compute_stream and cuda graph cannot both be active.
if (config.has_user_compute_stream && config.enable_cuda_graph) {
return factory->ort_api_.CreateStatus(
ORT_INVALID_ARGUMENT,
"CUDA plugin EP does not support using both user_compute_stream and enable_cuda_graph simultaneously.");
}

// When user_compute_stream is set, force unified stream mode (matches bundled EP behavior).
if (config.has_user_compute_stream) {
config.use_ep_level_unified_stream = true;
}

// When external allocator is used, force unified stream mode (matches bundled EP behavior).
if (config.external_alloc != nullptr && config.external_free != nullptr) {
config.use_ep_level_unified_stream = true;
}

// Store external allocator info in the device cache entry so CreateAllocatorImpl can use it.
if (config.external_alloc != nullptr && config.external_free != nullptr) {
std::lock_guard<std::mutex> lock(factory->device_cache_mutex_);
auto* entry = factory->FindDeviceCacheEntryByOrdinalLocked(config.device_id);
if (entry) {
std::lock_guard<std::mutex> arena_lock(entry->arena_mutex);
entry->external_alloc = config.external_alloc;
entry->external_free = config.external_free;
entry->external_empty_cache = config.external_empty_cache;
}
}

const OrtLogger& ep_logger = logger ? *logger : factory->default_logger_;
auto actual_ep = std::make_unique<CudaEp>(*factory, config, ep_logger);
*ep = actual_ep.release();
Expand Down Expand Up @@ -581,6 +696,19 @@

std::lock_guard<std::mutex> lock{entry->arena_mutex};

// If external allocator function pointers are configured, use those directly
// (no arena, no mempool — the external allocator manages its own caching).
if (entry->UseExternalAllocator()) {
if (!entry->external_device_allocator) {
entry->external_device_allocator = std::make_unique<CudaExternalDeviceAllocator>(
memory_info, req_device_id,
entry->external_alloc, entry->external_free, entry->external_empty_cache);
}
++entry->num_external_allocator_users;
*allocator = entry->external_device_allocator.get();
return nullptr;
}

if (use_mempool) {
if (!entry->mempool_allocator) {
status = CudaMempoolOrtAllocator::Create(memory_info, allocator_options,
Expand Down Expand Up @@ -704,6 +832,16 @@
if (--entry.num_mempool_users == 0) entry.mempool_allocator.reset();
return;
}
if (allocator == entry.external_device_allocator.get()) {
if (entry.num_external_allocator_users <= 0) {
LogWarning(factory->ort_api_, factory->default_logger_, ORT_FILE, __LINE__,
"CudaEpFactory::ReleaseAllocatorImpl",
"Refcount underflow in ReleaseAllocatorImpl (external_device_allocator). Ignoring release.");
return;
}
if (--entry.num_external_allocator_users == 0) entry.external_device_allocator.reset();
return;
}
}
}

Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,20 @@
std::unique_ptr<CudaArenaAllocator> device_arena;
std::unique_ptr<CudaArenaAllocator> pinned_arena;
std::unique_ptr<CudaMempoolOrtAllocator> mempool_allocator;
std::unique_ptr<CudaExternalDeviceAllocator> external_device_allocator;

Check warning on line 115 in onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h:115: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
int num_device_arena_users = 0;
int num_pinned_arena_users = 0;
int num_mempool_users = 0;
int num_external_allocator_users = 0;

// External allocator function pointers (set during CreateEpImpl when configured).
void* external_alloc = nullptr;
void* external_free = nullptr;
void* external_empty_cache = nullptr;

bool UseExternalAllocator() const {
return external_alloc != nullptr && external_free != nullptr;
}
};

struct HardwareDeviceKey {
Expand Down
Loading
Loading