From 83946f3e344338133245f642c5ce86f402d2016c Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 20 Mar 2026 16:42:12 -0700 Subject: [PATCH 01/48] Cuda Plug EP Core --- cmake/CMakeLists.txt | 9 + cmake/onnxruntime_providers_cuda.cmake | 5 + cmake/onnxruntime_providers_cuda_plugin.cmake | 358 ++++++++ include/onnxruntime/core/framework/tensor.h | 10 + .../ep/adapter/kernel_def_builder.h | 2 + include/onnxruntime/ep/adapter/node.h | 5 + include/onnxruntime/ep/adapter/op_kernel.h | 17 +- include/onnxruntime/ep/api.h | 2 + onnxruntime/core/providers/cuda/cuda_common.h | 7 + onnxruntime/core/providers/cuda/cuda_kernel.h | 34 + .../core/providers/cuda/cuda_nhwc_kernels.cc | 2 +- .../core/providers/cuda/cuda_nhwc_kernels.h | 4 +- .../core/providers/cuda/cudnn_common.h | 2 + .../cuda/plugin/cuda_allocator_plugin.cc | 101 +++ .../cuda/plugin/cuda_allocator_plugin.h | 45 + .../cuda/plugin/cuda_controlflow_plugin.cc | 372 ++++++++ .../cuda/plugin/cuda_controlflow_plugin.cu | 115 +++ .../cuda/plugin/cuda_controlflow_plugin.h | 97 +++ .../cuda/plugin/cuda_data_transfer_plugin.cc | 136 +++ .../cuda/plugin/cuda_data_transfer_plugin.h | 39 + .../core/providers/cuda/plugin/cuda_ep.cc | 309 +++++++ .../core/providers/cuda/plugin/cuda_ep.h | 84 ++ .../providers/cuda/plugin/cuda_ep_factory.cc | 344 ++++++++ .../providers/cuda/plugin/cuda_ep_factory.h | 108 +++ .../cuda/plugin/cuda_graph_plugin.cc | 144 ++++ .../providers/cuda/plugin/cuda_graph_plugin.h | 70 ++ .../cuda/plugin/cuda_kernel_adapter.h | 810 ++++++++++++++++++ .../providers/cuda/plugin/cuda_plugin_ep.cc | 89 ++ .../cuda/plugin/cuda_plugin_ep_symbols.def | 4 + .../cuda/plugin/cuda_plugin_kernels.cu | 59 ++ .../cuda/plugin/cuda_plugin_kernels.h | 20 + .../providers/cuda/plugin/cuda_plugin_utils.h | 77 ++ .../cuda/plugin/cuda_stream_plugin.cc | 194 +++++ .../cuda/plugin/cuda_stream_plugin.h | 77 ++ .../cuda/plugin/provider_api_shims.cc | 30 + .../providers/shared_library/provider_api.h | 20 +- 36 files changed, 3792 insertions(+), 9 deletions(-) create mode 100644 cmake/onnxruntime_providers_cuda_plugin.cmake create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_ep.cc create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_ep.h create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_plugin_ep.cc create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_plugin_ep_symbols.def create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.h create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h create mode 100644 onnxruntime/core/providers/cuda/plugin/provider_api_shims.cc diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 385342479913a..16b00a089a6b5 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -84,6 +84,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS" OFF) cmake_dependent_option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" ON "onnxruntime_USE_CUDA" OFF) +cmake_dependent_option(onnxruntime_BUILD_CUDA_EP_AS_PLUGIN "Build CUDA EP as a separate plugin shared library" OFF "onnxruntime_USE_CUDA" OFF) option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF) option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF) option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) @@ -1431,6 +1432,9 @@ if (Git_FOUND) if (onnxruntime_USE_FP8_KV_CACHE) string(APPEND ORT_BUILD_INFO "fp8-kv-cache=1, ") endif() + if (onnxruntime_BUILD_CUDA_EP_AS_PLUGIN) + string(APPEND ORT_BUILD_INFO "cuda-plugin-ep=1, ") + endif() if (onnxruntime_DUMP_TENSOR) string(APPEND ORT_BUILD_INFO "dump-tensor=1, ") endif() @@ -1763,6 +1767,11 @@ endif() foreach(onnxruntime_cmake_file ${ONNXRUNTIME_CMAKE_FILES}) include(${onnxruntime_cmake_file}.cmake) endforeach() + +# CUDA EP Plugin build (independent shared library) +if (onnxruntime_BUILD_CUDA_EP_AS_PLUGIN) + include(onnxruntime_providers_cuda_plugin.cmake) +endif() if (UNIX) option(BUILD_PKGCONFIG_FILES "Build and install pkg-config files" ON) else() diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index ff667d8f117e0..6336070e836c3 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -20,6 +20,9 @@ "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" ) endif() + # Exclude plugin directory if it was picked up by GLOB_RECURSE + list(FILTER onnxruntime_providers_cuda_cc_srcs EXCLUDE REGEX "core/providers/cuda/plugin/.*") + # Remove pch files list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h" @@ -43,6 +46,8 @@ "${ONNXRUNTIME_ROOT}/core/providers/cuda/math/unary_elementwise_ops_impl.cu" ) endif() + # Exclude plugin directory if it was picked up by GLOB_RECURSE + list(FILTER onnxruntime_providers_cuda_cu_srcs EXCLUDE REGEX "core/providers/cuda/plugin/.*") source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake new file mode 100644 index 0000000000000..1acc9cc133024 --- /dev/null +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -0,0 +1,358 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# Build the CUDA Execution Provider as a plugin shared library. +# This file is included from the main CMakeLists.txt when onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON. + +message(STATUS "Building CUDA EP as plugin shared library") + + + +set(CUDA_PLUGIN_EP_DIR "${ONNXRUNTIME_ROOT}/core/providers/cuda/plugin") + +# --- Collect standard CUDA EP sources --- +file(GLOB_RECURSE CUDA_EP_CC_SRCS CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" +) + +file(GLOB_RECURSE CUDA_EP_CU_SRCS CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" +) + +# --- Collect contrib ops sources --- +file(GLOB_RECURSE CUDA_CONTRIB_OPS_CC_SRCS CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cc" +) + +file(GLOB_RECURSE CUDA_CONTRIB_OPS_CU_SRCS CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cu" +) + +list(APPEND CUDA_PLUGIN_EP_CC_SRCS + ${CUDA_EP_CC_SRCS} + ${CUDA_CONTRIB_OPS_CC_SRCS} +) + +list(APPEND CUDA_PLUGIN_EP_CU_SRCS + ${CUDA_EP_CU_SRCS} + ${CUDA_CONTRIB_OPS_CU_SRCS} +) + +list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX "onnxruntime/contrib_ops/cuda/aten_ops/.*") +list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX "onnxruntime/contrib_ops/cuda/collective/.*") + +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX "onnxruntime/contrib_ops/cuda/aten_ops/.*") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX "onnxruntime/contrib_ops/cuda/collective/.*") + +# Exclude files that include cuda_execution_provider.h (directly or transitively), +# which conflicts with the adapter shim CUDAExecutionProvider class. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_execution_provider\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_provider_factory\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_provider_interface\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/math/einsum\\.cc$") + +# Exclude the framework controlflow/ subdirectory — these inherit from CPU base +# classes (If, Loop, Scan). The plugin has its own control flow wrappers in +# plugin/cuda_controlflow_plugin.cc that delegate to OrtEpApi. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/core/providers/cuda/controlflow/.*") + +# Exclude the entire tunable/ subdirectory — it depends on the real CudaTuningContext +# and CUDAExecutionProvider which are not available in the plugin build. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tunable/.*") + +# Exclude real EP infrastructure files (replaced by plugin/ equivalents). +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_stream_handle\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_execution_provider_info\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_graph\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_mempool_arena\\.cc$") + +# Exclude cuda_common.cc — its HalfGemmOptions definitions conflict with the +# adapter's inline shim. Utility functions are replaced or not needed. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_common\\.cc$") + +# Exclude cuda_nhwc_kernels.cc and cuda_contrib_kernels.cc — these files contain +# explicit BuildKernelCreateInfo<> registration tables that reference ALL kernel +# classes (including those in excluded source files like space_depth_ops.cc, +# controlflow/, transformers/, etc.), causing undefined symbols at link time. +# With PluginKernelCollector, individual kernel files self-register via macro +# overrides, so these centralized tables are not needed in the plugin build. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_nhwc_kernels\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_contrib_kernels\\.cc$") + +# integer_gemm.cc: dynamic_cast replaced with GetCublasHandle(cudaStream_t). +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/integer_gemm\\.cc$") # REMOVED in Stage 5 + +# RNN ops: dual-build-compatible signatures are in place (void* alloc_stream, +# cudaStream_t, cudnnHandle_t), but the ORT C API lacks KernelInfoGetAttributeArray_string +# which rnn.h uses via GetAttrs("activations", ...). +# Re-excluded until the C API is extended to support string array attributes. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/rnn/.*") + +list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/rnn/.*") + +# Exclude files that use TensorSeq (incomplete type in plugin build). +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/identity_op\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/sequence_op\\.cc$") + +# Permanently excluded — pure CPU ops, handled by GetCpuPreferredNodes. +# size.cc registers onnxruntime::Size (CPU op) whose Compute() body lives +# in the CPU provider and is not linked into the plugin. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/size\\.cc$") + +# scatter_nd.cc: ValidateShapes inlined for plugin, GetComputeStream fixed. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/scatter_nd\\.cc$") # REMOVED in Stage 5 + +# Exclude llm/ — attention.cc calls QkvToContext which dereferences +# onnxruntime::Stream* (not available in plugin build's adapter OpKernelContext). +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/llm/.*") +list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/llm/.*") + +# Exclude constant_of_shape — inherits from ConstantOfShapeBase (CPU provider) +# which is not linked into the plugin. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/generator/constant_of_shape\\.cc$") + +# matmul_integer.cc: GetComputeStream fixed, GemmInt8 signature updated. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/math/matmul_integer\.cc$") # REMOVED in Stage 5 + +# matmul.cc: GetComputeStream fixed, GetTuningContext guarded. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/math/matmul\.cc$") # REMOVED in Stage 5 + +# variadic_elementwise_ops.cc: adapter InputCount/RequiredInput/RequiredOutput supported. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/math/variadic_elementwise_ops\\.cc$") # REMOVED in Stage 5C + +# slice.cc: plugin-local wrappers added for SliceBase::PrepareForCompute/FlattenOutputDims. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/slice\\.cc$") # REMOVED in Stage 5C.3 + +# Exclude space_depth_ops — inherits from SpaceDepthBase (CPU provider). +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/space_depth_ops\\.cc$") + +# concat.cc: InputArgCount/GetComputeStream usage fixed for adapter. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/concat\\.cc$") # REMOVED in Stage 5A + +# gather.cc: switched to GatherBase::PrepareForComputeImpl for adapter context. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/gather\\.cc$") # REMOVED in Stage 5B + +# gather_nd.cc: PrepareCompute signature changed to void*/cudaStream_t, GetComputeStream fixed. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/gather_nd\\.cc$") # REMOVED in Stage 5 + +# pad.cc: plugin-local wrappers added for PadBase static helpers. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/pad\\.cc$") # REMOVED in Stage 5C.2 + +# reshape.cc: GetComputeStream/CopyTensor framework dependency fixed for adapter. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/reshape\\.cc$") # REMOVED in Stage 5A + +# split.cc: GetComputeStream usage fixed for adapter via CudaKernel::GetComputeStream. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/split\\.cc$") # REMOVED in Stage 5A + +# Exclude object_detection/ — NonMaxSuppression and RoiAlign inherit from CPU +# base classes (NonMaxSuppressionBase, RoiAlignBase) not linked into the plugin. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/object_detection/.*") +list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/object_detection/.*") + +# Exclude upsample.cc — UpsampleBase uses InputDefs() and +# OpKernelInfo::GetAllocator() not available in adapter. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/upsample\\.cc$") + +# Exclude resize.cc — Resize inherits from Upsample (excluded above). +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/resize\\.cc$") + +# Exclude einsum — einsum_auxiliary_ops.cc calls ReductionOps::ReduceCompute +# which is framework-only (guarded by #ifndef BUILD_CUDA_EP_AS_PLUGIN). +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/math/einsum_utils/.*") +list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/math/einsum_utils/.*") + +# unsqueeze.cc: plugin-local PrepareCompute path added for adapter context. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/unsqueeze\\.cc$") # REMOVED in Stage 5B + +# Permanently excluded — pure CPU ops, handled by GetCpuPreferredNodes. +# shape_op.cc inherits from onnxruntime::OpKernel (framework) +# which cannot convert to ep::adapter::OpKernel in the plugin build. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/shape_op\\.cc$") + +# cumsum.cc: axis parsing helper inlined for plugin build. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/math/cumsum\\.cc$") # REMOVED in Stage 5B + +# tile.cc: plugin-local IsTileMemcpy helper added. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/tile\\.cc$") # REMOVED in Stage 5B + +# --- Contrib op exclusions --- +# Exclude contrib ops that have dependencies not available in the plugin build. +# Note: aten_ops/ and collective/ exclusions are applied above (near the glob). + +# Exclude contrib llm/ — uses onnxruntime::Stream* in QkvToContext. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/llm/.*") +list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/llm/.*") + +# Exclude contrib training ops (shrunken_gather depends on provider_api.h in header). +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/shrunken_gather\\.cc$") + +# Exclude contrib bert ops that use GetComputeStream() or framework OpKernelContext. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/attention\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/decoder_attention\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/decoder_masked_self_attention\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/embed_layer_norm\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/fast_gelu\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/group_query_attention\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/longformer_attention\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/multihead_attention\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/packed_attention\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/packed_multihead_attention\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/paged_attention\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/relative_attn_bias\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/remove_padding\\.cc$") + +# Exclude contrib ops using GetComputeStream() or framework type deps. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/diffusion/group_norm\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/fused_conv\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/inverse\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/bias_dropout\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/fft_ops\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/moe/moe\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/sparse/sparse_attention\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/crop\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/dynamic_time_warping\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/dynamicslice\\.cc$") + +# Exclude contrib quantization ops with GetComputeStream() deps. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/attention_quantization\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/matmul_bnb4\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/matmul_nbits\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/moe_quantization\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/qordered_ops/.*") + +# Exclude contrib transformers/ (beam search, greedy search, sampling). +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/transformers/.*") +list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/transformers/.*") + +# Exclude gemm_float8.cc/.cu — ComputeInternal is in .cu which uses GetComputeStream(). +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/gemm_float8\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/gemm_float8\\.cu$") + +# fused_matmul.cc: matmul.cc is now included, so fused_matmul can be too. +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/fused_matmul\\.cc$") # REMOVED in Stage 5 + +# Create shared library target using the ORT helper function for plugins +onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_plugin + ${CUDA_PLUGIN_EP_CC_SRCS} + ${CUDA_PLUGIN_EP_CU_SRCS} +) +# Set CUDA standard and flags +set_target_properties(onnxruntime_providers_cuda_plugin PROPERTIES + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON +) + +# Suppress -Werror=maybe-uninitialized for local variables written by +# adapter OpKernelInfo::GetAttr<> (GCC falsely warns about variables that are +# initialised inside GetAttr’s output parameter path). +target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE + $<$:-Wno-maybe-uninitialized> +) +target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE + # Flash-attention, XQA, MoE, and other pure CUDA kernel .cu files must NOT + # receive the ORT-framework force-include (it conflicts with cute::Tensor etc.). + # cuda_plugin_kernels.cu already #include "cuda_kernel_adapter.h" directly. + # Op-registration .cc files do not include it directly, so they need it here. + # Suppress NVCC cudafe warnings: + # 550 - variable set but never used (in adapter headers) + # 2810 - [[nodiscard]] false positive on Status assignments in op_kernel.h / kernel_registry.h + "$<$:--expt-relaxed-constexpr;-Xcudafe;--diag_suppress=550>" + "$<$:SHELL:-Xcudafe --diag_suppress=2810>" + "$<$:-include;${REPO_ROOT}/include/onnxruntime/ep/adapters.h>" + "$<$:SHELL:-include ${CUDA_PLUGIN_EP_DIR}/cuda_kernel_adapter.h>" +) + +include(cudnn_frontend) +include(cutlass) + +# --- Find cuDNN (may be at a custom path via onnxruntime_CUDNN_HOME) --- +set(_CUDNN_SEARCH_PATHS "") +if(onnxruntime_CUDNN_HOME) + list(APPEND _CUDNN_SEARCH_PATHS "${onnxruntime_CUDNN_HOME}") +endif() +if(DEFINED ENV{CUDNN_HOME}) + list(APPEND _CUDNN_SEARCH_PATHS "$ENV{CUDNN_HOME}") +endif() + +set(CUDA_PLUGIN_CUDNN_INCLUDE_DIR ${CUDNN_INCLUDE_DIR}) +set(CUDA_PLUGIN_CUDNN_LIBRARY ${cudnn_LIBRARY}) + +if(NOT CUDA_PLUGIN_CUDNN_INCLUDE_DIR OR NOT CUDA_PLUGIN_CUDNN_LIBRARY) + message(FATAL_ERROR "cuDNN not found (from main ORT search) for CUDA Plugin EP.") +endif() + +message(STATUS "CUDA Plugin EP: cuDNN include: ${CUDA_PLUGIN_CUDNN_INCLUDE_DIR}") +message(STATUS "CUDA Plugin EP: cuDNN library: ${CUDA_PLUGIN_CUDNN_LIBRARY}") + +# Include directories — only public ORT headers + CUDA toolkit + cuDNN + internal headers for adapter +target_include_directories(onnxruntime_providers_cuda_plugin PRIVATE + ${REPO_ROOT}/include + ${REPO_ROOT}/include/onnxruntime/core/session + ${ONNXRUNTIME_ROOT} + ${CUDAToolkit_INCLUDE_DIRS} + ${CUDA_PLUGIN_CUDNN_INCLUDE_DIR} + ${Eigen3_SOURCE_DIR} + ${cutlass_SOURCE_DIR}/include + ${cutlass_SOURCE_DIR}/examples + ${cutlass_SOURCE_DIR}/tools/util/include +) + +onnxruntime_add_include_to_target( + onnxruntime_providers_cuda_plugin + onnxruntime_common + onnx + onnx_proto + ${PROTOBUF_LIB} + flatbuffers::flatbuffers +) + +# Link libraries +target_link_libraries(onnxruntime_providers_cuda_plugin PRIVATE + CUDA::cudart + CUDA::cublas + CUDA::cublasLt + CUDNN::cudnn_all + cudnn_frontend + ${CUDA_PLUGIN_CUDNN_LIBRARY} + Boost::mp11 + safeint_interface + onnxruntime_framework + onnxruntime_graph + onnxruntime_mlas + onnxruntime_flatbuffers + onnxruntime_common + cpuinfo::cpuinfo + onnx + onnx_proto + ${PROTOBUF_LIB} +) + +# Symbol visibility — only export CreateEpFactories and ReleaseEpFactory +target_compile_definitions(onnxruntime_providers_cuda_plugin PRIVATE ORT_API_MANUAL_INIT BUILD_CUDA_EP_AS_PLUGIN ONNX_ML=1 ONNX_NAMESPACE=onnx ONNX_USE_LITE_PROTO=1) + +if(WIN32) + # Windows: use .def file for symbol exports + set(CUDA_PLUGIN_DEF_FILE ${CUDA_PLUGIN_EP_DIR}/cuda_plugin_ep_symbols.def) + if(EXISTS ${CUDA_PLUGIN_DEF_FILE}) + target_sources(onnxruntime_providers_cuda_plugin PRIVATE ${CUDA_PLUGIN_DEF_FILE}) + endif() +else() + # Linux/macOS: hide all symbols by default, explicitly export via __attribute__((visibility("default"))) + set_target_properties(onnxruntime_providers_cuda_plugin PROPERTIES + C_VISIBILITY_PRESET hidden + CXX_VISIBILITY_PRESET hidden + ) +endif() + + + +# Set output name +set_target_properties(onnxruntime_providers_cuda_plugin PROPERTIES + OUTPUT_NAME "onnxruntime_providers_cuda_plugin" +) + +# Install +install(TARGETS onnxruntime_providers_cuda_plugin + LIBRARY DESTINATION lib + RUNTIME DESTINATION bin +) diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index c7f7f23f70334..bd87c49d39ea5 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -43,6 +43,16 @@ class Tensor final { // Strive not to allocate Tensor with new/delete as it is a shallow class and using it by value is just fine. // Use InitOrtValue() methods to allocate for OrtValue. +#ifdef BUILD_CUDA_EP_AS_PLUGIN + /// Static factory kept for plugin EP kernels that still call Tensor::Create(). + /// The main tree deprecated these in favor of constructors, but dynamically-linked + /// plugin code relies on the static method. + static std::unique_ptr Create(MLDataType elt_type, const TensorShape& shape, + std::shared_ptr allocator) { + return std::make_unique(elt_type, shape, std::move(allocator)); + } +#endif + Tensor() = default; // to allow creating vector to support seq(tensor) /** diff --git a/include/onnxruntime/ep/adapter/kernel_def_builder.h b/include/onnxruntime/ep/adapter/kernel_def_builder.h index 279c2782ef8eb..276c793a4311c 100644 --- a/include/onnxruntime/ep/adapter/kernel_def_builder.h +++ b/include/onnxruntime/ep/adapter/kernel_def_builder.h @@ -130,6 +130,8 @@ struct KernelDefBuilder { return *this; } + // ExecQueueId is intentionally a no-op. The plugin EP manages stream + // assignment externally; the queue id hint is not needed. KernelDefBuilder& ExecQueueId(int /*queue_id*/) { return *this; } Ort::KernelDef Build() { return builder_.Build(); } diff --git a/include/onnxruntime/ep/adapter/node.h b/include/onnxruntime/ep/adapter/node.h index 8510f7bc5031b..cdeb9209389a3 100644 --- a/include/onnxruntime/ep/adapter/node.h +++ b/include/onnxruntime/ep/adapter/node.h @@ -26,6 +26,11 @@ struct Node { return kernel_info_.GetOperatorType(); } + /** Gets the Node's domain. */ + std::string Domain() const { + return kernel_info_.GetOperatorDomain(); + } + /** Gets the since version of the operator. */ int SinceVersion() const noexcept { return kernel_info_.GetOperatorSinceVersion(); diff --git a/include/onnxruntime/ep/adapter/op_kernel.h b/include/onnxruntime/ep/adapter/op_kernel.h index 0a4b16321a8eb..fece85435733d 100644 --- a/include/onnxruntime/ep/adapter/op_kernel.h +++ b/include/onnxruntime/ep/adapter/op_kernel.h @@ -35,7 +35,7 @@ struct OpKernel { explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_{info} {} virtual ~OpKernel() {} - Node Node() const { + adapter::Node Node() const { return op_kernel_info_.node(); } const OpKernelInfo& Info() const { @@ -93,6 +93,13 @@ struct OpKernelContext { input_tensors_[index] = CreateTensorFromApiValue(const_cast(static_cast(input))); return &input_tensors_[index]; } + template >> + const T& RequiredInput(int index) const { + auto* input = Input(index); + ORT_ENFORCE(input != nullptr, "Required input ", index, " is null"); + return *input; + } Tensor* Output(int index, const TensorShape& shape) { if (index < 0 || static_cast(index) >= output_tensors_.size()) { return nullptr; @@ -109,6 +116,11 @@ struct OpKernelContext { output_tensors_[index] = CreateTensorFromApiValue(output); return &output_tensors_[index]; } + Tensor& RequiredOutput(int index, const TensorShape& shape) { + auto* output = Output(index, shape); + ORT_ENFORCE(output != nullptr, "Required output ", index, " is null"); + return *output; + } Tensor* Output(int index, const std::vector& shape) { return Output(index, TensorShape{shape}); } @@ -131,7 +143,6 @@ struct OpKernelContext { // TODO(fs-eire): Implement GetUseDeterministicCompute(). return false; } - void* GetGPUComputeStream() const { return context_.GetGPUComputeStream(); } @@ -146,7 +157,7 @@ struct OpKernelContext { }; /// -/// A bridge class between `onnxruntime::ep::adapter::OpKernel` and `::OrtKernelImpl`. +/// A bridge class between `onnxruntime::ep::adapter::OpKernel` and `onnxruntime::OrtKernelImpl`. /// struct KernelImpl : OrtKernelImpl { explicit KernelImpl(std::unique_ptr impl) diff --git a/include/onnxruntime/ep/api.h b/include/onnxruntime/ep/api.h index 36d99e5d44d45..8c8951490a55d 100644 --- a/include/onnxruntime/ep/api.h +++ b/include/onnxruntime/ep/api.h @@ -15,6 +15,8 @@ namespace onnxruntime { namespace ep { struct ApiPtrs { + ApiPtrs(const OrtApi& ort_, const OrtEpApi& ep_, const OrtModelEditorApi& model_editor_) + : ort(ort_), ep(ep_), model_editor(model_editor_) {} const OrtApi& ort; const OrtEpApi& ep; const OrtModelEditorApi& model_editor; diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 42989d09cfa85..f90eb2813afc4 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -3,6 +3,8 @@ #pragma once +#ifndef BUILD_CUDA_EP_AS_PLUGIN + // The following three lines were copied from ABSL // cutlass needs them, because cutlass uses "and"/"or" keywords #ifdef __cplusplus @@ -52,6 +54,7 @@ namespace cuda { #define CUDNN2_RETURN_IF_ERROR(expr, m) ORT_RETURN_IF_ERROR(CUDNN_CALL2(expr, m)) #define CUFFT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUFFT_CALL(expr)) #endif + // Type mapping for MLFloat16 to half template class ToCudaType { @@ -238,3 +241,7 @@ class HalfGemmOptions { } // namespace onnxruntime #include "core/providers/cuda/cuda_common_type_helpers.h" +#else +// Define shims and basic types needed by kernels in plugin build when cuda_common.h is included +#include "core/providers/cuda/plugin/cuda_kernel_adapter.h" +#endif diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index 70ba7657e6747..2aaff1192073b 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -3,6 +3,8 @@ #pragma once +#ifndef BUILD_CUDA_EP_AS_PLUGIN + #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cuda/cuda_fwd.h" @@ -12,6 +14,11 @@ namespace onnxruntime { namespace cuda { +#ifndef CUDA_STREAM_FROM_CTX +// Helper for kernels that need a cudaStream_t from OpKernelContext in both framework and plugin builds. +#define CUDA_STREAM_FROM_CTX(ctx) static_cast(GetComputeStream(ctx)) +#endif + // ----------------------------------------------------------------------- // Base class for CUDA kernels // ----------------------------------------------------------------------- @@ -49,6 +56,19 @@ class CudaKernel : public OpKernel { stream); } + // void* overload for dual-build compatibility with the plugin EP. + // In the framework build, the void* is always a static_cast(onnxruntime::Stream*). + template + inline IAllocatorUniquePtr GetScratchBuffer(size_t count_or_bytes, void* stream) const { + return GetScratchBuffer(count_or_bytes, static_cast(stream)); + } + + // Resolve nullptr ambiguity between Stream* and void* overloads. + template + inline IAllocatorUniquePtr GetScratchBuffer(size_t count_or_bytes, std::nullptr_t) const { + return GetScratchBuffer(count_or_bytes, static_cast(nullptr)); + } + // Different from GetScratchBuffer which use IAllocator::Alloc() to allocate memory, // this GetTransientScratchBuffer will call IAllocator::Reserve() to allocate memory. // IAllocator::Reserve() optionally implement some allocation logic that by-passes any arena-based @@ -65,6 +85,11 @@ class CudaKernel : public OpKernel { cuda_ep_stream->EnqueDeferredCPUBuffer(p); } + // void* overload for dual-build compatibility with the plugin EP. + inline void AddDeferredReleaseCPUPtr(void* p, void* stream) const { + AddDeferredReleaseCPUPtr(p, static_cast(stream)); + } + template inline IAllocatorUniquePtr AllocateBufferOnCPUPinned(size_t count_or_bytes) const { if (count_or_bytes == 0) return nullptr; @@ -174,6 +199,11 @@ class CudaKernel : public OpKernel { return Status::OK(); } + // void* overload for dual-build compatibility with the plugin EP. + Status CopyToGpu(void* stream) { + return CopyToGpu(static_cast(stream)); + } + T* CpuPtr() const { return cpu_pinned_copy_.get(); } @@ -234,3 +264,7 @@ class CudaKernel : public OpKernel { } // namespace cuda } // namespace onnxruntime + +#else +#include "core/providers/cuda/plugin/cuda_kernel_adapter.h" +#endif diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc index 5839d10b4345f..e36f745a8fdfa 100755 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -176,7 +176,7 @@ class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(16, 19, float, GridSample); class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(20, 21, float, GridSample); class CUDA_NHWC_OP_TYPED_CLASS_NAME(22, float, GridSample); -onnxruntime::common::Status RegisterCudaNhwcContribKernels(onnxruntime::KernelRegistry& kernel_registry) { +onnxruntime::common::Status RegisterCudaNhwcContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn nhwc_function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.h b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.h index 5a4d3493fdae6..f554d48f4d1f5 100644 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.h +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.h @@ -8,14 +8,14 @@ namespace onnxruntime::cuda { -onnxruntime::common::Status RegisterCudaNhwcKernels(onnxruntime::KernelRegistry& kernel_registry); +onnxruntime::common::Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry); } // namespace onnxruntime::cuda #ifndef DISABLE_CONTRIB_OPS namespace onnxruntime::contrib::cuda { -onnxruntime::common::Status RegisterCudaNhwcContribKernels(onnxruntime::KernelRegistry& kernel_registry); +onnxruntime::common::Status RegisterCudaNhwcContribKernels(KernelRegistry& kernel_registry); } // namespace onnxruntime::contrib::cuda #endif diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index b267ef6bed64f..6725d1120be23 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -149,9 +149,11 @@ struct Consts { inline double ClampCudnnBatchNormEpsilon(double epsilon) { if (epsilon < CUDNN_BN_MIN_EPSILON) { +#ifndef BUILD_CUDA_EP_AS_PLUGIN if (CUDNN_BN_MIN_EPSILON - epsilon > FLT_EPSILON) LOGS_DEFAULT(WARNING) << "Provided epsilon is smaller than CUDNN_BN_MIN_EPSILON. " << "Setting it to CUDNN_BN_MIN_EPSILON"; +#endif return CUDNN_BN_MIN_EPSILON; } return epsilon; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc new file mode 100644 index 0000000000000..5da80f4121c9a --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cuda_allocator_plugin.h" + +namespace onnxruntime { +namespace cuda_plugin { + +// --------------------------------------------------------------------------- +// CudaDeviceAllocator — uses cudaMalloc/cudaFree for GPU device memory. +// Note: No arena or caching layer — every allocation goes directly to CUDA. +// --------------------------------------------------------------------------- + +CudaDeviceAllocator::CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int device_id) + : OrtAllocator{}, + memory_info_(memory_info), + device_id_(device_id) { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + Reserve = ReserveImpl; + GetStats = nullptr; + AllocOnStream = nullptr; +} + +/*static*/ void* ORT_API_CALL CudaDeviceAllocator::AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept { + auto* alloc = static_cast(this_ptr); + void* p = nullptr; + if (size == 0) return nullptr; + cudaSetDevice(alloc->device_id_); + cudaError_t err = cudaMalloc(&p, size); + if (err != cudaSuccess) { + return nullptr; + } + return p; +} + +/*static*/ void ORT_API_CALL CudaDeviceAllocator::FreeImpl(OrtAllocator* this_ptr, void* p) noexcept { + auto* alloc = static_cast(this_ptr); + if (p != nullptr) { + cudaSetDevice(alloc->device_id_); + cudaFree(p); + } +} + +/*static*/ const OrtMemoryInfo* ORT_API_CALL CudaDeviceAllocator::InfoImpl(const OrtAllocator* this_ptr) noexcept { + const auto* alloc = static_cast(this_ptr); + return alloc->memory_info_; +} + +/*static*/ void* ORT_API_CALL CudaDeviceAllocator::ReserveImpl(OrtAllocator* this_ptr, size_t size) noexcept { + // Reserve currently delegates to Alloc (no separate reservation pool). + return AllocImpl(this_ptr, size); +} + +// --------------------------------------------------------------------------- +// CudaPinnedAllocator — uses cudaHostAlloc/cudaFreeHost for page-locked +// host memory visible to the GPU. +// --------------------------------------------------------------------------- + +CudaPinnedAllocator::CudaPinnedAllocator(const OrtMemoryInfo* memory_info) + : OrtAllocator{}, + memory_info_(memory_info) { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + Reserve = ReserveImpl; + GetStats = nullptr; + AllocOnStream = nullptr; +} + +/*static*/ void* ORT_API_CALL CudaPinnedAllocator::AllocImpl(OrtAllocator* /*this_ptr*/, size_t size) noexcept { + void* p = nullptr; + if (size == 0) return nullptr; + cudaError_t err = cudaHostAlloc(&p, size, cudaHostAllocDefault); + if (err != cudaSuccess) { + return nullptr; + } + return p; +} + +/*static*/ void ORT_API_CALL CudaPinnedAllocator::FreeImpl(OrtAllocator* /*this_ptr*/, void* p) noexcept { + if (p != nullptr) { + cudaFreeHost(p); + } +} + +/*static*/ const OrtMemoryInfo* ORT_API_CALL CudaPinnedAllocator::InfoImpl(const OrtAllocator* this_ptr) noexcept { + const auto* alloc = static_cast(this_ptr); + return alloc->memory_info_; +} + +/*static*/ void* ORT_API_CALL CudaPinnedAllocator::ReserveImpl(OrtAllocator* this_ptr, size_t size) noexcept { + // Reserve currently delegates to Alloc (no separate reservation pool). + return AllocImpl(this_ptr, size); +} + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h new file mode 100644 index 0000000000000..c39270774b992 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_plugin_utils.h" + +namespace onnxruntime { +namespace cuda_plugin { + +/// CUDA device memory allocator using cudaMalloc/cudaFree. +/// Lifetime is managed by the EP factory (ReleaseAllocatorImpl), not by a Release callback. +class CudaDeviceAllocator : public OrtAllocator { + public: + CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int device_id); + ~CudaDeviceAllocator() = default; + + private: + static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept; + static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept; + static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept; + static void* ORT_API_CALL ReserveImpl(OrtAllocator* this_ptr, size_t size) noexcept; + + const OrtMemoryInfo* memory_info_; + int device_id_; +}; + +/// CUDA pinned (host) memory allocator using cudaHostAlloc/cudaFreeHost. +/// Lifetime is managed by the EP factory (ReleaseAllocatorImpl), not by a Release callback. +class CudaPinnedAllocator : public OrtAllocator { + public: + CudaPinnedAllocator(const OrtMemoryInfo* memory_info); + ~CudaPinnedAllocator() = default; + + private: + static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept; + static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept; + static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept; + static void* ORT_API_CALL ReserveImpl(OrtAllocator* this_ptr, size_t size) noexcept; + + const OrtMemoryInfo* memory_info_; +}; + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc new file mode 100644 index 0000000000000..bc2ba2b8c6f8a --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc @@ -0,0 +1,372 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Plugin EP control flow kernel implementations for If, Loop, and Scan. +// These delegate to OrtEpApi::CreateIfKernel/CreateLoopKernel/CreateScanKernel +// instead of inheriting from CPU base classes. + +#include "core/providers/cuda/plugin/cuda_controlflow_plugin.h" +#include + +namespace onnxruntime { +namespace cuda { +namespace plugin { + +// =================================================================== +// If kernel +// =================================================================== + +Status PluginIfKernel::CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) { + OrtStatus* status = Ort::GetEpApi().CreateIfKernel(info, impl); + if (status) { + std::string msg = Ort::GetApi().GetErrorMessage(status); + Ort::GetApi().ReleaseStatus(status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, msg); + } + return Status::OK(); +} + +// =================================================================== +// Loop kernel helper +// =================================================================== + +PluginLoopHelper::PluginLoopHelper() : OrtLoopKernelHelper{} { + ort_version_supported = ORT_API_VERSION; + Release = ReleaseImpl; + ConcatOutput = ConcatOutputImpl; +} + +/*static*/ +void ORT_API_CALL PluginLoopHelper::ReleaseImpl(_In_ OrtLoopKernelHelper* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +/*static*/ +OrtStatus* ORT_API_CALL PluginLoopHelper::ConcatOutputImpl( + _In_ OrtLoopKernelHelper* /*this_ptr*/, + _In_opt_ void* stream_handle, + _In_reads_(num_per_iteration_outputs) const OrtValue* const* per_iteration_outputs, + _In_ size_t num_per_iteration_outputs, + _Out_writes_bytes_all_(output_size_in_bytes) void* output, + _In_ size_t output_size_in_bytes) noexcept { + try { + if (num_per_iteration_outputs == 0) return nullptr; + + cudaStream_t cuda_stream = static_cast(stream_handle); + + Ort::ConstValue first_output(per_iteration_outputs[0]); + size_t bytes_per_iteration = first_output.GetTensorSizeInBytes(); + + char* cur = static_cast(output); + for (size_t i = 0; i < num_per_iteration_outputs; i++) { + Ort::ConstValue val(per_iteration_outputs[i]); + size_t cur_bytes = val.GetTensorSizeInBytes(); + if (cur_bytes != bytes_per_iteration) { + return Ort::Status("Inconsistent size in loop output iteration", ORT_FAIL).release(); + } + cudaError_t err = cudaMemcpyAsync(cur, val.GetTensorRawData(), bytes_per_iteration, + cudaMemcpyDeviceToDevice, cuda_stream); + if (err != cudaSuccess) { + return Ort::Status((std::string("cudaMemcpyAsync failed in Loop ConcatOutput: ") + + cudaGetErrorString(err)) + .c_str(), + ORT_FAIL) + .release(); + } + cur += bytes_per_iteration; + } + + if (static_cast(cur - static_cast(output)) != output_size_in_bytes) { + return Ort::Status("Loop ConcatOutput: output buffer not fully filled", ORT_FAIL).release(); + } + + return nullptr; + } catch (const std::exception& ex) { + return Ort::Status(ex.what(), ORT_RUNTIME_EXCEPTION).release(); + } +} + +// =================================================================== +// Loop kernel +// =================================================================== + +Status PluginLoopKernel::CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) { + auto helper = std::make_unique(); + OrtStatus* status = Ort::GetEpApi().CreateLoopKernel(info, helper.get(), impl); + if (status) { + std::string msg = Ort::GetApi().GetErrorMessage(status); + Ort::GetApi().ReleaseStatus(status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, msg); + } + helper.release(); // ORT takes ownership on success + return Status::OK(); +} + +// =================================================================== +// Scan kernel helper +// =================================================================== + +PluginScanHelper::PluginScanHelper() : OrtScanKernelHelper{} { + ort_version_supported = ORT_API_VERSION; + Release = ReleaseImpl; + Transpose = TransposeImpl; +} + +/*static*/ +void ORT_API_CALL PluginScanHelper::ReleaseImpl(_In_ OrtScanKernelHelper* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +/*static*/ +OrtStatus* ORT_API_CALL PluginScanHelper::TransposeImpl( + _In_ OrtScanKernelHelper* /*this_ptr*/, + _In_reads_(num_permutation_elems) const size_t* permutation, + _In_ size_t num_permutation_elems, + _In_ const OrtValue* ort_input, + _In_opt_ OrtSyncStream* stream, + _Inout_ OrtValue* ort_output) noexcept { + try { + // Get the CUDA stream from the OrtSyncStream + cudaStream_t cuda_stream = nullptr; + if (stream) { + const OrtSyncStreamImpl* impl = Ort::GetEpApi().SyncStream_GetImpl(stream); + if (impl) { + // GetHandle is a function pointer on OrtSyncStreamImpl + cuda_stream = static_cast( + const_cast(impl)->GetHandle(const_cast(impl))); + } + } + + Ort::ConstValue input(ort_input); + Ort::UnownedValue output(ort_output); + + Ort::TensorTypeAndShapeInfo input_info = input.GetTensorTypeAndShapeInfo(); + std::vector input_shape = input_info.GetShape(); + size_t num_dims = input_shape.size(); + size_t total_elements = input_info.GetElementCount(); + + if (num_dims != num_permutation_elems) { + return Ort::Status("Scan Transpose: permutation size does not match input rank", ORT_FAIL).release(); + } + + if (total_elements == 0) return nullptr; + + // Determine element size from the data type + ONNXTensorElementDataType elem_type = input_info.GetElementType(); + size_t element_size = 0; + switch (elem_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + element_size = sizeof(float); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + element_size = sizeof(double); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + element_size = 2; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + element_size = 1; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + element_size = 2; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + element_size = 4; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + element_size = 8; + break; + default: + return Ort::Status("Scan Transpose: unsupported element type", ORT_FAIL).release(); + } + + const void* input_data = input.GetTensorRawData(); + void* output_data = output.GetTensorMutableData(); + + // Launch the GPU transpose kernel + LaunchTransposeKernel(input_data, output_data, + input_shape.data(), permutation, + num_dims, element_size, total_elements, + cuda_stream); + + return nullptr; + } catch (const std::exception& ex) { + return Ort::Status(ex.what(), ORT_RUNTIME_EXCEPTION).release(); + } +} + +// =================================================================== +// Scan kernel +// =================================================================== + +Status PluginScanKernel::CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) { + auto helper = std::make_unique(); + OrtStatus* status = Ort::GetEpApi().CreateScanKernel(info, helper.get(), impl); + if (status) { + std::string msg = Ort::GetApi().GetErrorMessage(status); + Ort::GetApi().ReleaseStatus(status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, msg); + } + helper.release(); // ORT takes ownership on success + return Status::OK(); +} + +} // namespace plugin +} // namespace cuda +} // namespace onnxruntime + +// =================================================================== +// Kernel Registrations — same opset versions as the framework CUDA EP +// =================================================================== + +using namespace onnxruntime::cuda::plugin; + +namespace onnxruntime { +namespace cuda { + +// --- If --- + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 1, 10, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginIfKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 11, 12, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginIfKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 13, 18, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginIfKernel); + +ONNX_OPERATOR_KERNEL_EX(If, + kOnnxDomain, + 19, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + PluginIfKernel); + +// --- Loop --- + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, + kOnnxDomain, + 1, 10, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginLoopKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, + kOnnxDomain, + 11, 12, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginLoopKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, + kOnnxDomain, + 13, 18, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginLoopKernel); + +ONNX_OPERATOR_KERNEL_EX(Loop, + kOnnxDomain, + 19, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + PluginLoopKernel); + +// --- Scan (opset 8) --- + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, + kOnnxDomain, + 8, 8, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorTypes()), + PluginScanKernel); + +// --- Scan (opset 9+) --- + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, + kOnnxDomain, + 9, 10, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginScanKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, + kOnnxDomain, + 11, 15, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginScanKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, + kOnnxDomain, + 16, 18, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginScanKernel); + +ONNX_OPERATOR_KERNEL_EX(Scan, + kOnnxDomain, + 19, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + PluginScanKernel); + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu new file mode 100644 index 0000000000000..43df89468c42d --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// GPU transpose kernel for the Scan control flow helper. +// Handles arbitrary N-D permutations by computing output coordinates +// from linear indices. + +#include +#include +#include +#include + +namespace onnxruntime { +namespace cuda { +namespace plugin { + +// Maximum number of dimensions supported by the transpose kernel. +// Most real-world tensors have <= 8 dimensions. +constexpr int kMaxTransposeDims = 8; + +// Kernel: each thread handles one element, computing its output position +// from the input position via the permutation. +__global__ void TransposeNDKernel(const char* __restrict__ input, + char* __restrict__ output, + const int64_t* __restrict__ input_strides, + const int64_t* __restrict__ output_strides, + const int* __restrict__ perm, + int num_dims, + size_t element_size, + size_t total_elements) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= total_elements) return; + + // Decompose linear index into input coordinates + int64_t coords[kMaxTransposeDims]; + size_t remaining = idx; + for (int d = 0; d < num_dims; d++) { + coords[d] = static_cast(remaining / static_cast(input_strides[d])); + remaining %= static_cast(input_strides[d]); + } + + // Compute output linear index via permutation + size_t out_idx = 0; + for (int d = 0; d < num_dims; d++) { + out_idx += static_cast(coords[perm[d]]) * static_cast(output_strides[d]); + } + + // Copy element bytes + const char* src = input + idx * element_size; + char* dst = output + out_idx * element_size; + // Use memcpy for arbitrary element sizes (compiler optimizes for common sizes) + memcpy(dst, src, element_size); +} + +void LaunchTransposeKernel(const void* input, void* output, + const int64_t* input_shape, const size_t* permutation, + size_t num_dims, size_t element_size, size_t total_elements, + cudaStream_t stream) { + if (total_elements == 0 || num_dims == 0) return; + + // Compute input strides (row-major) + int64_t input_strides[kMaxTransposeDims]; + input_strides[num_dims - 1] = 1; + for (int d = static_cast(num_dims) - 2; d >= 0; d--) { + input_strides[d] = input_strides[d + 1] * input_shape[d + 1]; + } + + // Compute output shape and strides from permutation + int64_t output_shape[kMaxTransposeDims]; + int64_t output_strides[kMaxTransposeDims]; + int perm_int[kMaxTransposeDims]; + for (size_t d = 0; d < num_dims; d++) { + output_shape[d] = input_shape[permutation[d]]; + perm_int[d] = static_cast(permutation[d]); + } + output_strides[num_dims - 1] = 1; + for (int d = static_cast(num_dims) - 2; d >= 0; d--) { + output_strides[d] = output_strides[d + 1] * output_shape[d + 1]; + } + + // Copy arrays to device + int64_t* d_input_strides = nullptr; + int64_t* d_output_strides = nullptr; + int* d_perm = nullptr; + + cudaMalloc(&d_input_strides, num_dims * sizeof(int64_t)); + cudaMalloc(&d_output_strides, num_dims * sizeof(int64_t)); + cudaMalloc(&d_perm, num_dims * sizeof(int)); + + cudaMemcpyAsync(d_input_strides, input_strides, num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(d_output_strides, output_strides, num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(d_perm, perm_int, num_dims * sizeof(int), cudaMemcpyHostToDevice, stream); + + constexpr int kBlockSize = 256; + int num_blocks = static_cast((total_elements + kBlockSize - 1) / kBlockSize); + + TransposeNDKernel<<>>( + static_cast(input), + static_cast(output), + d_input_strides, + d_output_strides, + d_perm, + static_cast(num_dims), + element_size, + total_elements); + + // Free device arrays asynchronously + cudaFreeAsync(d_input_strides, stream); + cudaFreeAsync(d_output_strides, stream); + cudaFreeAsync(d_perm, stream); +} + +} // namespace plugin +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h new file mode 100644 index 0000000000000..efee4bb215342 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Plugin EP control flow kernel wrappers for If, Loop, and Scan. +// These delegate to OrtEpApi::CreateIfKernel/CreateLoopKernel/CreateScanKernel +// instead of inheriting from CPU base classes. + +#pragma once + +#include "core/session/onnxruntime_cxx_api.h" + +namespace onnxruntime { +namespace cuda { +namespace plugin { + +// =================================================================== +// If kernel wrapper — delegates to OrtEpApi::CreateIfKernel +// =================================================================== + +class PluginIfKernel : public OpKernel { + public: + explicit PluginIfKernel(const OpKernelInfo& info) : OpKernel(info) {} + + Status CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) override; + + Status Compute(OpKernelContext*) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Plugin If kernel should not be called directly"); + } +}; + +// =================================================================== +// Loop kernel helper — provides GPU ConcatOutput via cudaMemcpyAsync +// =================================================================== + +class PluginLoopHelper : public OrtLoopKernelHelper { + public: + PluginLoopHelper(); + + static void ORT_API_CALL ReleaseImpl(_In_ OrtLoopKernelHelper* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL ConcatOutputImpl( + _In_ OrtLoopKernelHelper* this_ptr, + _In_opt_ void* stream_handle, + _In_reads_(num_per_iteration_outputs) const OrtValue* const* per_iteration_outputs, + _In_ size_t num_per_iteration_outputs, + _Out_writes_bytes_all_(output_size_in_bytes) void* output, + _In_ size_t output_size_in_bytes) noexcept; +}; + +class PluginLoopKernel : public OpKernel { + public: + explicit PluginLoopKernel(const OpKernelInfo& info) : OpKernel(info) {} + + Status CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) override; + + Status Compute(OpKernelContext*) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Plugin Loop kernel should not be called directly"); + } +}; + +// =================================================================== +// Scan kernel helper — provides GPU Transpose via CUDA kernel +// =================================================================== + +class PluginScanHelper : public OrtScanKernelHelper { + public: + PluginScanHelper(); + + static void ORT_API_CALL ReleaseImpl(_In_ OrtScanKernelHelper* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL TransposeImpl( + _In_ OrtScanKernelHelper* this_ptr, + _In_reads_(num_permutation_elems) const size_t* permutation, + _In_ size_t num_permutation_elems, + _In_ const OrtValue* input, + _In_opt_ OrtSyncStream* stream, + _Inout_ OrtValue* output) noexcept; +}; + +class PluginScanKernel : public OpKernel { + public: + explicit PluginScanKernel(const OpKernelInfo& info) : OpKernel(info) {} + + Status CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) override; + + Status Compute(OpKernelContext*) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Plugin Scan kernel should not be called directly"); + } +}; + +// GPU transpose helper (defined in cuda_controlflow_plugin.cu) +void LaunchTransposeKernel(const void* input, void* output, + const int64_t* input_shape, const size_t* permutation, + size_t num_dims, size_t element_size, size_t total_elements, + cudaStream_t stream); + +} // namespace plugin +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc new file mode 100644 index 0000000000000..5226ff98b08f5 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cuda_data_transfer_plugin.h" + +namespace onnxruntime { +namespace cuda_plugin { + +CudaDataTransfer::CudaDataTransfer(const OrtApi& ort_api, const OrtEpApi& ep_api, + const OrtMemoryDevice* gpu_device) + : OrtDataTransferImpl{}, + ort_api_(ort_api), + ep_api_(ep_api), + gpu_device_(gpu_device) { + ort_version_supported = ORT_API_VERSION; + Release = ReleaseImpl; + CanCopy = CanCopyImpl; + CopyTensors = CopyTensorsImpl; +} + +/*static*/ void ORT_API_CALL CudaDataTransfer::ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +/*static*/ bool ORT_API_CALL CudaDataTransfer::CanCopyImpl( + const OrtDataTransferImpl* this_ptr, + const OrtMemoryDevice* src_device, + const OrtMemoryDevice* dst_device) noexcept { + auto* dt = static_cast(this_ptr); + const OrtEpApi& ep_api = dt->ep_api_; + auto src_type = ep_api.MemoryDevice_GetDeviceType(src_device); + auto dst_type = ep_api.MemoryDevice_GetDeviceType(dst_device); + + bool src_is_cpu = (src_type == OrtMemoryInfoDeviceType_CPU); + bool dst_is_cpu = (dst_type == OrtMemoryInfoDeviceType_CPU); + bool src_is_gpu = (src_type == OrtMemoryInfoDeviceType_GPU); + bool dst_is_gpu = (dst_type == OrtMemoryInfoDeviceType_GPU); + + // Support CPU→GPU, GPU→CPU, GPU→GPU + return (src_is_cpu && dst_is_gpu) || + (src_is_gpu && dst_is_cpu) || + (src_is_gpu && dst_is_gpu); +} + +/*static*/ OrtStatus* ORT_API_CALL CudaDataTransfer::CopyTensorsImpl( + OrtDataTransferImpl* this_ptr, + const OrtValue** src_tensors, + OrtValue** dst_tensors, + OrtSyncStream** streams, + size_t count) noexcept { + EXCEPTION_TO_STATUS_BEGIN + + auto* dt = static_cast(this_ptr); + + for (size_t i = 0; i < count; ++i) { + Ort::ConstValue src{src_tensors[i]}; + Ort::UnownedValue dst{dst_tensors[i]}; + + auto src_type_shape = src.GetTensorTypeAndShapeInfo(); + size_t count_elems = src_type_shape.GetElementCount(); + + // Get element size from data type + ONNXTensorElementDataType elem_type = src_type_shape.GetElementType(); + size_t elem_size = 0; + // Compute byte size of the tensor elements. + // ORT's C API doesn't expose an element-size helper directly, so we + // map the ONNX element type to its byte width manually. + // Cases are grouped by element size for clarity. + switch (elem_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + elem_size = 1; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + elem_size = 2; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + elem_size = 4; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + elem_size = 8; + break; + default: + return dt->ort_api_.CreateStatus(ORT_EP_FAIL, "Unsupported tensor element type for copy"); + } + + size_t bytes = count_elems * elem_size; + if (bytes == 0) continue; + + const void* src_data = src.GetTensorRawData(); + void* dst_data = dst.GetTensorMutableRawData(); + + // Determine copy direction + const OrtMemoryInfo* src_mem_info = src.GetTensorMemoryInfo(); + const OrtMemoryInfo* dst_mem_info = dst.GetTensorMemoryInfo(); + const OrtMemoryDevice* src_dev = dt->ep_api_.MemoryInfo_GetMemoryDevice(src_mem_info); + const OrtMemoryDevice* dst_dev = dt->ep_api_.MemoryInfo_GetMemoryDevice(dst_mem_info); + auto src_dev_type = dt->ep_api_.MemoryDevice_GetDeviceType(src_dev); + auto dst_dev_type = dt->ep_api_.MemoryDevice_GetDeviceType(dst_dev); + + cudaMemcpyKind copy_kind; + if (src_dev_type == OrtMemoryInfoDeviceType_CPU && dst_dev_type == OrtMemoryInfoDeviceType_GPU) { + copy_kind = cudaMemcpyHostToDevice; + } else if (src_dev_type == OrtMemoryInfoDeviceType_GPU && dst_dev_type == OrtMemoryInfoDeviceType_CPU) { + copy_kind = cudaMemcpyDeviceToHost; + } else if (src_dev_type == OrtMemoryInfoDeviceType_GPU && dst_dev_type == OrtMemoryInfoDeviceType_GPU) { + copy_kind = cudaMemcpyDeviceToDevice; + } else { + return dt->ort_api_.CreateStatus(ORT_EP_FAIL, "Unsupported copy direction"); + } + + // Use async copy if stream is provided + if (streams != nullptr && streams[i] != nullptr) { + cudaStream_t cuda_stream = static_cast( + Ort::GetApi().SyncStream_GetHandle(streams[i])); + PL_CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, copy_kind, cuda_stream)); + } else { + PL_CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, copy_kind)); + } + } + + return nullptr; + + EXCEPTION_TO_STATUS_END +} + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h new file mode 100644 index 0000000000000..cd662b105973d --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_plugin_utils.h" + +namespace onnxruntime { +namespace cuda_plugin { + +/// CUDA data transfer implementation for CPU↔GPU and GPU↔GPU copies. +class CudaDataTransfer : public OrtDataTransferImpl { + public: + CudaDataTransfer(const OrtApi& ort_api, const OrtEpApi& ep_api, + const OrtMemoryDevice* gpu_device); + ~CudaDataTransfer() = default; + + private: + static void ORT_API_CALL ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept; + + static bool ORT_API_CALL CanCopyImpl( + const OrtDataTransferImpl* this_ptr, + const OrtMemoryDevice* src_device, + const OrtMemoryDevice* dst_device) noexcept; + + static OrtStatus* ORT_API_CALL CopyTensorsImpl( + OrtDataTransferImpl* this_ptr, + const OrtValue** src_tensors, + OrtValue** dst_tensors, + OrtSyncStream** streams, + size_t count) noexcept; + + const OrtApi& ort_api_; + const OrtEpApi& ep_api_; + const OrtMemoryDevice* gpu_device_; +}; + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc new file mode 100644 index 0000000000000..6215cc74f828a --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -0,0 +1,309 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cuda_ep.h" +#include "cuda_ep_factory.h" +#include "cuda_stream_plugin.h" +#include "core/providers/cuda/plugin/cuda_kernel_adapter.h" +#include "ep/get_capability_utils.h" + +#include +#include +#include +#include + +namespace onnxruntime { +namespace cuda_plugin { + +CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& logger) + : OrtEp{}, + factory_(factory), + name_(factory.GetEpName()), + config_(config), + logger_(logger), + cuda_graph_enabled_(config.enable_cuda_graph), + min_runs_before_capture_(config.min_num_runs_before_cuda_graph_capture) { + ort_version_supported = ORT_API_VERSION; + + // Set function pointers for kernel-registry-based EP + GetName = GetNameImpl; + GetCapability = GetCapabilityImpl; + GetKernelRegistry = GetKernelRegistryImpl; + GetPreferredDataLayout = GetPreferredDataLayoutImpl; + ShouldConvertDataLayoutForOp = ShouldConvertDataLayoutForOpImpl; + OnRunStart = OnRunStartImpl; + OnRunEnd = OnRunEndImpl; + + // Not a compile-based EP + Compile = nullptr; + ReleaseNodeComputeInfos = nullptr; + + const OrtApi& ort_api = factory_.GetOrtApi(); + Ort::Status log_status(ort_api.Logger_LogMessage(&logger_, ORT_LOGGING_LEVEL_INFO, + "CUDA Plugin EP created", + ORT_FILE, __LINE__, __FUNCTION__)); + + // Seed adapter-level runtime options for migrated kernels. + onnxruntime::cuda::SetCudaKernelAdapterRuntimeConfig( + config_.use_tf32, config_.device_id, config_.enable_skip_layer_norm_strict_mode, + config_.cudnn_conv_algo, config_.cudnn_conv1d_pad_to_nc1d); +} + +CudaEp::~CudaEp() = default; + +/*static*/ +const char* ORT_API_CALL CudaEp::GetNameImpl(const OrtEp* this_ptr) noexcept { + return static_cast(this_ptr)->name_.c_str(); +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( + OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + EXCEPTION_TO_STATUS_BEGIN + + auto* ep = static_cast(this_ptr); + const OrtEpApi& ep_api = ep->factory_.GetEpApi(); + + Ort::ConstGraph graph{ort_graph}; + std::vector all_nodes = graph.GetNodes(); + + if (all_nodes.empty()) { + return nullptr; + } + + // Phase 1: Collect tentative nodes — those for which we have a registered kernel. + std::vector tentative_nodes; + tentative_nodes.reserve(all_nodes.size()); + + for (const auto& node : all_nodes) { + // Skip nodes already assigned to another EP. + std::string ep_name = node.GetEpName(); + if (!ep_name.empty()) { + continue; + } + + const OrtKernelDef* kernel_def = nullptr; + RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_LookUpKernel( + graph_support_info, node, &kernel_def)); + + if (kernel_def != nullptr) { + tentative_nodes.push_back(node); + } + } + + // Phase 2: Filter out CPU-preferred nodes (e.g., Shape, NonZero, small compute ops + // that would be cheaper on CPU than incurring device-to-host copy overhead). + std::unordered_set cpu_preferred_nodes; + RETURN_IF_ERROR(ep::GetCpuPreferredNodes( + *ort_graph, *graph_support_info, ep->logger_, + gsl::span(tentative_nodes.data(), tentative_nodes.size()), + cpu_preferred_nodes)); + + // Phase 3: Add final supported nodes (tentative minus CPU-preferred). + for (const OrtNode* ort_node : tentative_nodes) { + if (cpu_preferred_nodes.count(ort_node) == 0) { + Ort::ConstNode node{ort_node}; + RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_AddSingleNode( + graph_support_info, node)); + } + } + + return nullptr; + + EXCEPTION_TO_STATUS_END +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEp::GetKernelRegistryImpl( + OrtEp* this_ptr, + const OrtKernelRegistry** kernel_registry) noexcept { + auto* ep = static_cast(this_ptr); + *kernel_registry = nullptr; + + RETURN_IF_ERROR(ep->factory_.GetKernelRegistryForEp(*ep, kernel_registry)); + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEp::GetPreferredDataLayoutImpl( + OrtEp* this_ptr, OrtEpDataLayout* preferred_data_layout) noexcept { + const auto* ep = static_cast(this_ptr); + *preferred_data_layout = ep->config_.prefer_nhwc ? OrtEpDataLayout_NHWC : OrtEpDataLayout_NCHW; + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEp::ShouldConvertDataLayoutForOpImpl( + OrtEp* this_ptr, const char* domain, const char* op_type, + OrtEpDataLayout target_data_layout, int* should_convert) noexcept { + (void)this_ptr; + + // Only convert to NHWC; for any other target layout, let ORT decide. + if (target_data_layout != OrtEpDataLayout_NHWC) { + *should_convert = -1; // Let ORT decide + return nullptr; + } + + // ONNX domain ops that have NHWC kernel registrations. + static const std::unordered_set cuda_nhwc_onnx_ops{ + "BatchNormalization", + "Conv", + "ConvTranspose", + "GlobalMaxPool", + "MaxPool", + "GlobalAveragePool", + "AveragePool", + "GridSample", + "DepthToSpace", + "SpaceToDepth", + "LRN", + }; + + // Check ONNX domain (empty string) or MS domain (com.microsoft) + bool is_onnx_domain = (domain[0] == '\0'); + bool is_ms_domain = (std::strcmp(domain, "com.microsoft") == 0); + + if (is_onnx_domain && cuda_nhwc_onnx_ops.count(op_type) > 0) { + *should_convert = 1; // Convert + return nullptr; + } + + if (is_ms_domain && std::strcmp(op_type, "GridSample") == 0) { + *should_convert = 1; // Convert + return nullptr; + } + + *should_convert = -1; // Let ORT decide for other ops + return nullptr; +} + +// --------------------------------------------------------------------------- +// CUDA Graph helpers +// --------------------------------------------------------------------------- + +CudaGraphAnnotation_t CudaEp::GetAnnotationId(const ::OrtRunOptions* run_options) const { + const OrtApi& ort_api = factory_.GetOrtApi(); + // Use the same key as the bundled CUDA EP: "gpu_graph_id" + const char* val = ort_api.GetRunConfigEntry(run_options, "gpu_graph_id"); + if (val == nullptr) { + return kCudaGraphAnnotationDefault; + } + try { + return std::stoi(val); + } catch (...) { + return kCudaGraphAnnotationDefault; + } +} + +bool CudaEp::IsGraphCaptureAllowed(CudaGraphAnnotation_t annotation_id) const { + if (!cuda_graph_manager_.IsGraphCaptureAllowedOnRun(annotation_id)) { + return false; + } + auto it = graph_id_to_run_count_.find(annotation_id); + if (it == graph_id_to_run_count_.end()) { + return false; + } + return it->second >= min_runs_before_capture_; +} + +// --------------------------------------------------------------------------- +// OnRunStart — manage CUDA graph capture/replay state machine +// --------------------------------------------------------------------------- + +/*static*/ +OrtStatus* ORT_API_CALL CudaEp::OnRunStartImpl( + OrtEp* this_ptr, const ::OrtRunOptions* run_options) noexcept { + EXCEPTION_TO_STATUS_BEGIN + + auto* ep = static_cast(this_ptr); + + if (!ep->cuda_graph_enabled_.load(std::memory_order_relaxed)) { + return nullptr; // Graph capture not enabled — no-op + } + + // gpu_graph_id == -1 means skip capture/replay for this run + // (matches bundled CUDA EP behavior via kOrtRunOptionsConfigCudaGraphAnnotation) + CudaGraphAnnotation_t annotation_id = ep->GetAnnotationId(run_options); + if (annotation_id == kCudaGraphAnnotationSkip) { + return nullptr; + } + + // Lazily set the graph manager's stream from the factory's compute stream. + CudaSyncStream* compute_stream = ep->factory_.GetComputeStream(); + if (compute_stream == nullptr) { + // Stream not yet created — skip graph capture for this run. + // This can happen if OnRunStart is called before CreateSyncStreamForDevice. + return nullptr; + } + ep->cuda_graph_manager_.SetStream(compute_stream->GetCudaStream()); + + if (ep->cuda_graph_manager_.IsGraphCaptured(annotation_id)) { + // Already captured — replay happens in OnRunEnd for the plugin EP. + // ORT runtime will still dispatch kernels normally; the captured graph + // replays the actual GPU work. For the plugin EP without stream executor + // hooks, we replay at OnRunEnd after kernel dispatch completes. + return nullptr; + } + + if (!ep->cuda_graph_manager_.IsGraphCaptured(annotation_id) && + ep->IsGraphCaptureAllowed(annotation_id)) { + // Warm-up period complete — begin capture + ep->cuda_graph_manager_.CaptureBegin(annotation_id); + ep->is_capturing_ = true; + ep->capturing_annotation_id_ = annotation_id; + } + + return nullptr; + + EXCEPTION_TO_STATUS_END +} + +// --------------------------------------------------------------------------- +// OnRunEnd — end capture or handle replay +// --------------------------------------------------------------------------- + +/*static*/ +OrtStatus* ORT_API_CALL CudaEp::OnRunEndImpl( + OrtEp* this_ptr, const ::OrtRunOptions* run_options, bool sync_stream) noexcept { + EXCEPTION_TO_STATUS_BEGIN + + auto* ep = static_cast(this_ptr); + + if (!ep->cuda_graph_enabled_.load(std::memory_order_relaxed)) { + return nullptr; + } + + // gpu_graph_id == -1 means skip capture/replay for this run + CudaGraphAnnotation_t annotation_id = ep->GetAnnotationId(run_options); + if (annotation_id == kCudaGraphAnnotationSkip) { + return nullptr; + } + + if (!ep->cuda_graph_manager_.IsGraphCaptured(annotation_id)) { + if (ep->is_capturing_ && ep->capturing_annotation_id_ == annotation_id) { + // Was capturing — end capture and replay the first time + ep->cuda_graph_manager_.CaptureEnd(annotation_id); + ep->is_capturing_ = false; + + // CUDA work issued to a capturing stream doesn't actually run on the GPU, + // so replay the captured graph to actually execute the work. + OrtStatus* replay_status = ep->cuda_graph_manager_.Replay(annotation_id, sync_stream); + if (replay_status != nullptr) return replay_status; + } else { + // Still in warm-up period — increment run count + ep->graph_id_to_run_count_[annotation_id]++; + } + } + // Note: For subsequent runs after capture, the captured graph is not replayed + // here. The ORT framework dispatches kernels normally (it does not know about + // CUDA graph capture). Full graph-only replay (with kernel dispatch bypass) + // requires stream executor support which is not yet available in the plugin EP. + + return nullptr; + + EXCEPTION_TO_STATUS_END +} + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.h b/onnxruntime/core/providers/cuda/plugin/cuda_ep.h new file mode 100644 index 0000000000000..9d15cd6abdb8a --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.h @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_plugin_utils.h" +#include "cuda_graph_plugin.h" + +#include +#include +#include +#include + +namespace onnxruntime { +namespace cuda_plugin { + +class CudaEpFactory; + +/// CUDA execution provider implementation using public OrtEp interface. +class CudaEp : public OrtEp { + public: + /// Configuration parameters for the CUDA EP, parsed from session options. + struct Config { + bool prefer_nhwc = false; ///< Use NHWC data layout when available. + bool use_tf32 = true; ///< Enable TF32 math on Ampere+ GPUs. + bool enable_skip_layer_norm_strict_mode = false; ///< Strict mode for SkipLayerNorm kernel. + int device_id = 0; ///< CUDA device ordinal. + int cudnn_conv_algo = 0; ///< cuDNN convolution algorithm selection. + bool cudnn_conv1d_pad_to_nc1d = false; ///< Pad 1D convolutions to NC1D format. + bool enable_cuda_graph = false; ///< Enable CUDA graph capture/replay. + int min_num_runs_before_cuda_graph_capture = 1; ///< Warm-up runs before graph capture. + }; + + CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& logger); + ~CudaEp(); + + const char* GetEpName() const { return name_.c_str(); } + const Config& GetConfig() const { return config_; } + + private: + // OrtEp callback implementations + static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetCapabilityImpl( + OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept; + + static OrtStatus* ORT_API_CALL GetKernelRegistryImpl( + OrtEp* this_ptr, + const OrtKernelRegistry** kernel_registry) noexcept; + + static OrtStatus* ORT_API_CALL GetPreferredDataLayoutImpl( + OrtEp* this_ptr, OrtEpDataLayout* preferred_data_layout) noexcept; + + static OrtStatus* ORT_API_CALL OnRunStartImpl( + OrtEp* this_ptr, const ::OrtRunOptions* run_options) noexcept; + + static OrtStatus* ORT_API_CALL ShouldConvertDataLayoutForOpImpl( + OrtEp* this_ptr, const char* domain, const char* op_type, + OrtEpDataLayout target_data_layout, int* should_convert) noexcept; + + static OrtStatus* ORT_API_CALL OnRunEndImpl( + OrtEp* this_ptr, const ::OrtRunOptions* run_options, bool sync_stream) noexcept; + + // CUDA Graph helpers + CudaGraphAnnotation_t GetAnnotationId(const ::OrtRunOptions* run_options) const; + bool IsGraphCaptureAllowed(CudaGraphAnnotation_t annotation_id) const; + + CudaEpFactory& factory_; + std::string name_; + Config config_; + const OrtLogger& logger_; + + // CUDA Graph state + std::atomic cuda_graph_enabled_{false}; + int min_runs_before_capture_ = 1; + CUDAGraphManager cuda_graph_manager_; + std::unordered_map graph_id_to_run_count_; + bool is_capturing_ = false; + CudaGraphAnnotation_t capturing_annotation_id_ = kCudaGraphAnnotationDefault; +}; + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc new file mode 100644 index 0000000000000..e6c6df2d6ccfe --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -0,0 +1,344 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cuda_ep_factory.h" +#include "cuda_ep.h" +#include "cuda_plugin_kernels.h" + +namespace onnxruntime { +namespace cuda_plugin { + +CudaEpFactory::CudaEpFactory(const OrtApi& ort_api, const OrtEpApi& ep_api, + const OrtLogger& default_logger) + : OrtEpFactory{}, + ort_api_(ort_api), + ep_api_(ep_api), + default_logger_(default_logger), + default_memory_info_{nullptr}, + pinned_memory_info_{nullptr} { + ort_version_supported = ORT_API_VERSION; + + // Assign callback function pointers + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; + GetVersion = GetVersionImpl; + GetSupportedDevices = GetSupportedDevicesImpl; + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + CreateDataTransfer = CreateDataTransferImpl; + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; + + // Initialize default memory info for CUDA device memory. + // The NVIDIA PCI vendor ID (0x10DE) is used to identify the device type. + default_memory_info_ = Ort::MemoryInfo{"Cuda", + OrtMemoryInfoDeviceType_GPU, + vendor_id_, + static_cast(device_id_), + OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 0, + OrtAllocatorType::OrtDeviceAllocator}; + + // Initialize pinned (host accessible) memory info + pinned_memory_info_ = Ort::MemoryInfo{"CudaPinned", + OrtAllocatorType::OrtDeviceAllocator, + 0, + OrtMemType::OrtMemTypeCPU}; +} + +CudaEpFactory::~CudaEpFactory() { + if (kernel_registry_ != nullptr) { + ep_api_.ReleaseKernelRegistry(kernel_registry_); + } +} + +OrtStatus* CudaEpFactory::GetKernelRegistryForEp(CudaEp& ep, + const OrtKernelRegistry** out_kernel_registry) { + *out_kernel_registry = nullptr; + + std::lock_guard lock(registry_mutex_); + + if (kernel_registry_ == nullptr) { + const char* ep_name = ep.GetEpName(); + // CreateCudaKernelRegistry dispatches between legacy/generated registrations + // and adapter-mode registration path based on build configuration. + RETURN_IF_ERROR(CreateCudaKernelRegistry(ep_api_, ep_name, nullptr, &kernel_registry_)); + } + + *out_kernel_registry = kernel_registry_; + return nullptr; +} + +// --------------------------------------------------------------------------- +// OrtEpFactory callback implementations +// --------------------------------------------------------------------------- + +/*static*/ +const char* ORT_API_CALL CudaEpFactory::GetNameImpl(const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->ep_name_.c_str(); +} + +/*static*/ +const char* ORT_API_CALL CudaEpFactory::GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->vendor_.c_str(); +} + +/*static*/ +uint32_t ORT_API_CALL CudaEpFactory::GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->vendor_id_; +} + +/*static*/ +const char* ORT_API_CALL CudaEpFactory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->ep_version_.c_str(); +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* hw_devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + auto* factory = static_cast(this_ptr); + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *hw_devices[i]; + auto hw_type = factory->ort_api_.HardwareDevice_Type(&device); + + if (hw_type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // Check if this GPU is an NVIDIA GPU by trying to match vendor ID + // For now, accept all GPU devices and let CUDA runtime handle validation + OrtKeyValuePairs* ep_metadata = nullptr; + factory->ort_api_.CreateKeyValuePairs(&ep_metadata); + + // Get CUDA device properties for metadata + int cuda_device_count = 0; + cudaError_t err = cudaGetDeviceCount(&cuda_device_count); + if (err == cudaSuccess && cuda_device_count > 0) { + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, factory->device_id_); + factory->ort_api_.AddKeyValuePair(ep_metadata, "cuda_device_name", prop.name); + factory->ort_api_.AddKeyValuePair( + ep_metadata, "cuda_compute_capability", + (std::to_string(prop.major) + "." + std::to_string(prop.minor)).c_str()); + } + + OrtEpDevice* ep_device = nullptr; + auto* status = factory->ep_api_.CreateEpDevice(factory, &device, ep_metadata, nullptr, + &ep_device); + factory->ort_api_.ReleaseKeyValuePairs(ep_metadata); + + if (status != nullptr) { + return status; + } + + // Register allocator info for GPU device memory + RETURN_IF_ERROR(factory->ep_api_.EpDevice_AddAllocatorInfo( + ep_device, factory->default_memory_info_)); + + // Register allocator info for CPU pinned memory (host accessible) + RETURN_IF_ERROR(factory->ep_api_.EpDevice_AddAllocatorInfo( + ep_device, factory->pinned_memory_info_)); + + ep_devices[num_ep_devices++] = ep_device; + } + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept { + EXCEPTION_TO_STATUS_BEGIN + + auto* factory = static_cast(this_ptr); + *ep = nullptr; + + if (num_devices != 1) { + return factory->ort_api_.CreateStatus( + ORT_INVALID_ARGUMENT, + "CUDA EP factory currently supports only one device at a time."); + } + + // Parse configuration from session options. + // The read helpers intentionally swallow errors: if a config entry is + // absent or malformed the default value in Config is kept. + CudaEp::Config config{}; + + auto read_session_config_bool = [&](const char* key, bool& value) { + size_t size = 0; + OrtStatus* status = factory->ort_api_.GetSessionConfigEntry(session_options, key, nullptr, &size); + if (status != nullptr) { + Ort::Status s(status); + return; + } + if (size == 0) return; + std::vector buf(size); + status = factory->ort_api_.GetSessionConfigEntry(session_options, key, buf.data(), &size); + if (status != nullptr) { + Ort::Status s(status); + return; + } + const std::string val(buf.data()); + value = (val == "1" || val == "true"); + }; + + auto read_session_config_int = [&](const char* key, int& value) { + size_t size = 0; + OrtStatus* status = factory->ort_api_.GetSessionConfigEntry(session_options, key, nullptr, &size); + if (status != nullptr) { + Ort::Status s(status); + return; + } + if (size == 0) return; + std::vector buf(size); + status = factory->ort_api_.GetSessionConfigEntry(session_options, key, buf.data(), &size); + if (status != nullptr) { + Ort::Status s(status); + return; + } + try { + value = std::stoi(buf.data()); + } catch (...) { + } + }; + + // Read from flat keys first, then from ep.cuda.* prefixed keys. + // The second pass intentionally overwrites the first so that + // ep.cuda.* takes precedence over unprefixed keys. + read_session_config_bool("prefer_nhwc", config.prefer_nhwc); + read_session_config_bool("use_tf32", config.use_tf32); + read_session_config_bool("enable_skip_layer_norm_strict_mode", config.enable_skip_layer_norm_strict_mode); + read_session_config_bool("cudnn_conv1d_pad_to_nc1d", config.cudnn_conv1d_pad_to_nc1d); + read_session_config_int("cudnn_conv_algo", config.cudnn_conv_algo); + + read_session_config_bool("ep.cuda.prefer_nhwc_layout", config.prefer_nhwc); + read_session_config_bool("ep.cuda.use_tf32", config.use_tf32); + read_session_config_bool("ep.cuda.enable_skip_layer_norm_strict_mode", config.enable_skip_layer_norm_strict_mode); + read_session_config_bool("ep.cuda.cudnn_conv1d_pad_to_nc1d", config.cudnn_conv1d_pad_to_nc1d); + read_session_config_int("ep.cuda.cudnn_conv_algo", config.cudnn_conv_algo); + read_session_config_bool("ep.cuda.enable_cuda_graph", config.enable_cuda_graph); + read_session_config_int("ep.cuda.min_num_runs_before_cuda_graph_capture", config.min_num_runs_before_cuda_graph_capture); + + const OrtLogger& ep_logger = logger ? *logger : factory->default_logger_; + auto actual_ep = std::make_unique(*factory, config, ep_logger); + *ep = actual_ep.release(); + + return nullptr; + + EXCEPTION_TO_STATUS_END +} + +/*static*/ +void ORT_API_CALL CudaEpFactory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { + delete static_cast(ep); +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEpFactory::CreateAllocatorImpl( + OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + auto& factory = *static_cast(this_ptr); + *allocator = nullptr; + + const OrtMemoryInfo* default_memory_info_ptr = factory.default_memory_info_.operator OrtMemoryInfo*(); + const OrtMemoryInfo* pinned_memory_info_ptr = factory.pinned_memory_info_.operator OrtMemoryInfo*(); + + auto is_equal_memory_info = [&](const OrtMemoryInfo* expected, bool& out_equal) -> OrtStatus* { + int is_equal = 0; + auto* status = factory.ort_api_.CompareMemoryInfo(memory_info, expected, &is_equal); + if (status != nullptr) { + return status; + } + out_equal = (is_equal != 0); + return nullptr; + }; + + bool is_default = false; + bool is_pinned = false; + RETURN_IF_ERROR(is_equal_memory_info(default_memory_info_ptr, is_default)); + RETURN_IF_ERROR(is_equal_memory_info(pinned_memory_info_ptr, is_pinned)); + + if (is_default) { + auto cuda_allocator = std::make_unique(memory_info, factory.device_id_); + *allocator = cuda_allocator.release(); + return nullptr; + } + + if (is_pinned) { + auto pinned_allocator = std::make_unique(memory_info); + *allocator = pinned_allocator.release(); + return nullptr; + } + + return factory.ort_api_.CreateStatus( + ORT_INVALID_ARGUMENT, + "Unknown memory info provided to CUDA EP CreateAllocator."); +} + +/*static*/ +void ORT_API_CALL CudaEpFactory::ReleaseAllocatorImpl( + OrtEpFactory* /*this_ptr*/, OrtAllocator* allocator) noexcept { + // We know the allocator was created by us, so cast and delete. + // OrtAllocator itself has no Release method. + delete allocator; +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEpFactory::CreateDataTransferImpl( + OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept { + auto& factory = *static_cast(this_ptr); + const OrtMemoryDevice* gpu_device = factory.ep_api_.MemoryInfo_GetMemoryDevice(factory.default_memory_info_); + auto data_transfer_impl = std::make_unique(factory.ort_api_, factory.ep_api_, gpu_device); + *data_transfer = data_transfer_impl.release(); + return nullptr; +} + +/*static*/ +bool ORT_API_CALL CudaEpFactory::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return true; // CUDA EP is stream-aware +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEpFactory::CreateSyncStreamForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtMemoryDevice* /*memory_device*/, + const OrtKeyValuePairs* /*stream_options*/, + OrtSyncStreamImpl** stream) noexcept { + EXCEPTION_TO_STATUS_BEGIN + + auto* factory = static_cast(this_ptr); + auto cuda_stream = std::make_unique(*factory, factory->device_id_, nullptr); + + // Initialize CUDA handles (stream, cuBLAS, cuDNN) + RETURN_IF_ERROR(cuda_stream->InitHandles()); + + // Track the compute stream for CUDA graph integration. + // The factory does NOT own this stream — ORT manages its lifetime. + factory->compute_stream_ = cuda_stream.get(); + + *stream = cuda_stream.release(); + return nullptr; + + EXCEPTION_TO_STATUS_END +} + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h new file mode 100644 index 0000000000000..429cf50eaf854 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_plugin_utils.h" +#include "cuda_allocator_plugin.h" +#include "cuda_data_transfer_plugin.h" +#include "cuda_stream_plugin.h" + +#include +#include +#include + +namespace onnxruntime { +namespace cuda_plugin { + +class CudaEp; + +/// CUDA EP factory implementing OrtEpFactory. +/// Manages device enumeration, allocator creation, data transfer, and stream creation. +class CudaEpFactory : public OrtEpFactory { + public: + CudaEpFactory(const OrtApi& ort_api, const OrtEpApi& ep_api, + const OrtLogger& default_logger); + ~CudaEpFactory(); + + const OrtApi& GetOrtApi() const { return ort_api_; } + const OrtEpApi& GetEpApi() const { return ep_api_; } + const std::string& GetEpName() const { return ep_name_; } + + /// Get or create the shared kernel registry for this factory. + OrtStatus* GetKernelRegistryForEp(CudaEp& ep, + const OrtKernelRegistry** out_kernel_registry); + + /// Get the compute stream (set by CreateSyncStreamForDevice). + /// Returns nullptr if no stream has been created yet. + CudaSyncStream* GetComputeStream() const { return compute_stream_; } + + private: + // OrtEpFactory callback implementations + static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; + static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept; + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, size_t num_devices, + OrtEpDevice** ep_devices, size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept; + + static OrtStatus* ORT_API_CALL CreateEpImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept; + + static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* this_ptr, OrtEp* ep) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl( + OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept; + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* this_ptr, + OrtAllocator* allocator) noexcept; + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl( + OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept; + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept; + + const OrtApi& ort_api_; + const OrtEpApi& ep_api_; + const OrtLogger& default_logger_; + + const std::string ep_name_{"CudaPluginExecutionProvider"}; + const std::string vendor_{"NVIDIA"}; + const uint32_t vendor_id_ = 0x10DE; // NVIDIA PCI vendor ID + const std::string ep_version_{"1.0.0"}; + + // Memory info for GPU device and CPU pinned memory + Ort::MemoryInfo default_memory_info_{nullptr}; + Ort::MemoryInfo pinned_memory_info_{nullptr}; + int device_id_ = 0; + + // Kernel registry (cached, shared across EP instances) + OrtKernelRegistry* kernel_registry_ = nullptr; + std::mutex registry_mutex_; + + // Compute stream (set by CreateSyncStreamForDevice, non-owning). + CudaSyncStream* compute_stream_ = nullptr; +}; + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc new file mode 100644 index 0000000000000..3e5f28b4b0491 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cuda_graph_plugin.h" + +#include +#include +#include + +namespace onnxruntime { +namespace cuda_plugin { + +// --------------------------------------------------------------------------- +// CudaGraphSet +// --------------------------------------------------------------------------- + +CudaGraphSet::~CudaGraphSet() { + Clear(); +} + +void CudaGraphSet::Clear() { + for (auto& [id, graph_exec] : cuda_graphs_) { + (void)cudaGraphExecDestroy(graph_exec); + } + cuda_graphs_.clear(); +} + +bool CudaGraphSet::Contains(CudaGraphAnnotation_t id) const { + return cuda_graphs_.find(id) != cuda_graphs_.end(); +} + +void CudaGraphSet::Put(CudaGraphAnnotation_t id, cudaGraphExec_t graph_exec) { + if (Contains(id)) { + throw std::runtime_error( + "CudaGraphSet::Put: annotation id " + std::to_string(id) + + " already exists. Use a different annotation id."); + } + cuda_graphs_.emplace(id, graph_exec); +} + +cudaGraphExec_t CudaGraphSet::Get(CudaGraphAnnotation_t id) const { + auto it = cuda_graphs_.find(id); + if (it == cuda_graphs_.end()) { + throw std::runtime_error( + "CudaGraphSet::Get: no graph found for annotation id " + std::to_string(id)); + } + return it->second; +} + +// --------------------------------------------------------------------------- +// CUDAGraphManager +// --------------------------------------------------------------------------- + +CUDAGraphManager::CUDAGraphManager(cudaStream_t stream) : stream_(stream) {} + +CUDAGraphManager::~CUDAGraphManager() { + Reset(); +} + +void CUDAGraphManager::SetStream(cudaStream_t stream) { + stream_ = stream; +} + +void CUDAGraphManager::CaptureBegin(CudaGraphAnnotation_t annotation_id) { + if (!IsGraphCaptureAllowedOnRun(annotation_id)) { + throw std::runtime_error("CUDAGraphManager::CaptureBegin: capture not allowed for annotation " + + std::to_string(annotation_id)); + } + + if (cuda_graph_set_.Contains(annotation_id)) { + throw std::runtime_error( + "CUDAGraphManager::CaptureBegin: annotation id " + std::to_string(annotation_id) + + " already captured. Use a different annotation id."); + } + + auto err = cudaStreamSynchronize(stream_); + if (err != cudaSuccess) { + throw std::runtime_error(std::string("cudaStreamSynchronize failed: ") + cudaGetErrorString(err)); + } + + // cudaStreamCaptureModeGlobal: single-thread capture (future: ThreadLocal for multi-stream) + err = cudaStreamBeginCapture(stream_, cudaStreamCaptureModeGlobal); + if (err != cudaSuccess) { + throw std::runtime_error(std::string("cudaStreamBeginCapture failed: ") + cudaGetErrorString(err)); + } +} + +void CUDAGraphManager::CaptureEnd(CudaGraphAnnotation_t annotation_id) { + cudaGraph_t graph = nullptr; + auto err = cudaStreamEndCapture(stream_, &graph); + if (err != cudaSuccess) { + throw std::runtime_error(std::string("cudaStreamEndCapture failed: ") + cudaGetErrorString(err)); + } + if (graph == nullptr) { + throw std::runtime_error("CUDAGraphManager::CaptureEnd: captured graph is NULL"); + } + + cudaGraphExec_t graph_exec = nullptr; + err = cudaGraphInstantiate(&graph_exec, graph, nullptr, nullptr, 0); + (void)cudaGraphDestroy(graph); + + if (err != cudaSuccess) { + throw std::runtime_error(std::string("cudaGraphInstantiate failed: ") + cudaGetErrorString(err)); + } + + cuda_graph_set_.Put(annotation_id, graph_exec); +} + +OrtStatus* CUDAGraphManager::Replay(CudaGraphAnnotation_t annotation_id, bool sync) { + cudaGraphExec_t graph_exec = cuda_graph_set_.Get(annotation_id); + + auto err = cudaGraphLaunch(graph_exec, stream_); + if (err != cudaSuccess) { + return Ort::GetApi().CreateStatus( + ORT_EP_FAIL, + (std::string("cudaGraphLaunch failed: ") + cudaGetErrorString(err)).c_str()); + } + + if (sync) { + err = cudaStreamSynchronize(stream_); + if (err != cudaSuccess) { + return Ort::GetApi().CreateStatus( + ORT_EP_FAIL, + (std::string("cudaStreamSynchronize after graph replay failed: ") + cudaGetErrorString(err)).c_str()); + } + } + + return nullptr; +} + +bool CUDAGraphManager::IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t annotation_id) const { + return annotation_id != kCudaGraphAnnotationSkip; +} + +bool CUDAGraphManager::IsGraphCaptured(CudaGraphAnnotation_t annotation_id) const { + return cuda_graph_set_.Contains(annotation_id); +} + +void CUDAGraphManager::Reset() { + cuda_graph_set_.Clear(); +} + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h new file mode 100644 index 0000000000000..84c6361d9c5e1 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Plugin-compatible CUDA graph manager for capture/replay lifecycle. +// Adapted from core/providers/cuda/cuda_graph.h — removes dependencies +// on internal EP types (CUDAExecutionProvider, CudaStream). + +#pragma once + +#include "cuda_plugin_utils.h" + +#include +#include + +namespace onnxruntime { +namespace cuda_plugin { + +using CudaGraphAnnotation_t = int; + +constexpr CudaGraphAnnotation_t kCudaGraphAnnotationSkip = -1; +constexpr CudaGraphAnnotation_t kCudaGraphAnnotationDefault = 0; + +/// Stores instantiated CUDA graph executables keyed by annotation ID. +struct CudaGraphSet { + CudaGraphSet() = default; + ~CudaGraphSet(); + + void Clear(); + bool Contains(CudaGraphAnnotation_t id) const; + void Put(CudaGraphAnnotation_t id, cudaGraphExec_t graph_exec); + cudaGraphExec_t Get(CudaGraphAnnotation_t id) const; + + private: + std::unordered_map cuda_graphs_; +}; + +/// Manages CUDA graph capture/instantiation/replay for the plugin EP. +/// Each instance is associated with a single cudaStream_t. +struct CUDAGraphManager { + CUDAGraphManager() = default; + explicit CUDAGraphManager(cudaStream_t stream); + ~CUDAGraphManager(); + + void SetStream(cudaStream_t stream); + + /// Begin capturing CUDA work on the associated stream. + void CaptureBegin(CudaGraphAnnotation_t annotation_id); + + /// End capture, instantiate the graph, and store it. + void CaptureEnd(CudaGraphAnnotation_t annotation_id); + + /// Launch a previously captured graph. + OrtStatus* Replay(CudaGraphAnnotation_t annotation_id, bool sync = true); + + /// Destroy all captured graphs. + void Reset(); + + /// Whether capture is allowed for the given annotation (i.e., not the skip sentinel). + bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t annotation_id) const; + + /// Whether a graph has already been captured for the given annotation. + bool IsGraphCaptured(CudaGraphAnnotation_t annotation_id) const; + + private: + CudaGraphSet cuda_graph_set_; + cudaStream_t stream_ = nullptr; // Does not own the stream +}; + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h new file mode 100644 index 0000000000000..a0bf3e4b1862f --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -0,0 +1,810 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// cuda_kernel_adapter.h — Compatibility shim for migrating CUDA kernels to the +// plugin EP architecture. +// +// This header provides: +// - CudaKernel base class (scratch buffers, CUDA handles, etc.) +// - Error-return macros (CUDA_RETURN_IF_ERROR, etc.) +// - Type mapping helpers (ToCudaType) +// - Math/compute shims (HalfGemmOptions, CublasMathModeSetter) +// - Self-registering BuildKernelCreateInfo<> macros via PluginKernelCollector +// - CUDAExecutionProvider shim class +// - CPU provider shims for the plugin build + +#pragma once + +#include "core/common/status.h" +#include "core/common/narrow.h" +#include "core/common/float16.h" +#include "core/common/float8.h" +#include "core/framework/float4.h" +#include "core/framework/allocator.h" +#include "core/framework/tensor_shape.h" +#include "core/util/math.h" +#include + +#include +#include +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" + +#ifdef __CUDACC__ +#include +#include +#endif + +// =================================================================== +// Macros will be defined later to override core definitions. +// =================================================================== + +#include "core/providers/cuda/plugin/cuda_stream_plugin.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" +#include "core/session/onnxruntime_cxx_api.h" + +// Forward-declare CudaStream so adapter overloads accepting CudaStream* compile +// without pulling in cuda_stream_handle.h (which depends on CUDAExecutionProvider). +namespace onnxruntime { +struct CudaStream; + +// Lightweight Stream shim for plugin build: wraps a raw cudaStream_t as a +// framework-compatible Stream* that can be passed to _impl.cu functions which +// call stream->GetHandle(). Stack-allocated; does NOT own the stream. +// Only available in .cc translation units (not .cu) since Stream is incomplete in NVCC context. +#ifndef __CUDACC__ +struct PluginStreamShim : public onnxruntime::Stream { + explicit PluginStreamShim(void* cuda_stream_handle) + : onnxruntime::Stream(cuda_stream_handle, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, + OrtDevice::VendorIds::NVIDIA, 0)) {} +}; +#endif +} // namespace onnxruntime + +// =================================================================== +// Section 1: Include path selection +// =================================================================== + +#include "core/graph/constants.h" +#include "ep/adapters.h" +#include "core/framework/op_kernel.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace cuda { + +#ifndef CUDA_STREAM_FROM_CTX +// Helper for kernels that need a cudaStream_t from OpKernelContext in plugin build. +#define CUDA_STREAM_FROM_CTX(ctx) static_cast(GetComputeStream(ctx)) +#endif + +// Forward declare the template for kernel registration macros to specialize +// inside the onnxruntime::cuda namespace. +template +KernelCreateInfo BuildKernelCreateInfo(); + +// Tensor creation helper to replace deprecated Tensor::Create +inline std::unique_ptr<::onnxruntime::Tensor> TensorCreate(MLDataType type, const TensorShape& shape, AllocatorPtr allocator) { + return std::make_unique<::onnxruntime::Tensor>(type, shape, std::move(allocator)); +} + +using ::onnxruntime::HandleNegativeAxis; + +} // namespace cuda + +#ifndef DISABLE_CONTRIB_OPS +namespace contrib { +namespace cuda { + +// Forward declare the template for kernel registration macros to specialize +// inside the onnxruntime::contrib::cuda namespace. +template +KernelCreateInfo BuildKernelCreateInfo(); + +inline std::unique_ptr<::onnxruntime::Tensor> TensorCreate(MLDataType type, const TensorShape& shape, AllocatorPtr allocator) { + return std::make_unique<::onnxruntime::Tensor>(type, shape, std::move(allocator)); +} + +using ::onnxruntime::HandleNegativeAxis; + +} // namespace cuda +} // namespace contrib +#endif +} // namespace onnxruntime + +// =================================================================== +// Section 2: Error-return macros (redefined for all plugin paths) +// =================================================================== + +// Redefine error macros for ported code to use our adapter-specific Status translation +#undef CUDA_RETURN_IF_ERROR +#define CUDA_RETURN_IF_ERROR(expr) \ + { \ + cudaError_t _err = (expr); \ + if (_err != cudaSuccess) { \ + return onnxruntime::common::Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, std::string("CUDA error: ") + cudaGetErrorString(_err)); \ + } \ + } + +#undef CUBLAS_RETURN_IF_ERROR +#define CUBLAS_RETURN_IF_ERROR(expr) \ + { \ + cublasStatus_t _status = (expr); \ + if (_status != CUBLAS_STATUS_SUCCESS) { \ + return onnxruntime::common::Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, std::string("cuBLAS error: ") + std::to_string(static_cast(_status))); \ + } \ + } + +#undef CUDNN_RETURN_IF_ERROR +#define CUDNN_RETURN_IF_ERROR(expr) \ + { \ + cudnnStatus_t _status = (expr); \ + if (_status != CUDNN_STATUS_SUCCESS) { \ + return onnxruntime::common::Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, std::string("cuDNN error: ") + cudnnGetErrorString(_status)); \ + } \ + } + +// =================================================================== +// Section 3: Self-registering kernel collector using BuildKernelCreateInfo<> +// +// Each ONNX_OPERATOR_*_KERNEL_EX macro expansion both: +// 1. Creates a BuildKernelCreateInfo() template specialization +// (identical to the framework's macro output, but using adapter types) +// 2. Auto-registers the BuildKernelCreateInfoFn pointer into a global +// PluginKernelCollector singleton +// +// At registration time, the factory iterates the collector and calls +// adapter::KernelRegistry::Register(build_fn()) for each compiled kernel. +// Only ops whose .cc files are actually compiled get registered. +// =================================================================== + +#include + +namespace onnxruntime { +namespace cuda { + +/// Singleton collector for BuildKernelCreateInfoFn pointers. +/// Each compiled kernel .cc file's macro expansion auto-registers here. +class PluginKernelCollector { + public: + static PluginKernelCollector& Instance() { + static PluginKernelCollector instance; + return instance; + } + + void Add(BuildKernelCreateInfoFn fn) { entries_.push_back(fn); } + const std::vector& Entries() const { return entries_; } + + private: + std::vector entries_; +}; + +} // namespace cuda +} // namespace onnxruntime + +// --- Macro overrides: produce BuildKernelCreateInfo<> AND auto-register --- +// +// These macros mirror the framework's ONNX_OPERATOR_*_KERNEL_EX definitions +// from core/framework/op_kernel.h, but additionally register each +// BuildKernelCreateInfoFn into PluginKernelCollector at static init time. + +#define ORT_ADAPTER_CONCAT_IMPL(x, y) x##y +#define ORT_ADAPTER_CONCAT(x, y) ORT_ADAPTER_CONCAT_IMPL(x, y) + +#undef ONNX_OPERATOR_KERNEL_EX +#define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \ + class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(ver).Provider(provider).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(_autoreg_##name##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ + true); + +#undef ONNX_OPERATOR_VERSIONED_KERNEL_EX +#define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, provider, builder, ...) \ + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(startver, endver).Provider(provider).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(_autoreg_##name##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ + true); + +#undef ONNX_OPERATOR_TYPED_KERNEL_EX +#define ONNX_OPERATOR_TYPED_KERNEL_EX(name, domain, ver, type, provider, builder, ...) \ + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(ver).Provider(provider).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(_autoreg_##name##_##type##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ + true); + +#undef ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX +#define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, domain, startver, endver, type, provider, builder, ...) \ + class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(startver, endver).Provider(provider).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(_autoreg_##name##_##type##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ + true); + +// =================================================================== +// Section 4: Logging shim (adapter path only) +// Replaces LOGS_DEFAULT with a no-op stream to avoid pulling in the +// full ORT logging framework inside the plugin shared library. +// =================================================================== + +// Explicit function instantiation — called once per unique class in each .cc file +#define ONNX_OPERATOR_TYPED_KERNEL_COMPUTE_INSTANTIATION(cls) template Status cls::ComputeInternal(OpKernelContext* context) const; + +#undef CREATE_MESSAGE +#undef LOGS +#undef LOGS_DEFAULT +#undef ORT_LOG_MESSAGE + +namespace onnxruntime { +namespace cuda { +struct PluginNoOpLogStream { + template + PluginNoOpLogStream& operator<<(const T&) { return *this; } +}; +} // namespace cuda +} // namespace onnxruntime + +#ifndef LOGS_DEFAULT +#define LOGS_DEFAULT(severity) ::onnxruntime::cuda::PluginNoOpLogStream() +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace cuda { + +// =================================================================== +// Section 5: Runtime configuration for migrated kernels +// Stored as atomics so SetCudaKernelAdapterRuntimeConfig() can be +// called from CudaEp's constructor on any thread. +// =================================================================== + +namespace detail { +struct CudaKernelAdapterRuntimeConfig { + std::atomic use_tf32{true}; + std::atomic skip_layer_norm_strict_mode{false}; + std::atomic device_id{0}; + std::atomic cudnn_conv_algo{0}; + std::atomic cudnn_conv1d_pad_to_nc1d{false}; +}; +inline CudaKernelAdapterRuntimeConfig& GetCudaKernelAdapterRuntimeConfig() { + static CudaKernelAdapterRuntimeConfig config; + return config; +} +template +struct SizeOf { + static constexpr size_t value = sizeof(T); +}; +template <> +struct SizeOf { + static constexpr size_t value = 0; +}; +inline size_t BytesForCount(size_t count_or_bytes, size_t element_size) { + if (element_size == 0) return count_or_bytes; + if (count_or_bytes > (std::numeric_limits::max() / element_size)) return 0; + return count_or_bytes * element_size; +} +} // namespace detail +} // namespace cuda + +// =================================================================== +// Section 6: CUDAExecutionProvider shim +// Provides the minimal API surface that migrated kernels expect +// (GetCudnnConvAlgo, UseTF32, GetDeviceProp, etc.) without the full +// CUDAExecutionProvider class from onnxruntime/core/providers/cuda/. +// =================================================================== + +// Shim for CUDAExecutionProvider required by conv.cc, einsum, and others +class CUDAExecutionProvider : public onnxruntime::IExecutionProvider { + public: + explicit CUDAExecutionProvider(const std::string& name) : onnxruntime::IExecutionProvider{name} {} + int GetCudnnConvAlgo() const { + return cuda::detail::GetCudaKernelAdapterRuntimeConfig().cudnn_conv_algo.load(std::memory_order_relaxed); + } + bool GetCudnnConv1dPadToNc1d() const { + return cuda::detail::GetCudaKernelAdapterRuntimeConfig().cudnn_conv1d_pad_to_nc1d.load(std::memory_order_relaxed); + } + bool UseTF32() const { + return cuda::detail::GetCudaKernelAdapterRuntimeConfig().use_tf32.load(std::memory_order_relaxed); + } + bool IsFuseConvBias() const { + return false; + } + const cudaDeviceProp& GetDeviceProp() const { + static cudaDeviceProp prop; + static std::once_flag flag; + std::call_once(flag, []() { + int device_id = cuda::detail::GetCudaKernelAdapterRuntimeConfig().device_id.load(std::memory_order_relaxed); + if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) { + std::memset(&prop, 0, sizeof(prop)); + prop.major = -1; + } + }); + return prop; + } +}; + +namespace cuda { + +inline void SetCudaKernelAdapterRuntimeConfig(bool use_tf32, int device_id, bool skip_layer_norm_strict_mode = false, + int cudnn_conv_algo = 0, bool cudnn_conv1d_pad_to_nc1d = false) { + auto& config = detail::GetCudaKernelAdapterRuntimeConfig(); + config.use_tf32.store(use_tf32, std::memory_order_relaxed); + config.skip_layer_norm_strict_mode.store(skip_layer_norm_strict_mode, std::memory_order_relaxed); + config.device_id.store(device_id, std::memory_order_relaxed); + config.cudnn_conv_algo.store(cudnn_conv_algo, std::memory_order_relaxed); + config.cudnn_conv1d_pad_to_nc1d.store(cudnn_conv1d_pad_to_nc1d, std::memory_order_relaxed); +} + +inline bool GetCudaKernelAdapterSkipLayerNormStrictMode() { + const auto& config = detail::GetCudaKernelAdapterRuntimeConfig(); + return config.skip_layer_norm_strict_mode.load(std::memory_order_relaxed); +} + +// Global aliases and shims +using Status = onnxruntime::common::Status; +using MLFloat16 = onnxruntime::MLFloat16; +using BFloat16 = onnxruntime::BFloat16; +using Float8E4M3FN = onnxruntime::Float8E4M3FN; +using Float8E4M3FNUZ = onnxruntime::Float8E4M3FNUZ; +using Float8E5M2 = onnxruntime::Float8E5M2; +using Float8E5M2FNUZ = onnxruntime::Float8E5M2FNUZ; + +// Type mapping for CUDA +template +struct ToCudaType { + typedef T MappedType; + static MappedType FromFloat(float f) { return static_cast(f); } +}; + +template <> +struct ToCudaType { + typedef half MappedType; + static MappedType FromFloat(float f) { + uint16_t h = onnxruntime::math::floatToHalf(f); + return *reinterpret_cast(&h); + } +}; + +#ifdef __CUDACC__ +template <> +struct ToCudaType { + typedef nv_bfloat16 MappedType; + static MappedType FromFloat(float f) { + return nv_bfloat16(f); + } +}; + +// Forward declare templates from common.cuh to allow specialization +// Match signatures from common.cuh exactly (no default parameters) +template +struct _IsInf; +template +struct _IsNan; + +namespace bf16_isinf_nan { +template +struct IsInfTyped; +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(nv_bfloat16 a) { + uint16_t val = *reinterpret_cast(&a); + return (val & 0x7F80) == 0x7F80 && (val & 0x007F) == 0x0000; + } + static __device__ __inline__ bool IsInfPos(nv_bfloat16 a) { + return *reinterpret_cast(&a) == 0x7F80; + } + static __device__ __inline__ bool IsInfNeg(nv_bfloat16 a) { + return *reinterpret_cast(&a) == 0xFF80; + } +}; +} // namespace bf16_isinf_nan + +// Specialize for nv_bfloat16 to avoid ambiguity with isnan/isinf overloads +template <> +struct _IsNan { + __device__ __inline__ bool operator()(nv_bfloat16 a) const { + uint16_t val = *reinterpret_cast(&a); + return (val & 0x7F80) == 0x7F80 && (val & 0x007F) != 0x0000; + } +}; + +template +struct _IsInf { + __device__ __inline__ bool operator()(nv_bfloat16 a) const { + if constexpr (detect_positive && detect_negative) { + return bf16_isinf_nan::IsInfTyped::IsInf(a); + } else if constexpr (detect_positive) { + return bf16_isinf_nan::IsInfTyped::IsInfPos(a); + } else if constexpr (detect_negative) { + return bf16_isinf_nan::IsInfTyped::IsInfNeg(a); + } else { + return false; + } + } +}; +#endif + +// =================================================================== +// Section 6b: CPU provider shims for the plugin build +// Inline implementations of CPU utility functions that CUDA kernels +// reference (e.g., OneHot validation, GatherElements shape checks). +// These are normally provided by onnxruntime_providers but the plugin +// does not link against it. +// +// We temporarily close namespace cuda so these shims live directly in +// namespace onnxruntime, where unqualified lookup from onnxruntime::cuda +// will find them, and where onnxruntime::GatherElements resolves correctly. +// =================================================================== + +} // namespace cuda + +// Shim for ValidateInputs from core/providers/cpu/tensor/onehot.h +inline Status ValidateInputs(const Tensor* depth, const Tensor* values) { + if (!depth->Shape().IsScalar()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Invalid argument for depth; it's not a scalar."); + } + if (!(values->Shape().NumDimensions() == 1 && values->Shape().Size() == 2)) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Invalid argument for values; either it's rank is more than 1" + " or it has more than 2 elements"); + } + return Status::OK(); +} + +// Shim for PrepareOutputShape from core/providers/cpu/tensor/onehot.h +inline Status PrepareOutputShape(const Tensor* indices, const int64_t depth_val, const int64_t axis, + int64_t& prefix_dim_size, int64_t& suffix_dim_size, + TensorShapeVector& output_shape) { + const auto& indices_shape = indices->Shape(); + const auto indices_dims = indices_shape.GetDims(); + const auto indices_num_dims = indices_shape.NumDimensions(); + output_shape = indices_shape.AsShapeVector(); + const auto output_rank = static_cast(indices_num_dims) + 1; + auto true_axis = HandleNegativeAxis(axis, output_rank); + output_shape.insert(output_shape.begin() + true_axis, depth_val); + prefix_dim_size = 1; + for (int64_t i = 0; i < true_axis; ++i) { + prefix_dim_size *= indices_dims[narrow(i)]; + } + suffix_dim_size = indices_shape.Size() / prefix_dim_size; + return Status::OK(); +} + +// Shim for GatherElements::ValidateInputShapes from +// core/providers/cpu/tensor/gather_elements.h +class GatherElements { + public: + static Status ValidateInputShapes(const TensorShape& input_data_shape, + const TensorShape& indices_shape, + int64_t axis) { + int64_t input_data_rank = static_cast(input_data_shape.NumDimensions()); + int64_t indices_rank = static_cast(indices_shape.NumDimensions()); + if (input_data_rank < 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "GatherElements op: Cannot operate on scalar input"); + if (input_data_rank != indices_rank) + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "GatherElements op: Rank of input 'data' needs to be equal to rank of input 'indices'"); + for (int64_t i = 0; i < indices_rank; ++i) { + if (i != axis) { + if (indices_shape[narrow(i)] < 0 || + indices_shape[narrow(i)] > input_data_shape[narrow(i)]) + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "GatherElements op: 'indices' shape should have values within bounds of 'data' shape. " + "Invalid value in indices shape is: ", + indices_shape[narrow(i)]); + } + } + return Status::OK(); + } +}; + +namespace cuda { // re-open onnxruntime::cuda + +// =================================================================== +// Section 7: CudaKernel base class +// Base class for all migrated CUDA kernels. Provides scratch-buffer +// management, CUDA handle access (cuBLAS, cuDNN), device property +// queries, and the CudaAsyncBuffer helper for host→device transfers. +// =================================================================== + +// Additional adapter logic for CudaKernel + +class CudaKernel : public OpKernel { + public: + explicit CudaKernel(const OpKernelInfo& info) : OpKernel(info), info_(info) { + const auto& config = detail::GetCudaKernelAdapterRuntimeConfig(); + use_tf32_ = config.use_tf32.load(std::memory_order_relaxed); + device_id_ = config.device_id.load(std::memory_order_relaxed); + int cur = device_id_; + if (cudaGetDevice(&cur) == cudaSuccess) device_id_ = cur; + if (cudaGetDeviceProperties(&device_prop_, device_id_) != cudaSuccess) { + std::memset(&device_prop_, 0, sizeof(device_prop_)); + device_prop_.major = -1; + } + } + virtual ~CudaKernel() = default; + Status Compute(OpKernelContext* ctx) const { + Status s = ComputeInternal(ctx); + if (s.IsOK()) { + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "CUDA error: " + std::string(cudaGetErrorString(err))); + } + return s; + } + virtual Status ComputeInternal(OpKernelContext* ctx) const = 0; + + inline cudaStream_t DefaultCudaStream() const { return Stream(static_cast(nullptr)); } + inline cublasHandle_t DefaultCublasHandle() const { return GetCublasHandle(static_cast(nullptr)); } + inline cudnnHandle_t DefaultCudnnHandle() const { return GetCudnnHandle(static_cast(nullptr)); } + + inline Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst, onnxruntime::Stream& stream) const { + if (src.Shape().Size() == 0) return Status::OK(); + if (cudaMemcpyAsync(dst.MutableDataRaw(), src.DataRaw(), src.SizeInBytes(), cudaMemcpyDeviceToDevice, (cudaStream_t)stream.GetHandle()) != cudaSuccess) { + return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Memcpy fail"); + } + return Status::OK(); + } + + cudaStream_t Stream(OpKernelContext* ctx) const { + if (!ctx) return nullptr; + return static_cast(ctx->GetGPUComputeStream()); + } + + // Returns an opaque stream pointer for passing to GetScratchBuffer/AddDeferredReleaseCPUPtr/CopyToGpu. + // Returns void* for dual-build compatibility: framework wraps Stream*, plugin wraps cudaStream_t. + inline void* GetComputeStream(OpKernelContext* ctx) const { + return ctx->GetGPUComputeStream(); + } + + static cudnnHandle_t GetCudnnHandle(cudaStream_t s) { + auto* sync = cuda_plugin::CudaSyncStream::FromCudaStream(s); + return sync ? sync->GetCudnnHandle() : nullptr; + } + static inline cudnnHandle_t GetCudnnHandle(onnxruntime::CudaStream* stream) { + return stream ? GetCudnnHandle(static_cast(reinterpret_cast(stream)->GetHandle())) : nullptr; + } + static inline cudnnHandle_t GetCudnnHandle(onnxruntime::Stream* stream) { + return stream ? GetCudnnHandle(static_cast(stream->GetHandle())) : nullptr; + } + cudnnHandle_t GetCudnnHandle(OpKernelContext* ctx) const { return GetCudnnHandle(Stream(ctx)); } + + static cublasHandle_t GetCublasHandle(cudaStream_t s) { + auto* sync = cuda_plugin::CudaSyncStream::FromCudaStream(s); + return sync ? sync->GetCublasHandle() : nullptr; + } + static inline cublasHandle_t GetCublasHandle(onnxruntime::CudaStream* stream) { + return stream ? GetCublasHandle(static_cast(reinterpret_cast(stream)->GetHandle())) : nullptr; + } + static inline cublasHandle_t GetCublasHandle(onnxruntime::Stream* stream) { + return stream ? GetCublasHandle(static_cast(stream->GetHandle())) : nullptr; + } + cublasHandle_t GetCublasHandle(OpKernelContext* ctx) const { return GetCublasHandle(Stream(ctx)); } + + const cudaDeviceProp& GetDeviceProp() const { return device_prop_; } + bool UseTF32() const { return use_tf32_; } + bool IsArchAvailable(int arch) const { return device_prop_.major >= arch; } + const OpKernelInfo& Info() const { return info_; } + const onnxruntime::AttentionKernelOptions* GetAttentionKernelOptions() const { + static onnxruntime::AttentionKernelOptions options; + return &options; + } + + // Stub for GetTuningContext — tunable ops are not supported in the plugin. + struct PluginTuningContextStub { + bool IsTunableOpEnabled() const { return false; } + }; + PluginTuningContextStub* GetTuningContext() const { + static PluginTuningContextStub stub; + return &stub; + } + + // GetConstOnes: returns a device buffer of constant ones. + // Delegates to IConstantBuffer from cuda_utils.h (compiled in cuda_utils.cu). + template + const T* GetConstOnes(size_t count, cudaStream_t stream) const { + static std::unique_ptr> buf; + static std::once_flag flag; + std::call_once(flag, []() { buf = CreateConstantOnes(); }); + return buf->GetBuffer(stream, count); + } + + template + using IAllocatorUniquePtr = std::unique_ptr>; + template + inline IAllocatorUniquePtr GetScratchBuffer(size_t cnt, void* s) const { + if (cnt == 0) return IAllocatorUniquePtr(nullptr, [](T*) {}); + size_t sz = detail::BytesForCount(cnt, detail::SizeOf::value); + void* p = nullptr; + if (cudaMalloc(&p, sz) != cudaSuccess) return IAllocatorUniquePtr(nullptr, [](T*) {}); + return IAllocatorUniquePtr(static_cast(p), [s](T* ptr) { + if (ptr) { + if (s) { + cudaFreeAsync(ptr, static_cast(s)); + } else { + cudaFree(ptr); + } + } + }); + } + template + inline IAllocatorUniquePtr GetTransientScratchBuffer(size_t cnt) const { + return GetScratchBuffer(cnt, nullptr); + } + inline void AddDeferredReleaseCPUPtr(void* p, void* s) const { + if (!p) return; + auto* sync = cuda_plugin::CudaSyncStream::FromCudaStream(static_cast(s)); + if (sync) { + sync->EnqueueDeferredCPUBuffer(p); + return; + } + cudaFreeHost(p); + } + template + inline IAllocatorUniquePtr AllocateBufferOnCPUPinned(size_t cnt) const { + if (cnt == 0) return IAllocatorUniquePtr(nullptr, [](T*) {}); + size_t sz = detail::BytesForCount(cnt, detail::SizeOf::value); + void* p = nullptr; + if (cudaHostAlloc(&p, sz, cudaHostAllocDefault) != cudaSuccess) return IAllocatorUniquePtr(nullptr, [](T*) {}); + return IAllocatorUniquePtr(static_cast(p), [](T* ptr) { if (ptr) cudaFreeHost(ptr); }); + } + + template + class CudaAsyncBuffer { + public: + CudaAsyncBuffer(const CudaKernel* ok) : gpu_(nullptr, [](T*) {}), count_(0), op_kernel_(ok) {} + CudaAsyncBuffer(const CudaKernel* ok, size_t n) : CudaAsyncBuffer(ok) { AllocCpuPtr(n); } + CudaAsyncBuffer(const CudaKernel* ok, const T& v, size_t n) : CudaAsyncBuffer(ok, n) { + T* p = CpuPtr(); + for (size_t i = 0; i != n; ++i) *p++ = v; + } + CudaAsyncBuffer(const CudaKernel* ok, gsl::span vec) : CudaAsyncBuffer(ok, vec.size()) { memcpy(CpuPtr(), vec.data(), vec.size() * sizeof(T)); } + void AllocCpuPtr(size_t n) { + cpu_ = op_kernel_->AllocateBufferOnCPUPinned(n); + if (!cpu_) throw std::runtime_error("alloc fail"); + count_ = n; + } + Status CopyToGpu(void* s) { + if (cpu_) { + gpu_ = op_kernel_->GetScratchBuffer(count_, s); + if (cudaMemcpyAsync(gpu_.get(), cpu_.get(), count_ * sizeof(T), cudaMemcpyHostToDevice, static_cast(s)) != cudaSuccess) return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Memcpy fail"); + op_kernel_->AddDeferredReleaseCPUPtr(cpu_.release(), s); + } + return Status::OK(); + } + T* CpuPtr() const { return cpu_.get(); } + gsl::span CpuSpan() const { return gsl::span(CpuPtr(), count_); } + T* GpuPtr() const { return gpu_.get(); } + size_t count() const { return count_; } + + protected: + IAllocatorUniquePtr gpu_; + std::unique_ptr> cpu_{nullptr, [](T*) {}}; + size_t count_; + const CudaKernel* op_kernel_; + }; + + private: + const OpKernelInfo& info_; + cudaDeviceProp device_prop_{}; + bool use_tf32_ = true; + int device_id_ = 0; +}; + +// =================================================================== +// Section 8: Compute helper shims (HalfGemmOptions, CublasMathModeSetter) +// =================================================================== + +// Shims for HalfGemmOptions and CublasMathModeSetter required by fpgeneric.h +class HalfGemmOptions { + public: + static const HalfGemmOptions* GetInstance() { + static HalfGemmOptions instance; + return &instance; + } + cublasMath_t GetMathMode() const { return CUBLAS_DEFAULT_MATH; } + bool IsCompute16F() const { return false; } +#if defined(CUBLAS_COMPUTE_32F) + cublasComputeType_t GetComputeType() const { return CUBLAS_COMPUTE_32F; } +#else + cudaDataType_t GetComputeType() const { return CUDA_R_32F; } +#endif +}; + +class CublasMathModeSetter { + public: + CublasMathModeSetter(const cudaDeviceProp& prop, cublasHandle_t handle, cublasMath_t mode) : handle_(handle) { + enable_ = (mode == CUBLAS_TF32_TENSOR_OP_MATH ? prop.major >= 8 : true); + if (enable_) { + cublasGetMathMode(handle, &mode_); + enable_ = (mode_ != mode); + if (enable_) { + cublasSetMathMode(handle, mode); + } + } + } + + ~CublasMathModeSetter() { + if (enable_) { + cublasSetMathMode(handle_, mode_); + } + } + + private: + cublasHandle_t handle_; + cublasMath_t mode_ = CUBLAS_DEFAULT_MATH; + bool enable_; +}; + +} // namespace cuda + +// Global aliases for convenience +using MLFloat16 = onnxruntime::MLFloat16; +using BFloat16 = onnxruntime::BFloat16; +using Float8E4M3FN = onnxruntime::Float8E4M3FN; +using Float8E4M3FNUZ = onnxruntime::Float8E4M3FNUZ; +using Float8E5M2 = onnxruntime::Float8E5M2; +using Float8E5M2FNUZ = onnxruntime::Float8E5M2FNUZ; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_ep.cc new file mode 100644 index 0000000000000..8ad5da2fa6ab7 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_ep.cc @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// DLL entry points for the CUDA Plugin Execution Provider. +// Exports CreateEpFactories() and ReleaseEpFactory() as the +// public interface for ORT to load and use the CUDA EP as a plugin. + +#include "onnxruntime_cxx_api.h" + +#include "cuda_ep_factory.h" + +#ifndef _WIN32 +#define EXPORT_SYMBOL __attribute__((visibility("default"))) +#else +#define EXPORT_SYMBOL +#endif + +extern "C" { + +/// Create the CUDA EP factory instances. +/// Called by ORT when loading the CUDA plugin EP DLL. +EXPORT_SYMBOL OrtStatus* CreateEpFactories( + const char* registration_name, + const OrtApiBase* ort_api_base, + const OrtLogger* default_logger, + OrtEpFactory** factories, + size_t max_factories, + size_t* num_factories) { + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + const OrtEpApi* ep_api = ort_api->GetEpApi(); + + // Initialize the C++ API FIRST before any C++ wrapper usage + Ort::InitApi(ort_api); + + if (default_logger == nullptr) { + return ort_api->CreateStatus( + ORT_INVALID_ARGUMENT, + "CUDA Plugin EP: default_logger must not be null."); + } + + // Log initialization (default_logger is guaranteed non-null after the check above). + { + std::string msg = "CreateEpFactories: Initializing CUDA Plugin EP with registration name: "; + msg += (registration_name ? registration_name : "NULL"); + auto* status = ort_api->Logger_LogMessage(default_logger, ORT_LOGGING_LEVEL_INFO, + msg.c_str(), + ORT_FILE, __LINE__, __FUNCTION__); + if (status) ort_api->ReleaseStatus(status); + } + + if (max_factories < 1) { + auto* log_status = ort_api->Logger_LogMessage(default_logger, ORT_LOGGING_LEVEL_ERROR, + "CreateEpFactories: max_factories < 1", + ORT_FILE, __LINE__, __FUNCTION__); + if (log_status) ort_api->ReleaseStatus(log_status); + return ort_api->CreateStatus( + ORT_INVALID_ARGUMENT, + "CUDA Plugin EP: Not enough space to return EP factory. Need at least one."); + } + + try { + auto factory = std::make_unique( + *ort_api, *ep_api, *default_logger); + + factories[0] = factory.release(); + *num_factories = 1; + + auto* log_status = ort_api->Logger_LogMessage(default_logger, ORT_LOGGING_LEVEL_INFO, + "CreateEpFactories: Successfully created CUDA EP factory", + ORT_FILE, __LINE__, __FUNCTION__); + if (log_status) ort_api->ReleaseStatus(log_status); + } catch (const std::exception& ex) { + auto* log_status = ort_api->Logger_LogMessage(default_logger, ORT_LOGGING_LEVEL_ERROR, + ex.what(), ORT_FILE, __LINE__, __FUNCTION__); + if (log_status) ort_api->ReleaseStatus(log_status); + return ort_api->CreateStatus(ORT_EP_FAIL, ex.what()); + } + + return nullptr; +} + +/// Release a CUDA EP factory instance. +/// Called by ORT when unloading the CUDA plugin EP DLL. +EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete static_cast(factory); + return nullptr; +} + +} // extern "C" diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_ep_symbols.def b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_ep_symbols.def new file mode 100644 index 0000000000000..6d0efad28c669 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_ep_symbols.def @@ -0,0 +1,4 @@ +; Symbol export definition file for CUDA Plugin EP (Windows) +EXPORTS + CreateEpFactories + ReleaseEpFactory diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu new file mode 100644 index 0000000000000..822ea3d5fa72f --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file provides the CreateCudaKernelRegistry entrypoint for the CUDA plugin EP. +// +// Kernel registration is now fully automatic: each compiled kernel .cc file's +// ONNX_OPERATOR_*_KERNEL_EX macro expansion creates a BuildKernelCreateInfo<>() +// template specialization and auto-registers it in PluginKernelCollector via +// the macro overrides in cuda_kernel_adapter.h. +// +// CreateCudaKernelRegistry iterates the collector and registers each entry +// into an adapter::KernelRegistry which is then returned to the EP factory. + +#include "cuda_plugin_kernels.h" +#include "cuda_stream_plugin.h" +#include "cuda_kernel_adapter.h" + +// Define the BuildKernelCreateInfo() sentinel in onnxruntime::cuda. +// This is normally defined in cuda_execution_provider.cc (excluded from plugin). +// The NHWC registration tables reference it as a placeholder to prevent empty arrays. +namespace onnxruntime::cuda { +template <> +KernelCreateInfo BuildKernelCreateInfo() { + KernelCreateInfo info; + return info; +} +} // namespace onnxruntime::cuda + +namespace onnxruntime { +namespace cuda_plugin { + +OrtStatus* CreateCudaKernelRegistry(const OrtEpApi& /*ep_api*/, + const char* /*ep_name*/, + void* /*create_kernel_state*/, + OrtKernelRegistry** out_registry) { + *out_registry = nullptr; + + EXCEPTION_TO_STATUS_BEGIN + + // adapter::KernelRegistry wraps OrtKernelRegistry via the Ort C++ API. + ::onnxruntime::ep::adapter::KernelRegistry registry; + + // Iterate all self-registered BuildKernelCreateInfoFn pointers. + const auto& entries = ::onnxruntime::cuda::PluginKernelCollector::Instance().Entries(); + for (auto build_fn : entries) { + ::onnxruntime::ep::adapter::KernelCreateInfo info = build_fn(); + if (info.kernel_def != nullptr) { // filter the BuildKernelCreateInfo sentinel + (void)registry.Register(std::move(info)); + } + } + + *out_registry = registry.release(); + return nullptr; + + EXCEPTION_TO_STATUS_END +} + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.h b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.h new file mode 100644 index 0000000000000..1ff979de22743 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_plugin_utils.h" + +namespace onnxruntime { +namespace cuda_plugin { + +/// Create the CUDA kernel registry using self-registered BuildKernelCreateInfo<> +/// entries from PluginKernelCollector. All compiled kernel .cc files automatically +/// contribute their registrations via the macro overrides in cuda_kernel_adapter.h. +OrtStatus* CreateCudaKernelRegistry(const OrtEpApi& ep_api, + const char* ep_name, + void* create_kernel_state, + OrtKernelRegistry** out_registry); + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h new file mode 100644 index 0000000000000..82a792faa3258 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Common utilities, error-handling macros, and type definitions shared by +// all source files in the CUDA plugin EP implementation. + +#pragma once + +#include "onnxruntime_c_api.h" +#include "onnxruntime_cxx_api.h" + +#include +#include +#include +#include + +// Error handling macros + +#ifndef PL_CUDA_RETURN_IF_ERROR +#define PL_CUDA_RETURN_IF_ERROR(cuda_call_expr) \ + do { \ + cudaError_t _cuda_err = (cuda_call_expr); \ + if (_cuda_err != cudaSuccess) { \ + return Ort::GetApi().CreateStatus( \ + ORT_EP_FAIL, \ + (std::string("CUDA error: ") + cudaGetErrorName(_cuda_err) + ": " + \ + cudaGetErrorString(_cuda_err)) \ + .c_str()); \ + } \ + } while (0) +#endif + +#ifndef PL_CUBLAS_RETURN_IF_ERROR +#define PL_CUBLAS_RETURN_IF_ERROR(cublas_call_expr) \ + do { \ + cublasStatus_t _cublas_err = (cublas_call_expr); \ + if (_cublas_err != CUBLAS_STATUS_SUCCESS) { \ + return Ort::GetApi().CreateStatus( \ + ORT_EP_FAIL, \ + (std::string("cuBLAS error: ") + \ + std::to_string(static_cast(_cublas_err))) \ + .c_str()); \ + } \ + } while (0) +#endif + +#ifndef PL_CUDNN_RETURN_IF_ERROR +#define PL_CUDNN_RETURN_IF_ERROR(cudnn_call_expr) \ + do { \ + cudnnStatus_t _cudnn_err = (cudnn_call_expr); \ + if (_cudnn_err != CUDNN_STATUS_SUCCESS) { \ + return Ort::GetApi().CreateStatus( \ + ORT_EP_FAIL, \ + (std::string("cuDNN error: ") + \ + cudnnGetErrorString(_cudnn_err)) \ + .c_str()); \ + } \ + } while (0) +#endif + +#define EXCEPTION_TO_STATUS_BEGIN try { +#define EXCEPTION_TO_STATUS_END \ + } \ + catch (const Ort::Exception& ex) { \ + Ort::Status status(ex); \ + return status.release(); \ + } \ + catch (const std::exception& ex) { \ + Ort::Status status(ex.what(), ORT_EP_FAIL); \ + return status.release(); \ + } + +/// Stored API pointers accessible to all plugin components. +struct CudaPluginApis { + const OrtApi& ort_api; + const OrtEpApi& ep_api; +}; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc new file mode 100644 index 0000000000000..c1e753440cca4 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc @@ -0,0 +1,194 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cuda_stream_plugin.h" +#include "cuda_ep_factory.h" +#include + +namespace onnxruntime { +namespace cuda_plugin { + +namespace { + +// Global stream-to-CudaSyncStream mapping. +// Required because migrated CUDA kernels receive only a raw cudaStream_t +// but need access to associated cuBLAS/cuDNN handles. The map is +// lazily initialized and never freed (process-lifetime singleton). +static std::unordered_map* g_stream_map = nullptr; +static std::mutex* g_stream_map_mutex = nullptr; +static std::once_flag g_stream_map_init_flag; + +void InitStreamMap() { + std::call_once(g_stream_map_init_flag, []() { + g_stream_map = new std::unordered_map(); + g_stream_map_mutex = new std::mutex(); + }); +} +} // namespace + +// --------------------------------------------------------------------------- +// CudaSyncStream +// --------------------------------------------------------------------------- + +CudaSyncStream::CudaSyncStream(CudaEpFactory& factory, int device_id, + const OrtEp* /*ep*/) + : OrtSyncStreamImpl{}, + factory_(factory), + device_id_(device_id) { + ort_version_supported = ORT_API_VERSION; + GetHandle = GetHandleImpl; + CreateNotification = CreateNotificationImpl; + Flush = FlushImpl; + OnSessionRunEnd = OnSessionRunEndImpl; + Release = ReleaseImpl; +} + +CudaSyncStream::~CudaSyncStream() { + CleanupDeferredCPUBuffers(); + + if (cuda_stream_) UnregisterStream(cuda_stream_); + + if (cublas_handle_) cublasDestroy(cublas_handle_); + if (cudnn_handle_) cudnnDestroy(cudnn_handle_); + if (cublas_lt_handle_) cublasLtDestroy(cublas_lt_handle_); + if (cuda_stream_) cudaStreamDestroy(cuda_stream_); +} + +OrtStatus* CudaSyncStream::InitHandles() { + cudaSetDevice(device_id_); + + PL_CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&cuda_stream_, cudaStreamNonBlocking)); + RegisterStream(cuda_stream_, this); + + PL_CUBLAS_RETURN_IF_ERROR(cublasCreate(&cublas_handle_)); + PL_CUBLAS_RETURN_IF_ERROR(cublasSetStream(cublas_handle_, cuda_stream_)); + + PL_CUDNN_RETURN_IF_ERROR(cudnnCreate(&cudnn_handle_)); + PL_CUDNN_RETURN_IF_ERROR(cudnnSetStream(cudnn_handle_, cuda_stream_)); + + PL_CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&cublas_lt_handle_)); + + return nullptr; +} + +void CudaSyncStream::EnqueueDeferredCPUBuffer(void* cpu_buffer) { + deferred_cpu_buffers_.push_back(cpu_buffer); +} + +void CudaSyncStream::CleanupDeferredCPUBuffers() { + for (void* buf : deferred_cpu_buffers_) { + cudaFreeHost(buf); + } + deferred_cpu_buffers_.clear(); +} + +/*static*/ void* ORT_API_CALL CudaSyncStream::GetHandleImpl(OrtSyncStreamImpl* this_ptr) noexcept { + auto* stream = static_cast(this_ptr); + return stream->cuda_stream_; +} + +/*static*/ OrtStatus* ORT_API_CALL CudaSyncStream::CreateNotificationImpl( + OrtSyncStreamImpl* this_ptr, OrtSyncNotificationImpl** notification) noexcept { + EXCEPTION_TO_STATUS_BEGIN + auto* stream = static_cast(this_ptr); + auto notif = std::make_unique(*stream); + *notification = notif.release(); + return nullptr; + EXCEPTION_TO_STATUS_END +} + +/*static*/ OrtStatus* ORT_API_CALL CudaSyncStream::FlushImpl(OrtSyncStreamImpl* this_ptr) noexcept { + auto* stream = static_cast(this_ptr); + PL_CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream->cuda_stream_)); + return nullptr; +} + +/*static*/ OrtStatus* ORT_API_CALL CudaSyncStream::OnSessionRunEndImpl(OrtSyncStreamImpl* this_ptr) noexcept { + auto* stream = static_cast(this_ptr); + // Synchronize before releasing deferred CPU buffers to ensure + // all async copies using those buffers have completed. + PL_CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream->cuda_stream_)); + stream->CleanupDeferredCPUBuffers(); + return nullptr; +} + +/*static*/ void ORT_API_CALL CudaSyncStream::ReleaseImpl(OrtSyncStreamImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +/*static*/ CudaSyncStream* CudaSyncStream::FromCudaStream(cudaStream_t stream) { + InitStreamMap(); + std::lock_guard lock(*g_stream_map_mutex); + auto it = g_stream_map->find(stream); + if (it != g_stream_map->end()) { + return it->second; + } + return nullptr; +} + +/*static*/ void CudaSyncStream::RegisterStream(cudaStream_t stream, CudaSyncStream* sync_stream) { + InitStreamMap(); + std::lock_guard lock(*g_stream_map_mutex); + (*g_stream_map)[stream] = sync_stream; +} + +/*static*/ void CudaSyncStream::UnregisterStream(cudaStream_t stream) { + if (!g_stream_map) return; + std::lock_guard lock(*g_stream_map_mutex); + g_stream_map->erase(stream); +} + +// --------------------------------------------------------------------------- +// CudaSyncNotification +// --------------------------------------------------------------------------- + +CudaSyncNotification::CudaSyncNotification(CudaSyncStream& stream) + : OrtSyncNotificationImpl{}, + stream_(stream) { + ort_version_supported = ORT_API_VERSION; + Activate = ActivateImpl; + WaitOnDevice = WaitOnDeviceImpl; + WaitOnHost = WaitOnHostImpl; + Release = ReleaseImpl; + + // Create a CUDA event for synchronization (disable timing for performance) + cudaSetDevice(stream_.GetDeviceId()); + cudaEventCreateWithFlags(&event_, cudaEventDisableTiming); +} + +CudaSyncNotification::~CudaSyncNotification() { + if (event_) { + cudaEventDestroy(event_); + } +} + +/*static*/ OrtStatus* ORT_API_CALL CudaSyncNotification::ActivateImpl( + OrtSyncNotificationImpl* this_ptr) noexcept { + auto* notif = static_cast(this_ptr); + PL_CUDA_RETURN_IF_ERROR(cudaEventRecord(notif->event_, notif->stream_.GetCudaStream())); + return nullptr; +} + +/*static*/ OrtStatus* ORT_API_CALL CudaSyncNotification::WaitOnDeviceImpl( + OrtSyncNotificationImpl* this_ptr, OrtSyncStream* stream) noexcept { + auto* notif = static_cast(this_ptr); + // SyncStream_GetHandle is in the main ORT API + cudaStream_t wait_stream = static_cast(Ort::GetApi().SyncStream_GetHandle(stream)); + PL_CUDA_RETURN_IF_ERROR(cudaStreamWaitEvent(wait_stream, notif->event_, 0)); + return nullptr; +} + +/*static*/ OrtStatus* ORT_API_CALL CudaSyncNotification::WaitOnHostImpl( + OrtSyncNotificationImpl* this_ptr) noexcept { + auto* notif = static_cast(this_ptr); + PL_CUDA_RETURN_IF_ERROR(cudaEventSynchronize(notif->event_)); + return nullptr; +} + +/*static*/ void ORT_API_CALL CudaSyncNotification::ReleaseImpl( + OrtSyncNotificationImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h new file mode 100644 index 0000000000000..347ac0ede0dfa --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_plugin_utils.h" + +#include +#include +#include + +namespace onnxruntime { +namespace cuda_plugin { + +class CudaSyncNotification; +class CudaEpFactory; + +/// CUDA stream implementation for the plugin EP. +/// Owns a cudaStream_t and associated cuBLAS/cuDNN handles. +class CudaSyncStream : public OrtSyncStreamImpl { + public: + CudaSyncStream(CudaEpFactory& factory, int device_id, + const OrtEp* ep); + ~CudaSyncStream(); + + int GetDeviceId() const { return device_id_; } + cudaStream_t GetCudaStream() const { return cuda_stream_; } + cublasHandle_t GetCublasHandle() const { return cublas_handle_; } + cudnnHandle_t GetCudnnHandle() const { return cudnn_handle_; } + cublasLtHandle_t GetCublasLtHandle() const { return cublas_lt_handle_; } + + void EnqueueDeferredCPUBuffer(void* cpu_buffer); + OrtStatus* InitHandles(); + + static CudaSyncStream* FromCudaStream(cudaStream_t stream); + + private: + static void RegisterStream(cudaStream_t stream, CudaSyncStream* sync_stream); + static void UnregisterStream(cudaStream_t stream); + static void* ORT_API_CALL GetHandleImpl(OrtSyncStreamImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL CreateNotificationImpl( + OrtSyncStreamImpl* this_ptr, OrtSyncNotificationImpl** notification) noexcept; + static OrtStatus* ORT_API_CALL FlushImpl(OrtSyncStreamImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL OnSessionRunEndImpl(OrtSyncStreamImpl* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtSyncStreamImpl* this_ptr) noexcept; + + void CleanupDeferredCPUBuffers(); + + CudaEpFactory& factory_; + int device_id_; + cudaStream_t cuda_stream_ = nullptr; + cublasHandle_t cublas_handle_ = nullptr; + cudnnHandle_t cudnn_handle_ = nullptr; + cublasLtHandle_t cublas_lt_handle_ = nullptr; + + std::vector deferred_cpu_buffers_; +}; + +/// CUDA event-based notification for stream synchronization. +class CudaSyncNotification : public OrtSyncNotificationImpl { + public: + explicit CudaSyncNotification(CudaSyncStream& stream); + ~CudaSyncNotification(); + + private: + static OrtStatus* ORT_API_CALL ActivateImpl(OrtSyncNotificationImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL WaitOnDeviceImpl( + OrtSyncNotificationImpl* this_ptr, OrtSyncStream* stream) noexcept; + static OrtStatus* ORT_API_CALL WaitOnHostImpl(OrtSyncNotificationImpl* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtSyncNotificationImpl* this_ptr) noexcept; + + CudaSyncStream& stream_; + cudaEvent_t event_ = nullptr; +}; + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/provider_api_shims.cc b/onnxruntime/core/providers/cuda/plugin/provider_api_shims.cc new file mode 100644 index 0000000000000..7e35096039341 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/provider_api_shims.cc @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Provider API shims used by migrated CUDA kernels. +// Direct implementations — no SHARED_PROVIDER bridge needed. + +#include +#include +#include "core/common/float16.h" + +namespace onnxruntime { + +std::string GetEnvironmentVar(const std::string& var_name) { + const char* val = std::getenv(var_name.c_str()); + return val ? std::string(val) : std::string(); +} + +namespace math { + +uint16_t floatToHalf(float f) { + return MLFloat16(f).val; +} + +float halfToFloat(uint16_t h) { + return MLFloat16::FromBits(h).ToFloat(); +} + +} // namespace math + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 0e3ecdf2385c7..2ba0d24c55fef 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -7,6 +7,18 @@ // switching providers to be runnable as shared libraries. The interfaces will become more tightly integrated into the core code. #pragma once + +// When building the CUDA EP as a plugin (BUILD_CUDA_EP_AS_PLUGIN), +// skip all SHARED_PROVIDER type redefinitions. The adapter header (ep/adapters.h) +// provides its own facade types, and the SHARED_PROVIDER bridge would conflict. +#ifdef BUILD_CUDA_EP_AS_PLUGIN + +// Plugin build: provider_api.h is a complete no-op. We do NOT define +// SHARED_PROVIDER so that #ifndef SHARED_PROVIDER guards in framework +// headers (op_kernel.h, etc.) remain active. + +#else // !BUILD_CUDA_EP_AS_PLUGIN — normal SHARED_PROVIDER path + #define SHARED_PROVIDER 1 #ifdef _WIN32 @@ -307,7 +319,7 @@ constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; template -using IAllocatorUniquePtr = std::unique_ptr>; +using IAllocatorUniquePtr = std::unique_ptr >; inline OrtStatus* CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept { return g_host->CreateStatus(code, msg); } @@ -418,7 +430,7 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2; } -inline std::vector> +inline std::vector > CreateSupportedPartitions(const GraphViewer& graph_viewer, const std::unordered_set& supported_nodes, const std::unordered_set& stop_ops, @@ -466,7 +478,7 @@ inline Status ConvertInMemoryDataToInline(Graph& graph, const std::string& name) } // namespace graph_utils namespace QDQ { -inline std::pair>, std::unordered_map> +inline std::pair >, std::unordered_map > GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) { return g_host->QDQ__GetAllNodeUnits(graph_viewer, logger); } @@ -511,3 +523,5 @@ inline T* Initializer::data() { #define LOGS_DEFAULT(severity) \ LOGS_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime) + +#endif // !BUILD_CUDA_EP_AS_PLUGIN From 1aab300e69804583aac6ab3d78040be505d2929a Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 20 Mar 2026 17:16:21 -0700 Subject: [PATCH 02/48] Add doc --- docs/cuda_plugin_ep/cuda_ops_for_plugin_ep.md | 269 +++++++ docs/cuda_plugin_ep/cuda_plugin_ep_design.md | 709 ++++++++++++++++++ 2 files changed, 978 insertions(+) create mode 100644 docs/cuda_plugin_ep/cuda_ops_for_plugin_ep.md create mode 100644 docs/cuda_plugin_ep/cuda_plugin_ep_design.md diff --git a/docs/cuda_plugin_ep/cuda_ops_for_plugin_ep.md b/docs/cuda_plugin_ep/cuda_ops_for_plugin_ep.md new file mode 100644 index 0000000000000..0a6e627b1b5f7 --- /dev/null +++ b/docs/cuda_plugin_ep/cuda_ops_for_plugin_ep.md @@ -0,0 +1,269 @@ +# CUDA Kernel Changes for Plugin EP Compatibility + +## Overview + +The CUDA Plugin EP builds CUDA operator kernels into a separate shared library +(`onnxruntime_providers_cuda_plugin.so`) that communicates with the ORT core +through the ORT EP API. This architecture requires that kernel source files +**not** depend on framework-internal types that are unavailable across the +shared-library boundary. + +The plugin build uses two key mechanisms to achieve compatibility with minimal +(or zero) changes to existing kernel `.cc` files: + +1. **Force-included adapter headers** — The CMake build injects + `adapters.h` and `cuda_kernel_adapter.h` via `-include` compiler flags. + These headers redefine macros (`ONNX_OPERATOR_*_KERNEL_EX`), provide a + plugin-compatible `CudaKernel` base class, and supply shims for + `OpKernelContext`, `OpKernelInfo`, etc. + +2. **`BUILD_CUDA_EP_AS_PLUGIN` preprocessor guard** — For cases where the + adapter headers alone are insufficient, kernel headers can use + `#ifdef BUILD_CUDA_EP_AS_PLUGIN` to select an alternative code path + (e.g., a self-contained class instead of inheriting from a CPU base class). + +## Common Incompatibility Patterns + +| Pattern | Description | Typical Fix | +|---------|-------------|-------------| +| **`GetComputeStream()` returning `onnxruntime::Stream*`** | The adapter `OpKernelContext` exposes `GetComputeStream()` that returns the adapter `Stream*` with `GetHandle()`. Most kernels call `GetScratchBuffer(n, ctx->GetComputeStream())` which already works through the adapter. Kernels that `dynamic_cast` or call `CudaStream`-specific methods break. | Use `static_cast(ctx->GetComputeStream()->GetHandle())` instead of `CudaStream*` methods. The adapter `CudaKernel::GetCublasHandle(cudaStream_t)` and `GetCudnnHandle(cudaStream_t)` are available. | +| **Inheritance from CPU base class** | Kernels like `Resize : Upsample`, `SpaceToDepth : SpaceDepthBase`, `NonMaxSuppression : NonMaxSuppressionBase` inherit from CPU provider classes that are not linked into the plugin. | Add a `#ifdef BUILD_CUDA_EP_AS_PLUGIN` block in the header with a self-contained class that inlines the needed logic (see `constant_of_shape.h` for an example). | +| **`TensorSeq` (incomplete type)** | `TensorSeq` is not available in the plugin build. `identity_op.cc` and `sequence_op.cc` operate on sequence types. | These ops should remain excluded or need `TensorSeq` to be exposed through the EP API. | +| **`CudaTuningContext`** | Kernels that call `GetTuningContext()` and use `CudaTuningContext` methods directly. The adapter provides a stub `GetTuningContext()` but full tuning infra is unavailable. | Guard tuning-specific calls with `#ifndef BUILD_CUDA_EP_AS_PLUGIN` or use the adapter's stub which returns `nullptr` (callers should null-check). | +| **`PhiloxGenerator` / RNG state** | Dropout-family ops use `PhiloxGenerator` from the `CudaStream` object. This requires `CudaStream*` access. | Needs a `PhiloxGenerator` accessor in the adapter or exclusion. | +| **`QkvToContext` taking `Stream*`** | Attention ops pass `context->GetComputeStream()` (an `onnxruntime::Stream*`) to `QkvToContext`. This function dereferences `Stream*` internally. | Either change `QkvToContext` signature to accept `cudaStream_t` + handles, or provide a `PluginStreamShim` wrapper (already in the adapter). | +| **Pure CPU ops** | `Shape`, `Size` — these register CPU-side `OpKernel` classes whose `Compute()` is in the CPU provider library. | Permanently excluded; handled by `GetCpuPreferredNodes()`. | +| **`cuda_execution_provider.h` include** | Files that directly include the real `CUDAExecutionProvider` class definition conflict with the adapter's shim class. | Use the adapter's `CUDAExecutionProvider` shim (automatically provided by `cuda_kernel_adapter.h`). | +| **KernelInfoGetAttributeArray\_string** | RNN ops call `GetAttrs(...)` which maps to a C API function not yet available. | Wait for C API extension, or inline attribute parsing. | +| **Registration tables** | `cuda_nhwc_kernels.cc` and `cuda_contrib_kernels.cc` contain centralized `BuildKernelCreateInfo<>` tables that reference all kernel classes, including excluded ones. | Not needed — the plugin uses `PluginKernelCollector` for self-registration via macro overrides. | + +## How to Bring an Excluded Kernel to Plugin EP + +### Step 1: Identify the Dependency + +Check why the kernel is excluded by looking at: +- The comment in `cmake/onnxruntime_providers_cuda_plugin.cmake` +- The kernel `.cc`/`.h` files for the patterns listed above + +### Step 2: Apply the Minimal Fix + +The preferred approach (in order of preference): + +1. **No source change needed** — If the only issue was `GetComputeStream()` + usage with `GetScratchBuffer()`, the adapter already handles this. Just + remove the exclusion from the cmake file and test. + +2. **Use `BUILD_CUDA_EP_AS_PLUGIN` guard in the header** — For CPU base + class dependencies, add an alternative class definition: + ```cpp + #ifdef BUILD_CUDA_EP_AS_PLUGIN + class MyOp final : public CudaKernel { + // Self-contained implementation that inlines base class logic + }; + #else + class MyOp final : public CpuBaseClass, public CudaKernel { + // Original implementation + }; + #endif + ``` + +3. **Modify calling convention** — For functions that take + `onnxruntime::Stream*` or `CudaStream*`, change to accept + `cudaStream_t` + explicit handles: + ```cpp + // Before: + SomeHelper(context->GetComputeStream(), ...); + // After: + SomeHelper(static_cast(context->GetComputeStream()->GetHandle()), ...); + ``` + +4. **Add a shim in `cuda_kernel_adapter.h`** — For utility functions from + CPU providers (e.g., `ValidateInputs`, `PrepareCompute`), inline the + logic in the adapter header so it's available in the plugin build. + +5. **Inline CPU helper to header** — Move the helper implementation + from the CPU `.cc` file to the `.h` header, wrapped in + `#ifdef SHARED_PROVIDER` (declaration only) / `#else` (inline body). + The `SHARED_PROVIDER` build retains the existing `ProviderHostCPU` + bridge path. See `padbase.h`, `slice.h`, `scatter_nd.h` for examples. + +6. **Templatize on info/context type** — For base class constructors + that call `GetAttr()`, templatize on `KernelInfoType` with + `info.template GetAttr(...)`. For methods that take + `OpKernelContext&`, templatize on `KernelContextType`. + See `roialign.h`, `unsqueeze.h`, `attention_base.h` for examples. + +7. **Move CUDA type helpers to shared header** — For utility functions + that only depend on CUDA types (not framework types), move from + `.cc` to a header so the plugin build can consume them directly. + See `cuda_common_type_helpers.h`. + +### Step 3: Remove the CMake Exclusion + +In `cmake/onnxruntime_providers_cuda_plugin.cmake`, comment out the exclusion +line with a note about what was done: +```cmake +# myop.cc: . +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/myop\\.cc$") # REMOVED in Stage N +``` + +### Step 4: Build and Test + +```bash +# Build with plugin EP enabled +./build.sh --config Release --use_cuda --build_cuda_ep_as_plugin +# Run parity tests +python tools/ci_build/cuda_plugin_parity_report.py +``` + +--- + +## Excluded Operators Table + +### Infrastructure Files (Not Operator Kernels) + +These are excluded because they define the real EP infrastructure, which is +replaced by the plugin's own implementations in `plugin/`. + +| File | Reason | Resolution | +|------|--------|------------| +| `cuda_execution_provider.cc` | Defines the real `CUDAExecutionProvider` class; conflicts with adapter shim. | Permanently excluded; replaced by `plugin/cuda_ep.cc`. | +| `cuda_provider_factory.cc` | Creates the real CUDA EP via `ProviderFactory`; not used in plugin architecture. | Permanently excluded; replaced by `plugin/cuda_ep_factory.cc`. | +| `cuda_provider_interface.cc` | Shared-library provider interface for the old (non-plugin) shared-library model. | Permanently excluded; not applicable to plugin EP. | +| `cuda_stream_handle.cc` | Defines `CudaStream` class; replaced by plugin stream adapter. | Permanently excluded; replaced by `plugin/cuda_stream_plugin.cc`. | +| `cuda_execution_provider_info.cc` | EP configuration parsing tied to the real EP. | Permanently excluded; replaced by `plugin/cuda_ep.cc` config. | +| `cuda_graph.cc` | CUDA graph capture tied to real EP stream management. | Permanently excluded; replaced by `plugin/cuda_graph_plugin.cc`. | +| `cuda_mempool_arena.cc` | Memory arena tied to real EP allocator infrastructure. | Permanently excluded; replaced by `plugin/cuda_allocator_plugin.cc`. | +| `cuda_common.cc` | `HalfGemmOptions` definitions conflict with adapter's inline shim. | Permanently excluded; shims provided in `cuda_kernel_adapter.h`. | +| `cuda_nhwc_kernels.cc` | Centralized kernel registration table; references all NHWC kernel classes. | Permanently excluded; `PluginKernelCollector` auto-registers. | +| `cuda_contrib_kernels.cc` | Centralized kernel registration table; references all contrib kernel classes. | Permanently excluded; `PluginKernelCollector` auto-registers. | + +### Standard ONNX Operator Kernels — Currently Excluded + +| File | Exclusion Reason | Change Needed to Include | +|------|-----------------|--------------------------| +| `math/einsum.cc` | Inherits from `onnxruntime::Einsum` (CPU provider); calls `Einsum::Compute()` which chains to `DeviceCompute()` through the CPU base class vtable. Also depends on `einsum_utils/` which calls `ReduceCompute`. | Add `#ifdef BUILD_CUDA_EP_AS_PLUGIN` path that directly implements `ComputeInternal()` without the CPU base class. Substantial effort — einsum is complex. | +| `math/einsum_utils/*` | `einsum_auxiliary_ops.cc` calls `ReductionOps::ReduceCompute` which is a framework-only function. | Must inline or rewrite reduction logic for plugin build. Coupled with `einsum.cc`. | +| `controlflow/*` (If, Loop, Scan) | Inherits from CPU base classes (`If`, `Loop`, `Scan` from `core/providers/cpu/controlflow/`). These ops call into the ORT session to execute subgraphs. | Plugin has custom wrappers in `plugin/cuda_controlflow_plugin.cc` that delegate to `OrtEpApi`. Permanently excluded from standard source; plugin equivalents exist. | +| `tunable/*` | Depends on `CudaTuningContext` and the real `CUDAExecutionProvider` for tuning infrastructure. | Needs full tuning API exposure through plugin interface. Low priority — tuning is optional. | +| `rnn/*` (RNN, GRU, LSTM) | Kernel constructors call `GetAttrs("activations", ...)` which maps to `KernelInfoGetAttributeArray_string` — a C API function that does not yet exist. Also uses `CudnnRnnBase` which manages cuDNN RNN descriptors. | Extend the ORT C API with `KernelInfoGetAttributeArray_string`. After that, the dual-build signatures (already in place) should work. | +| `tensor/identity_op.cc` | Uses `TensorSeq` (incomplete type in plugin build) for sequence pass-through in `IdentityOp`. | Expose `TensorSeq` through the EP API adapter, or split the sequence codepath into a separate file with `#ifdef`. | +| `tensor/sequence_op.cc` | All ops (`SequenceAt`, `SequenceConstruct`, `SequenceInsert`, etc.) heavily use `TensorSeq`. | Same as `identity_op.cc` — requires `TensorSeq` support in the adapter. | +| `tensor/size.cc` | Pure CPU op — registers `onnxruntime::Size` whose `Compute()` is in the CPU provider. | **Permanently excluded.** Handled by `GetCpuPreferredNodes()`. | +| `tensor/shape_op.cc` | Pure CPU op — inherits from `onnxruntime::OpKernel` (framework class, not adapter `OpKernel`). Output is on CPU. | **Permanently excluded.** Handled by `GetCpuPreferredNodes()`. | +| `tensor/space_depth_ops.cc` | Inherits from `SpaceDepthBase` (CPU provider, `core/providers/cpu/tensor/space_depth_ops.h`). | `SpaceDepthBase` constructor templatized on `KernelInfoType` (#27628). Remaining: inline `SpaceDepthCompute` validation logic or add adapter-compatible path. Reduced effort. | +| `tensor/upsample.cc` | Inherits from `UpsampleBase` (CPU provider). `UpsampleBase` uses `InputDefs()` and complex attribute/input parsing in its constructor. | `UpsampleBase::AdjustOutputSizeAsPolicy` moved to header (#27628). Remaining blockers: `InputDefs()` and `OpKernelInfo::GetAllocator()` not available in adapter. Moderate effort. | +| `tensor/resize.cc` | Inherits from `Upsample` which inherits from `UpsampleBase`. | Blocked on `upsample.cc` — must fix `Upsample` first, then `Resize` follows. | +| `generator/constant_of_shape.cc` | Inherits from `ConstantOfShapeBase` (CPU provider) which uses `TensorProto`/`UnpackTensor`. | **Already has `#ifdef BUILD_CUDA_EP_AS_PLUGIN` path** in the header with a self-contained class. Currently excluded because the `.cc` file's `#else` path still compiles `ConstantOfShapeBase` version. Need to verify the `#ifdef` path compiles and remove the exclusion. | +| `object_detection/*` (NonMaxSuppression, RoiAlign) | `NonMaxSuppression` inherits from `NonMaxSuppressionBase`; `RoiAlign` inherits from `RoiAlignBase`. Both CPU base classes. `NonMaxSuppression` also uses CPU helper `PrepareCompute`. | `NonMaxSuppressionBase` refactored to `NonMaxSuppressionBaseImpl` template (#27617). `RoiAlignBase` constructor templatized, `CheckROIAlignValidInput` inlined (#27628). Remaining: integration verification and residual `GetComputeStream()` issues. | +| `llm/*` | Attention kernels that call `QkvToContext` with `onnxruntime::Stream*`. Deep dependency on attention implementation internals. | Change `QkvToContext` to accept `cudaStream_t` + explicit handles, or use `PluginStreamShim`. Large surface area. | + +### Contrib Operator Kernels — Currently Excluded + +| File | Exclusion Reason | Change Needed to Include | +|------|-----------------|--------------------------| +| **aten_ops/\*** | PyTorch ATen operator bindings; requires `libtorch`. Not relevant for plugin EP. | **Permanently excluded.** | +| **collective/\*** | NCCL/MPI collective ops; requires distributed runtime. | **Permanently excluded** (or separate plugin). | +| **contrib llm/\*** | Same as standard `llm/` — deep `Stream*` and attention infra dependencies. | Same fix as standard `llm/`. | +| **transformers/\*** (beam_search, greedy_search, sampling) | Directly includes `cuda_execution_provider.h`. Uses session-level APIs to run subgraphs (encoder/decoder). Heavy framework dependency. | Would need significant refactoring to route subgraph execution through `OrtEpApi`. Very high effort. | +| **bert/attention.cc** | Calls `GetScratchBuffer` with `context->GetComputeStream()` (works via adapter). Main issue: calls `QkvToContext` passing `context->GetComputeStream()` (`Stream*`), and uses `IAllocator::MakeUniquePtr` with stream. | `AttentionBase::CheckInputs`/`CheckMask`/`GetPresent` moved to header (#27628). Remaining blocker: `QkvToContext` takes `Stream*`. Moderate-high effort. | +| **bert/decoder_attention.cc** | Same pattern as `attention.cc` — `QkvToContext` with `Stream*`. | Same fix as `attention.cc`. | +| **bert/decoder_masked_self_attention.cc** | Uses `GetComputeStream()` for scratch buffers and stream handle extraction. | Replace `GetComputeStream()` → adapter-compatible calls. Moderate effort. | +| **bert/embed_layer_norm.cc** | `embed_layer_norm_helper::CheckInputs` templatized and moved to header (#27617). CPU base class dependency resolved. | Verify compilation with exclusion removed — helper refactoring complete. **Very low effort.** | +| **bert/fast_gelu.cc** | Was excluded due to `bias_gelu_helper` CPU base class dependency. `bias_gelu_helper::CheckInputs` now templatized and inlined (#27617). | Verify compilation with exclusion removed — helper refactoring complete. **Very low effort.** | +| **bert/group_query_attention.cc** | Heavy use of `GetComputeStream()` (scratch buffers, stream handle extraction, `CudaStream*` cast). Complex attention pipeline with flash attention, XQA loader. | Same approach as `attention.cc`. High effort due to many code paths. | +| **bert/longformer_attention.cc** | Uses `GetScratchBuffer` with `GetComputeStream()`, workspace allocation. `LongformerAttentionBase::CheckInputs` moved to header (#27628). | Remaining blocker: `GetComputeStream()` / `Stream*` usage. Moderate effort. | +| **bert/multihead_attention.cc** | Same pattern as `attention.cc` — `QkvToContext` with `Stream*`. | Same fix as `attention.cc`. | +| **bert/packed_attention.cc** | Same attention pipeline dependency. | Same fix as `attention.cc`. | +| **bert/packed_multihead_attention.cc** | Same attention pipeline dependency. | Same fix as `attention.cc`. | +| **bert/paged_attention.cc** | Uses `GetComputeStream()` for scratch buffers and paged KV-cache management. | Replace stream access pattern. Moderate effort. | +| **bert/relative_attn_bias.cc** | Uses `GetScratchBuffer` with `GetComputeStream()`. | Simple `GetComputeStream()` pattern — may work with adapter. **Low effort to try.** | +| **bert/remove_padding.cc** | Uses `GetScratchBuffer` with `GetComputeStream()`. | Simple `GetComputeStream()` pattern — may work with adapter. **Low effort to try.** | +| **diffusion/group_norm.cc** | Uses `CudaTuningContext*` and `Stream*` in the `DispatchGroupNorm` helper. | Guard tuning path with `#ifndef BUILD_CUDA_EP_AS_PLUGIN`, change stream parameter. Moderate effort. | +| **fused_conv.cc** | Uses `GetComputeStream()` for cuDNN workspace allocation. | Replace stream access with adapter-compatible calls. Moderate effort. | +| **inverse.cc** | Uses `GetScratchBuffer` with `GetComputeStream()`. cuBLAS batched operations. | Simple pattern — likely works with adapter. **Low effort to try.** | +| **math/bias_dropout.cc** | Uses `PhiloxGenerator` from `CudaStream` for RNG state. Also `GetComputeStream()`. | Needs `PhiloxGenerator` accessor in adapter. Blocked on RNG infrastructure. | +| **math/fft_ops.cc** | Uses `onnxruntime::Stream*` directly. cuFFT plan management. | Change stream access to adapter pattern. Moderate effort. | +| **math/gemm_float8.cc/.cu** | `ComputeInternal` is in `.cu` file which uses `GetComputeStream()`. `.cu` files don't receive the force-include adapter header. | Move `GetComputeStream()` usage to `.cc` file, or pass stream as parameter to `.cu` function. Moderate effort. | +| **moe/moe.cc** | Uses `GetComputeStream()`. MoE routing + expert computation. | Replace `context->GetComputeStream()` with adapter-compatible calls. Moderate effort. | +| **sparse/sparse_attention.cc** | Uses `onnxruntime::Stream*`. Sparse attention kernel dispatch. | Same stream pattern fix. Moderate effort. | +| **tensor/shrunken_gather.cc** | Training op — includes `provider_api.h` in header. `ENABLE_TRAINING_OPS` guard. | **Permanently excluded** (training op, not needed for inference plugin). | +| **tensor/crop.cc** | `CropBase` constructor templatized on `KernelInfoType` (#27628). No `GetComputeStream()` usage. | Verify compilation with exclusion removed — constructor refactoring complete. **Very low effort.** | +| **tensor/dynamic_time_warping.cc** | Uses `GetScratchBuffer` with `GetComputeStream()`. | Simple pattern — likely works with adapter. **Low effort to try.** | +| **tensor/dynamicslice.cc** | Uses `onnxruntime::Stream*` via `GetComputeStream()`. | Simple pattern — likely works with adapter. **Low effort to try.** | +| **quantization/attention_quantization.cc** | Uses `GetScratchBuffer` with `GetComputeStream()`, calls `QkvToContext`. | Same fix as `attention.cc`. Moderate-high effort. | +| **quantization/matmul_bnb4.cc** | Uses `GetScratchBuffer` and `GetComputeStream()->GetHandle()`. | Adapter should handle this pattern. **Low effort to try.** | +| **quantization/matmul_nbits.cc** | Uses `GetScratchBuffer` with `GetComputeStream()` and `GetHandle()`. | Adapter should handle this pattern. **Low effort to try.** | +| **quantization/moe_quantization.cc** | Uses `GetComputeStream()`. Quantized MoE pipeline. | Same as `moe.cc`. Moderate effort. | +| **quantization/qordered_ops/\*** | Ordered quantization ops with framework dependencies. | Needs investigation. Low priority. | + +### Operators Successfully Brought to Plugin EP (Reference) + +These were previously excluded and are now included thanks to adapter +compatibility. Listed here as examples of the fix patterns applied. + +| File | Fix Applied | Stage | +|------|-------------|-------| +| `tensor/reshape.cc` | `CopyTensor` replaced with explicit `cudaMemcpyAsync` on kernel stream (#27719). | 5A | +| `tensor/concat.cc` | `InputArgCount`/`GetComputeStream` usage works through adapter `OpKernelContext`. | 5A | +| `tensor/split.cc` | `GetComputeStream` usage works via `CudaKernel::GetComputeStream`. | 5A | +| `tensor/gather.cc` | Switched to `GatherBase::PrepareForComputeImpl` compatible with adapter context. | 5B | +| `tensor/gather_nd.cc` | `PrepareCompute` signature changed to `void*`/`cudaStream_t`. | 5 | +| `tensor/unsqueeze.cc` | Plugin-local `PrepareCompute` path added for adapter context. | 5B | +| `tensor/tile.cc` | Plugin-local `IsTileMemcpy` helper added. | 5B | +| `math/cumsum.cc` | Axis parsing helper inlined for plugin build. | 5B | +| `tensor/scatter_nd.cc` | `ValidateShapes` inlined for plugin; `GetComputeStream` fixed. | 5 | +| `tensor/pad.cc` | Plugin-local wrappers for `PadBase` static helpers. | 5C.2 | +| `tensor/slice.cc` | Plugin-local wrappers for `SliceBase::PrepareForCompute`/`FlattenOutputDims`. | 5C.3 | +| `math/variadic_elementwise_ops.cc` | Adapter `InputCount`/`RequiredInput`/`RequiredOutput` supported. | 5C | +| `math/matmul.cc` | `GetComputeStream` fixed; `GetTuningContext` guarded. | 5 | +| `math/matmul_integer.cc` | `GetComputeStream` fixed; `GemmInt8` signature updated. | 5 | +| `math/integer_gemm.cc` | `dynamic_cast` replaced with stream-based `GetCublasHandle()` overload (#27719). | 5 | +| `contrib/math/fused_matmul.cc` | Included after `matmul.cc` was fixed. | 5 | + +## Priority Recommendations + +### High Priority (Common ops, likely low effort) + +These excluded ops use simple `GetComputeStream()` patterns that the adapter +already supports. They should be tried first. Ops marked with (✓) have had +their CPU helper dependencies fully refactored and are ready for build +verification: + +- `contrib/bert/embed_layer_norm.cc` (✓ helper refactored #27617) +- `contrib/bert/fast_gelu.cc` (✓ helper refactored #27617) +- `contrib/bert/relative_attn_bias.cc` +- `contrib/bert/remove_padding.cc` +- `contrib/tensor/crop.cc` (✓ constructor templatized #27628) +- `contrib/tensor/dynamic_time_warping.cc` +- `contrib/tensor/dynamicslice.cc` +- `contrib/inverse.cc` +- `contrib/quantization/matmul_bnb4.cc` +- `contrib/quantization/matmul_nbits.cc` +- `generator/constant_of_shape.cc` (already has `#ifdef` path) + +### Medium Priority (Moderate refactoring needed) + +- `tensor/space_depth_ops.cc` — constructor templatized; remaining validation to inline +- `contrib/diffusion/group_norm.cc` — guard tuning context +- `contrib/moe/moe.cc` — fix stream access +- `contrib/fused_conv.cc` — fix stream access +- `contrib/math/fft_ops.cc` — fix stream access +- `contrib/math/gemm_float8.cc/.cu` — move stream access to `.cc` + +### Low Priority (Significant effort or niche) + +- `tensor/upsample.cc` + `tensor/resize.cc` — `AdjustOutputSizeAsPolicy` moved to header; `InputDefs()`/`GetAllocator()` still needed +- `rnn/*` — blocked on C API string array extension +- `llm/*` + `bert/attention*.cc` family — deep attention pipeline changes +- `math/einsum.cc` — complex CPU base class +- `object_detection/*` — base classes partially refactored; integration verification needed +- `transformers/*` — subgraph execution, very high effort + +### Permanently Excluded + +- `tensor/size.cc`, `tensor/shape_op.cc` — pure CPU ops +- `aten_ops/*` — PyTorch dependency +- `collective/*` — distributed runtime +- `tensor/shrunken_gather.cc` — training only +- Infrastructure files — replaced by `plugin/` equivalents diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md new file mode 100644 index 0000000000000..d07c1216ff5c9 --- /dev/null +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -0,0 +1,709 @@ +# CUDA Plugin EP — Design Document + +## 1. Overview + +The CUDA Plugin EP is an alternative build of the ONNX Runtime CUDA Execution Provider that compiles as a standalone shared library (`libonnxruntime_providers_cuda_plugin.so`). It loads at runtime through the ORT EP Plugin API instead of being statically linked into the main runtime binary. + +**Goals:** +- Allow CUDA EP updates independent of ORT core releases +- Support all operators currently supported by the in-tree CUDA EP (tunable ops are low priority) +- Minimize changes to existing CUDA kernel source files + +**Current status:** ~80% of CUDA kernels compile in the plugin build. Excluded operators are documented in [Section 7](#7-excluded-operators). + +--- + +## 2. Architecture + +### 2.1 Build Targets + +The ORT CUDA build produces four separate libraries: + +| Target | Output | Type | Description | +|--------|--------|------|-------------| +| `onnxruntime_providers` | `libonnxruntime_providers.a` | Static lib | CPU provider + framework ops | +| `onnxruntime_providers_shared` | `libonnxruntime_providers_shared.so` | Shared lib | DLL-boundary bridge for in-tree EPs | +| `onnxruntime_providers_cuda` | `libonnxruntime_providers_cuda.so` | Shared module | In-tree CUDA EP (uses `SHARED_PROVIDER` bridge) | +| `onnxruntime_providers_cuda_plugin` | `libonnxruntime_providers_cuda_plugin.so` | Shared module | Plugin CUDA EP (uses EP API adapters) | + +### 2.2 Preprocessor Defines + +Each build target uses different preprocessor defines that control how framework types are resolved: + +| Define | Set In | Purpose | +|--------|--------|---------| +| `SHARED_PROVIDER` | `onnxruntime_providers_shared`, `onnxruntime_providers_cuda` | Activates the DLL-boundary proxy types in `provider_api.h` | +| `BUILD_CUDA_EP_AS_PLUGIN` | `onnxruntime_providers_cuda_plugin` | Makes `provider_api.h` a no-op; activates plugin-specific code paths | +| `ORT_USE_EP_API_ADAPTERS` | `onnxruntime_providers_cuda_plugin` | Enables the EP adapter type aliases (`ep/adapters.h`) | +| `ORT_API_MANUAL_INIT` | `onnxruntime_providers_cuda_plugin` | Manual ORT API initialization in plugin DLL | + +### 2.3 Class Hierarchy + +``` +OrtEpFactory OrtEp + ↑ ↑ +CudaEpFactory adapter::Ep (holds unique_ptr) + │ ↑ + │ CudaEp + │ │ + │ └─ owns ──→ CUDAExecutionProvider + │ (: IExecutionProvider) + │ ├─ config members + │ ├─ device properties + │ └─ stream→handle map + │ + └─ creates ──→ CudaSyncStream (owns cublasHandle_t, cudnnHandle_t, cublasLtHandle_t) +``` + +Key ownership relationships: +- `CudaEpFactory` creates `CudaEp` instances and `CudaSyncStream` objects +- `CudaEp` inherits from `ep::adapter::Ep` and owns a `CUDAExecutionProvider` instance (accessible via `EpImpl()`) +- `CUDAExecutionProvider` is a plugin-local class (not the framework one) that inherits from `IExecutionProvider` and provides the full API surface CUDA kernels need +- `CudaSyncStream` owns CUDA/cuBLAS/cuDNN handles per stream + +### 2.4 Plugin DLL Entry Points + +The plugin exports exactly two C symbols: +- `CreateEpFactories()` — called by ORT to create the EP factory +- `ReleaseEpFactory()` — called by ORT to destroy the factory + +All other symbols have hidden visibility. + +--- + +## 3. Type Resolution — How Kernel Code Compiles Unchanged + +The core design principle is that existing CUDA kernel `.cc` files compile in the plugin build with **zero or minimal source changes**. This is achieved through a two-layer force-include mechanism. + +### 3.1 Force-Include Chain + +For every `.cc` file in the plugin build, CMake injects two force-includes before any source code: + +``` +1. ep/adapters.h — adapter type aliasing +2. cuda_kernel_adapter.h — CudaKernel base class, macros, CPU shims +``` + +Note: `.cu` files do NOT receive force-includes (conflicts with CUTLASS/cute). They must include `cuda_kernel_adapter.h` explicitly if needed. + +### 3.2 Adapter Type Aliasing (`ep/adapters.h`) + +`ep/adapters.h` defines `using` aliases in both `onnxruntime::cuda` and `onnxruntime::contrib::cuda` namespaces: + +```cpp +namespace onnxruntime::cuda { + using OpKernel = ep::adapter::OpKernel; + using OpKernelContext = ep::adapter::OpKernelContext; + using OpKernelInfo = ep::adapter::OpKernelInfo; + using KernelRegistry = ep::adapter::KernelRegistry; + using KernelDefBuilder = ep::adapter::KernelDefBuilder; + using DataTransferManager = ep::adapter::DataTransferManager; + // ... etc +} +``` + +When kernel code in `namespace onnxruntime::cuda` references `OpKernelContext`, it resolves to the adapter type instead of the framework type. **No kernel source changes needed.** + +### 3.3 Provider API Bypass + +In the plugin build, `provider_api.h` (normally included from `cuda_common.h`) is a **complete no-op** — it does NOT define `SHARED_PROVIDER`. This means: + +- `#ifndef SHARED_PROVIDER` guards in framework headers remain **active**, exposing real types +- Header-inlined utility methods (see [Section 4](#4-cpu-base-class-helpers)) get their inline bodies +- The `ProviderHostCPU` virtual table bridge is bypassed entirely + +### 3.4 Kernel Adapter (`cuda_kernel_adapter.h`) + +This 700+ line header provides everything CUDA kernels need that would normally come from framework infrastructure: + +| Section | What It Provides | +|---------|-----------------| +| Error macros | `CUDA_RETURN_IF_ERROR`, `CUBLAS_RETURN_IF_ERROR`, `CUDNN_RETURN_IF_ERROR` | +| Type mappings | `ToCudaType::MappedType = half`, etc. | +| CudaKernel base | Scratch buffers, handle access, `Stream()`, `GetComputeStream()` | +| Kernel registration | Self-registering `ONNX_OPERATOR_*_KERNEL_EX` macro overrides via `PluginKernelCollector` | +| CPU shims | Lightweight reimplementations of CPU helpers not linked into plugin | +| Math helpers | `HalfGemmOptions`, `CublasMathModeSetter` | +| Stream shim | `PluginStreamShim` wrapping raw `cudaStream_t` as `onnxruntime::Stream*` | + +### 3.5 Kernel Registration + +In the in-tree build, kernels register through centralized tables (`cuda_nhwc_kernels.cc`, `cuda_contrib_kernels.cc`). In the plugin build, the `ONNX_OPERATOR_*_KERNEL_EX` macros are overridden to auto-register each kernel into the `PluginKernelCollector` singleton at static initialization time: + +```cpp +// Macro override generates: +// 1. BuildKernelCreateInfo() function +// 2. Static PluginKernelCollector::Register() call + +// At plugin startup, CreateCudaKernelRegistry() iterates the collector +// and registers each kernel into an adapter::KernelRegistry. +``` + +--- + +## 4. CPU Base Class Helpers — The SHARED_PROVIDER Pattern + +Many CUDA kernels inherit from CPU base classes and call utility methods (e.g., `PadBase::HandleDimValueZero`, `SliceBase::PrepareForCompute`). In the in-tree build, these call across the DLL boundary through `ProviderHostCPU`. The plugin doesn't use this bridge. + +### 4.1 Pattern: Inline in Header + +The primary approach moves pure-computation helpers from CPU `.cc` files to headers: + +```cpp +// In padbase.h: +#ifdef SHARED_PROVIDER + // In-tree build: declaration only, body in ProviderHostCPU bridge + static void HandleDimValueZero(Mode mode, const TensorShape& input_shape, TensorShape& output_shape); +#else + // Plugin build + CPU provider: inline body + static inline void HandleDimValueZero(Mode mode, const TensorShape& input_shape, + TensorShape& output_shape) { + // ... implementation ... + } +#endif +``` + +**Files refactored with this pattern:** +- `padbase.h` — `HandleDimValueZero`, `ComputePads` (delegates to `ComputePadsImpl` template) +- `scatter_nd.h` — `ValidateShapes` +- `split.h` — `PrepareForCompute` +- `tile.h` — `IsTileMemcpy` +- `slice.h` — `PrepareForCompute`, `FlattenOutputDims` +- `cumsum.h` — `cumsum_op::GetAxis` +- `bias_gelu_helper.h` — `bias_gelu_helper::CheckInputs` +- `concatbase.h` — `PrepareForCompute` +- `gatherbase.h` — `PrepareForCompute`/`PrepareForComputeImpl` (template) +- `unsqueeze.h` — `PrepareCompute` +- `embed_layer_norm_helper.h` — `embed_layer_norm::CheckInputs` (templatized on context type) +- `non_max_suppression_helper.h` — `NonMaxSuppressionBaseImpl` template class (new file) +- `attention_base.h` — `AttentionBase::CheckInputs`, `CheckMask`, `GetPresent` (templatized on context type) +- `longformer_attention_base.h` — `LongformerAttentionBase::CheckInputs` +- `roialign.h` — `CheckROIAlignValidInput`, `RoiAlignBase` constructor (templatized on info type) +- `upsamplebase.h` — `UpsampleBase::AdjustOutputSizeAsPolicy` +- `crop.h` — `CropBase` constructor (templatized on info type) +- `space_depth_ops.h` — `SpaceDepthBase` constructor (templatized on info type) +- `clip.h` — Clip min/max attribute handling (removed `Clip_6Base` CPU dependency) +- `cuda_common_type_helpers.h` — CUDA type conversion and handle error string helpers (moved from `cuda_common.cc`) + +### 4.2 Pattern: Template Methods + +For methods that take `OpKernelContext&` (which differs between plugin and in-tree builds), template versions accept any context type: + +```cpp +// In padbase.h: +template +static void ComputePadsImpl(KernelContextType& ctx, size_t data_rank, + gsl::span pads_data, PadsVector& pads) { ... } +``` + +The CUDA kernel calls `PadBase::ComputePadsImpl(*ctx, ...)` directly, avoiding the `OpKernelContext&` type mismatch. + +The same pattern is applied to constructors that receive `OpKernelInfo`: + +```cpp +// In roialign.h: +template +RoiAlignBase(const TKernelInfo& info) { + info.template GetAttr("mode", &mode_string); + info.template GetAttr("output_height", &output_height_); + // ... +} +``` + +This allows the base class constructor to work with both the framework `OpKernelInfo` and the plugin adapter's `OpKernelInfo`. Applied to: `RoiAlignBase`, `CropBase`, `SpaceDepthBase` (#27628). + +### 4.3 Files That Cannot Be Inlined + +Some CPU base classes have heavy dependencies (protobuf, `UnpackTensor`) that make inlining impractical: + +- **`ConstantOfShapeBase`** — depends on `TensorProto` and `UnpackTensor`. Plugin uses a self-contained duplicate class in `constant_of_shape.h` guarded by `#ifdef BUILD_CUDA_EP_AS_PLUGIN`. +- **`UpsampleBase`** — partially addressed: `AdjustOutputSizeAsPolicy` moved to header (#27628). Still depends on `InputDefs()` and `OpKernelInfo::GetAllocator()` which are not in the adapter. + +--- + +## 5. Handle and Stream Management + +### 5.1 Stream Ownership + +`CudaSyncStream` is the plugin's CUDA stream implementation: +- Owns `cudaStream_t`, `cublasHandle_t`, `cudnnHandle_t`, `cublasLtHandle_t` +- Created by `CudaEpFactory::CreateSyncStreamForDevice` +- Registered with `CUDAExecutionProvider` for handle lookup + +### 5.2 Handle Access Path + +``` +CudaKernel::GetCublasHandle(OpKernelContext* ctx) + → Stream(ctx) // raw cudaStream_t from ctx + → CUDAExecutionProvider::GetActiveProvider() // static pointer to active EP + → provider->GetCublasHandle(cudaStream_t) // stream→handle map lookup +``` + +The `CUDAExecutionProvider` maintains a `std::unordered_map` for handle lookups. + +### 5.3 Provider Access + +Kernels access the provider through two paths: +1. **`CudaKernel::provider_`** — set in the constructor from `info.GetExecutionProvider()` +2. **`CUDAExecutionProvider::GetActiveProvider()`** — static atomic pointer (for `.cu` code that doesn't have a `CudaKernel` instance) + +### 5.4 CUDA Graph Support + +#### 5.4.1 How CUDA Graph Works in Bundled CUDA EP + +CUDA Graph capture/replay in ORT is a **cooperative protocol** between the ORT session framework and the execution provider. Understanding this protocol is critical for the plugin EP design. + +**Session-level orchestration** (`inference_session.cc`): + +1. During session initialization, if an EP reports `IsGraphCaptureEnabled() == true` and all graph nodes are assigned to that EP (plus allowed CPU shape nodes), the session caches a pointer to the EP in `cached_execution_provider_for_graph_replay_`. + +2. At `Run()` time, the session checks `cached_execution_provider_for_graph_replay_.IsGraphCaptured(annotation_id)`: + - **If captured**: The session **skips the entire kernel dispatch pipeline** — no `OnRunStart`, no executor, no `OnRunEnd` — and calls `ReplayGraph(annotation_id)` directly. This is the fast path. + - **If not yet captured**: The session runs the normal kernel dispatch pipeline (including `OnRunStart` → executor → `OnRunEnd`), which allows the EP to manage warm-up counting and trigger capture. + +3. After each normal run, the session checks if graph capture is enabled but not yet captured, and **recursively calls `Run()`** to accumulate the required warm-up runs and trigger capture — so from the user's perspective, a single `Run()` call handles the entire warm-up + capture sequence. + +**EP-level capture** (`CUDAExecutionProvider`): + +- `OnRunStart()`: If warm-up is complete and graph not yet captured, calls `cudaStreamBeginCapture()`. +- `OnRunEnd()`: If capturing, calls `cudaStreamEndCapture()` + `cudaGraphInstantiate()` + first `Replay()` (since captured kernels don't execute on GPU during capture). +- `IsGraphCaptureEnabled()`: Returns `true` if `enable_cuda_graph` provider option is set. +- `IsGraphCaptured(annotation_id)`: Returns `true` if a graph has been captured for this annotation. +- `ReplayGraph(annotation_id)`: Calls `cudaGraphLaunch()` for the stored `cudaGraphExec_t`. + +The key insight is that the **session-level replay bypass** (`ReplayGraph()` without kernel dispatch) is what makes CUDA Graph efficient. Without it, the EP can capture a graph but can never replay it efficiently — kernels would still be dispatched by the executor on every run. + +``` +Session::Run() + ├── [Graph captured?] ──YES──→ ep->ReplayGraph(id) ──→ return ← FAST PATH + │ + └── [Not captured] ──→ OnRunStart() → executor dispatches kernels → OnRunEnd() + │ │ + │ (EP begins cudaStreamBeginCapture) │ (EP ends capture, first replay) + │ │ + └──────── Session recurses if warmup needed ─┘ +``` + +#### 5.4.2 Current Plugin EP Behavior — API Gap + +The `OrtEp` C API (`onnxruntime_ep_c_api.h`) provides `OnRunStart` and `OnRunEnd` callbacks but **does not include**: +- `IsGraphCaptureEnabled()` +- `IsGraphCaptured(annotation_id)` +- `ReplayGraph(annotation_id)` + +The `PluginExecutionProvider` bridge (`ep_plugin_provider_interfaces.cc`) does not override these `IExecutionProvider` virtual methods, so they return the base class defaults (`false`, `false`, `Status::OK()`). + +**Consequence**: The session's `cached_execution_provider_for_graph_replay_` is never set for the plugin EP. The session-level replay bypass **never activates**. Even after the plugin captures a CUDA graph via `OnRunStart`/`OnRunEnd`, subsequent runs still go through the full kernel dispatch pipeline — the captured graph sits unused. + +The current plugin implementation has a partial mitigation: it captures the graph and replays it once (in `OnRunEnd` after capture). But on subsequent runs, `OnRunEnd` sees the graph is already captured and does nothing. + +#### 5.4.3 Revised Design — Remove EP-Level Graph Management + +Given the API gap, the correct design for the plugin EP is: + +> **The plugin EP should NOT manage CUDA graph capture/replay internally.** CUDA graph support requires session-level cooperation that is not available through the current `OrtEp` C API. + +**Rationale:** + +1. The `OrtEp` C API has no `IsGraphCaptureEnabled`/`IsGraphCaptured`/`ReplayGraph` callbacks. Without these, the session cannot know that the EP supports graph capture, cannot bypass kernel dispatch for replay, and cannot trigger the recursive warm-up sequence. + +2. Implementing capture in `OnRunStart`/`OnRunEnd` without the session-level replay bypass is **incorrect** — the captured graph would never be replayed on subsequent runs (the session always dispatches kernels normally). + +3. The session's graph validation logic (all nodes on one EP, no control flow) is also not triggered without `IsGraphCaptureEnabled()`. + +**Recommended approach:** + +| Option | Description | Effort | Status | +|--------|------------|--------|--------| +| **A. Extend the OrtEp C API** | Add `IsGraphCaptureEnabled`, `IsGraphCaptured`, `ReplayGraph` to `OrtEp`. Update `PluginExecutionProvider` to delegate to these. | Medium — requires ORT core changes | Preferred long-term solution | +| **B. Disable graph capture in plugin EP** | Remove `CUDAGraphManager` and graph-related code from the plugin. Document as a known limitation. Re-enable when Option A is available. | Small | Recommended for now | +| **C. Keep capture-only (no replay)** | Keep the current code but document that it only captures + replays once (the first time), with no subsequent replay optimization. | None | Misleading — gives false confidence | + +**Recommendation**: Option B for the current release, with Option A tracked as a public API enhancement request. + +#### 5.4.4 What Needs to Change in ORT Core (Option A) + +To enable full CUDA graph support for plugin EPs, the `OrtEp` struct needs three new optional callbacks: + +```c +// Proposed additions to OrtEp (onnxruntime_ep_c_api.h) +struct OrtEp { + // ... existing fields ... + + /// Returns true if CUDA graph capture is enabled for this EP. + /// If nullptr, defaults to false. + ORT_API2_STATUS(IsGraphCaptureEnabled, _In_ const OrtEp* this_ptr, _Out_ bool* enabled); + + /// Returns true if a graph has been captured for the given annotation ID. + /// If nullptr, defaults to false. + ORT_API2_STATUS(IsGraphCaptured, _In_ const OrtEp* this_ptr, + _In_ int graph_annotation_id, _Out_ bool* captured); + + /// Replay a previously captured graph. + /// If nullptr, returns OK (no-op). + ORT_API2_STATUS(ReplayGraph, _In_ OrtEp* this_ptr, _In_ int graph_annotation_id); +}; +``` + +The `PluginExecutionProvider` bridge would then delegate these to the plugin: + +```cpp +// In ep_plugin_provider_interfaces.cc +bool PluginExecutionProvider::IsGraphCaptureEnabled() const { + if (ort_ep_->IsGraphCaptureEnabled == nullptr) return false; + bool enabled = false; + auto* status = ort_ep_->IsGraphCaptureEnabled(ort_ep_.get(), &enabled); + // handle status... + return enabled; +} +``` + +This would plug into the existing `cached_execution_provider_for_graph_replay_` mechanism in `InferenceSession` with no other session-level changes needed. + +#### 5.4.5 Current State + +| Component | Status | Notes | +|-----------|--------|-------| +| `cuda_graph_plugin.h/.cc` | Implemented | `CUDAGraphManager` adapted from bundled EP. Captures/replays correctly. | +| `CudaEp::OnRunStartImpl` | Implemented | Reads `gpu_graph_id`, manages warm-up, begins capture. | +| `CudaEp::OnRunEndImpl` | Implemented | Ends capture, first replay. No subsequent replay. | +| Session-level replay bypass | **Not functional** | `OrtEp` API lacks `IsGraphCaptureEnabled`/`IsGraphCaptured`/`ReplayGraph`. | +| Tests | Pass (capture + first replay) | `test_cuda_plugin_cuda_graph()` tests warm-up, capture, and `gpu_graph_id=-1` disable. | + +**Action items:** +1. Keep `cuda_graph_plugin.h/.cc` and `CudaEp` graph state machine code — it is correct and will be needed when the API gap is closed. +2. Default `enable_cuda_graph` to `false` in the plugin EP config and document the limitation. +3. File an ORT core feature request to add `IsGraphCaptureEnabled`/`IsGraphCaptured`/`ReplayGraph` to the `OrtEp` C API. +4. When the API is extended, wire up the existing `CUDAGraphManager` through the new callbacks. + +--- + +## 6. EP Adapter Layer (`include/onnxruntime/ep/adapter/`) + +The adapter layer provides thin wrappers around the ORT C API that present a C++ interface matching the framework types: + +| Adapter Class | Wraps | Key Methods | +|---------------|-------|-------------| +| `OpKernel` | `OrtKernelImpl` | `Compute()`, `PrePack()` | +| `OpKernelContext` | `OrtKernelContext` | `Input()`, `Output()`, `InputCount()`, `GetGPUComputeStream()`, `GetComputeStream()` | +| `OpKernelInfo` | `Ort::ConstKernelInfo` | `GetAttr()`, `GetExecutionProvider()`, `TryGetConstantInput()`, `GetDataTransferManager()` | +| `KernelRegistry` | `Ort::KernelRegistry` | `Register(KernelCreateInfo&&)` | +| `KernelDefBuilder` | `Ort::KernelDefBuilder` | `TypeConstraint()`, `InputMemoryType()`, `SetName()` | +| `Ep` | `OrtEp` | `EpImpl()`, allocators, data transfer | +| `Logger` | Plugin logger | Logging interface | +| `DataTransferManager` | `IDataTransfer` | `CopyTensor()` | + +--- + +## 7. Excluded Operators + +The following operators are excluded from the plugin build. Each exclusion has a specific technical reason and a path to inclusion. + +### 7.1 Infrastructure (Permanently Excluded — Replaced by Plugin Equivalents) + +| File | Reason | +|------|--------| +| `cuda_execution_provider.cc` | Replaced by `cuda_ep_provider.h` + `cuda_ep.h` | +| `cuda_provider_factory.cc` | Replaced by `cuda_ep_factory.cc` | +| `cuda_provider_interface.cc` | Not needed in plugin architecture | +| `cuda_stream_handle.cc` | Replaced by `cuda_stream_plugin.cc` | +| `cuda_execution_provider_info.cc` | Config parsed directly in `CudaEp::Config` | +| `cuda_graph.cc` | Replaced by `cuda_graph_plugin.cc` | +| `cuda_mempool_arena.cc` | Plugin uses `cudaMalloc`/`cudaFree` directly | +| `cuda_common.cc` | Utility functions shimmed in `cuda_kernel_adapter.h` | +| `cuda_nhwc_kernels.cc` | Replaced by `PluginKernelCollector` auto-registration | +| `cuda_contrib_kernels.cc` | Replaced by `PluginKernelCollector` auto-registration | + +### 7.2 Pure CPU Ops (Permanently Excluded) + +| File | Reason | +|------|--------| +| `tensor/size.cc` | Pure CPU op, handled by `GetCpuPreferredNodes` | +| `tensor/shape_op.cc` | Pure CPU op, inherits from `onnxruntime::OpKernel` (framework) | + +### 7.3 Operators Excluded Due to Missing Features + +| File | Exclusion Reason | What's Needed to Include | +|------|-----------------|--------------------------| +| `controlflow/*` | CPU base class If/Loop/Scan not linked | Plugin has own wrappers in `cuda_controlflow_plugin.cc` via `OrtEpApi`. Already functional. | +| `tunable/*` | Depends on real `CudaTuningContext` | Implement plugin-side `ITuningContext` that delegates to ORT tuning APIs. Low priority. | +| `rnn/*` | ORT C API lacks `KernelInfoGetAttributeArray_string` | Extend C API with string-array attribute support. | +| `math/einsum.cc`, `math/einsum_utils/*` | `einsum_auxiliary_ops.cc` calls `ReductionOps::ReduceCompute` (framework-only) | Extract `ReduceCompute` into a shared interface or reimplement the reduction path. | +| `tensor/identity_op.cc` | Uses `TensorSeq` (incomplete type in plugin) | Add `TensorSeq` adapter to the EP adapter layer. | +| `tensor/sequence_op.cc` | Uses `TensorSeq` (incomplete type in plugin) | Same as above. | +| `tensor/space_depth_ops.cc` | Inherits `SpaceDepthBase` (CPU provider) | Constructor templatized on `KernelInfoType` (#27628). Remaining: inline `SpaceDepthCompute` validation logic. | +| `tensor/upsample.cc` | `UpsampleBase` uses `InputDefs()` and `OpKernelInfo::GetAllocator()` | `AdjustOutputSizeAsPolicy` moved to header (#27628). Remaining: extend adapter with `GetAllocator()` and `InputDefs()`. | +| `tensor/resize.cc` | Inherits from `Upsample` (excluded above) | Fix `Upsample` first, then `Resize` follows. | +| `generator/constant_of_shape.cc` | `ConstantOfShapeBase` depends on `TensorProto`/`UnpackTensor` | Plugin already has self-contained implementation in `constant_of_shape.h` via `#ifdef BUILD_CUDA_EP_AS_PLUGIN`. The `.cc` is excluded but the kernel works. | +| `object_detection/*` | `NonMaxSuppressionBase`, `RoiAlignBase` from CPU provider | `NonMaxSuppressionBaseImpl` template (#27617), `RoiAlignBase` constructor templatized (#27628). Remaining: integration verification. | +| `llm/*` | Attention ops dereference `onnxruntime::Stream*` (not adapter-compatible) | Extend adapter `OpKernelContext::GetComputeStream()` to return a full `Stream*` implementation. | +| `contrib_ops/cuda/llm/*` | Same as above | Same as above. | +| `contrib_ops/cuda/bert/attention.cc` | `GetComputeStream()` returns real `Stream*` which is needed | `AttentionBase` helpers moved to header (#27628). Remaining: `Stream*` adapter extension for `QkvToContext`. | +| `contrib_ops/cuda/bert/decoder_attention.cc` | Same | Same. | +| `contrib_ops/cuda/bert/decoder_masked_self_attention.cc` | Same | Same. | +| `contrib_ops/cuda/bert/embed_layer_norm.cc` | `EmbedLayerNormHelper` CPU base class | Already refactored helper; needs `GetComputeStream()` fix. | +| `contrib_ops/cuda/bert/fast_gelu.cc` | Was excluded due to `bias_gelu_helper` CPU base class dep | `bias_gelu_helper::CheckInputs` is now inlined. Remove this exclusion and verify. | +| `contrib_ops/cuda/bert/group_query_attention.cc` | `GetComputeStream()` / attention infra | Same `Stream*` adapter extension. | +| `contrib_ops/cuda/bert/longformer_attention.cc` | `LongformerAttentionBase::CheckInputs` moved to header (#27628) | `Stream*` adapter extension. | +| `contrib_ops/cuda/bert/multihead_attention.cc` | Same | Same. | +| `contrib_ops/cuda/bert/packed_attention.cc` | Same | Same. | +| `contrib_ops/cuda/bert/packed_multihead_attention.cc` | Same | Same. | +| `contrib_ops/cuda/bert/paged_attention.cc` | Same | Same. | +| `contrib_ops/cuda/bert/relative_attn_bias.cc` | Same | Same. | +| `contrib_ops/cuda/bert/remove_padding.cc` | Same | Same. | +| `contrib_ops/cuda/diffusion/group_norm.cc` | `GetComputeStream()` | Same `Stream*` adapter extension. | +| `contrib_ops/cuda/fused_conv.cc` | Framework type deps | Audit specific deps; likely `Stream*` related. | +| `contrib_ops/cuda/inverse.cc` | Framework type deps | Audit specific deps. | +| `contrib_ops/cuda/math/bias_dropout.cc` | `GetComputeStream()` | Same `Stream*` adapter extension. | +| `contrib_ops/cuda/math/fft_ops.cc` | `GetComputeStream()` | Same. | +| `contrib_ops/cuda/math/gemm_float8.cc`/`.cu` | `GetComputeStream()` in `.cu` file | Same, plus NVCC compatibility. | +| `contrib_ops/cuda/moe/moe.cc` | `GetComputeStream()` | Same. | +| `contrib_ops/cuda/sparse/sparse_attention.cc` | `GetComputeStream()` | Same. | +| `contrib_ops/cuda/tensor/crop.cc` | `CropBase` constructor templatized (#27628). No `GetComputeStream()` usage. | Verify compilation — very low effort. | +| `contrib_ops/cuda/tensor/dynamic_time_warping.cc` | `GetComputeStream()` | Same. | +| `contrib_ops/cuda/tensor/dynamicslice.cc` | `GetComputeStream()` | Same. | +| `contrib_ops/cuda/tensor/shrunken_gather.cc` | Training op, `provider_api.h` header dep | Low priority (training). | +| `contrib_ops/cuda/quantization/attention_quantization.cc` | `GetComputeStream()` | Same. | +| `contrib_ops/cuda/quantization/matmul_bnb4.cc` | `GetComputeStream()` | Same. | +| `contrib_ops/cuda/quantization/matmul_nbits.cc` | `GetComputeStream()` | Same. | +| `contrib_ops/cuda/quantization/moe_quantization.cc` | `GetComputeStream()` | Same. | +| `contrib_ops/cuda/quantization/qordered_ops/*` | `GetComputeStream()` | Same. | +| `contrib_ops/cuda/transformers/*` | Beam search, greedy search, sampling | Complex framework deps; needs significant adapter work. | +| `aten_ops/*` | ATen interop | Out of scope for plugin. | +| `collective/*` | NCCL collective ops | Out of scope for plugin. | + +### 7.4 Common Exclusion Themes + +The majority of excluded operators fall into a few categories: + +1. **`GetComputeStream()` returning `onnxruntime::Stream*`** (~25 ops) — The adapter's `GetComputeStream()` returns a `PluginCudaComputeStreamShim` which wraps a raw `cudaStream_t`. Many attention/LLM ops dereference `Stream*` expecting a `CudaStream` with extra members. **Unblocking this is the single highest-impact change.** + +2. **CPU base class inheritance** (~5 ops) — Some ops inherit from CPU base classes not linked into the plugin. Most have been refactored with the inline-header pattern. `SpaceDepthBase` and `RoiAlignBase` constructors are now templatized (#27628); `NonMaxSuppressionBase` refactored to a template (#27617); `UpsampleBase::AdjustOutputSizeAsPolicy` moved to header (#27628). Remaining: `UpsampleBase` `InputDefs()`/`GetAllocator()`. + +3. **Missing C API features** (~2 ops) — RNN ops need string-array attribute support via the C API. + +4. **Framework-only code paths** (~3 ops) — Einsum's reduction path, tunable infrastructure. + +--- + +## 8. Remaining `#ifdef` Guards in Kernel Code + +After refactoring, only 6 files contain `BUILD_CUDA_EP_AS_PLUGIN` or `ORT_USE_EP_API_ADAPTERS` guards: + +| File | Guard | Purpose | Removable? | +|------|-------|---------|------------| +| `cuda_kernel.h` | Both | Three-way gate: plugin → adapter; in-tree → real CudaKernel | No — infrastructure | +| `cuda_common.h` | Both | Logging macros, error macros, `HalfGemmOptions` | No — infrastructure | +| `cuda_execution_provider.h` | `ORT_USE_EP_API_ADAPTERS` | Skip entire class in plugin build | No — infrastructure | +| `generator/constant_of_shape.h` | `BUILD_CUDA_EP_AS_PLUGIN` | Self-contained plugin implementation | No — can't inline `ConstantOfShapeBase` | +| `math/matmul.cc` | `ORT_USE_EP_API_ADAPTERS` | Guards `FuncManager` registration (tunable) | Only when tunable is supported | +| `math/gemm.cc` | `ORT_USE_EP_API_ADAPTERS` | Guards `FuncManager` registration (tunable) | Only when tunable is supported | + +All kernel-level `#ifdef` guards in operator `.cc` files have been eliminated through the inline-header refactoring pattern, except for `matmul.cc`, `gemm.cc` (tunable dispatch), and `constant_of_shape.h` (protobuf dependency). + +--- + +## 9. Building + +### 9.1 CMake Flag + +The plugin is enabled by setting `onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON`: + +```bash +sh build.sh --config Release --build_dir build/cuda --parallel --use_cuda \ + --cuda_version 12.8 --cuda_home /path/to/cuda \ + --cudnn_home /path/to/cudnn \ + --build_wheel --skip_tests \ + --cmake_generator Ninja \ + --enable_cuda_nhwc_ops \ + --cmake_extra_defines onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES="90" +``` + +Or using the existing `cuda.sh` convenience script: + +```bash +./cuda.sh --build --test_plugin # --test_plugin sets BUILD_CUDA_EP_AS_PLUGIN=ON +``` + +### 9.2 Impact on Other Build Targets + +The `onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON` flag has **no impact** on `libonnxruntime_providers_cuda.so` or `libonnxruntime_providers_shared.so`. It only: + +1. Adds the `onnxruntime_providers_cuda_plugin` target (producing `libonnxruntime_providers_cuda_plugin.so`) +2. Appends `"cuda-plugin-ep=1"` to the build info string (cosmetic) + +The in-tree CUDA EP and shared provider bridge are compiled identically regardless of this flag. A single build with the flag ON produces all four libraries — there is no need for separate build scripts or build directories. + +### 9.3 Plugin Independence + +`libonnxruntime_providers_cuda_plugin.so` is **fully self-contained**. It does not depend on `libonnxruntime_providers_cuda.so` or `libonnxruntime_providers_shared.so` at load time. It statically links against `onnxruntime_framework`, `onnxruntime_graph`, `onnxruntime_common`, `onnxruntime_mlas`, `onnxruntime_flatbuffers`, and links dynamically against CUDA/cuDNN/protobuf. Communication with the ORT runtime happens exclusively through the C API (`OrtApi`/`OrtEpApi`) passed at load time. + +### 9.4 Build Outputs + +After a successful build with the plugin flag ON, `build/cuda/Release/` contains: + +| File | Description | +|------|-------------| +| `libonnxruntime_providers.a` | CPU provider (static, linked into main binary) | +| `libonnxruntime_providers_shared.so` | Shared provider bridge (for in-tree CUDA EP) | +| `libonnxruntime_providers_cuda.so` | In-tree CUDA EP (uses shared provider bridge) | +| `libonnxruntime_providers_cuda_plugin.so` | Plugin CUDA EP (standalone, uses C API) | + +### 9.5 Deployment + +To use the plugin EP, copy the `.so` to the ORT Python package's `capi/` directory: + +```bash +cp build/cuda/Release/libonnxruntime_providers_cuda_plugin.so \ + $(python -c "import onnxruntime; print(onnxruntime.__path__[0])")/capi/ +``` + +The plugin is then available as `CudaPluginExecutionProvider` in session provider lists. + +--- + +## 10. Testing + +### 10.1 Test Script + +`onnxruntime/test/python/transformers/test_cuda_plugin_ep.py` provides multi-stage testing: + +| Stage | What It Tests | +|-------|---------------| +| Stage 2 | Basic ops: Add, MatMul, Gemm, Conv | +| Stage 3 | NHWC layout: Conv, BatchNorm, MaxPool, AveragePool | +| Stage 4 | CUDA Graph capture/replay | +| Stage 5A | Standard ops: Reshape, Split, Concat, Gather, Unsqueeze | +| Stage 5B | More ops: Tile, CumSum, ConstantOfShape, SpaceToDepth, Pad, Slice, Resize, Sum | +| Stage 5C | CPU base class ops: Upsample, DepthToSpace | +| Stage 5D | Contrib ops: FastGelu, BiasDropout, SkipLayerNorm | + +### 10.2 Running Tests + +After building and deploying the plugin (see [Section 9.5](#95-deployment)): + +```bash +# Run tests from /tmp to avoid module shadowing +cd /tmp +python /path/to/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py +``` + +### 10.3 Parity Report + +`tools/ci_build/cuda_plugin_parity_report.py` generates a report comparing registered kernels between the in-tree CUDA EP and the plugin EP, identifying gaps. + +--- + +## 11. How to Add a New Kernel to the Plugin + +### 11.1 If the kernel compiles as-is + +Most kernels that don't use `GetComputeStream()` (returning `Stream*`) or inherit from excluded CPU base classes will compile without changes. The force-include mechanism handles type resolution automatically. + +Just verify it's not in the exclusion list in `cmake/onnxruntime_providers_cuda_plugin.cmake`. + +### 11.2 If the kernel calls a CPU base class helper + +Apply the inline-header pattern: + +1. Move the helper implementation from the CPU `.cc` file to the `.h` file +2. Wrap with `#ifdef SHARED_PROVIDER` (declaration only) / `#else` (inline body) +3. In the CUDA kernel, call the base class method directly (remove any local wrappers) +4. Verify all 4 build targets compile + +Example from `cumsum.h`: +```cpp +namespace cumsum_op { +#ifdef SHARED_PROVIDER +Status GetAxis(const Tensor* axis_tensor, int64_t input_rank, int64_t& axis_out); +#else +inline Status GetAxis(const Tensor* axis_tensor, int64_t input_rank, int64_t& axis_out) { + // ... implementation ... +} +#endif +} +``` + +### 11.3 If the helper takes OpKernelContext& + +Use a template version that accepts any context type: + +```cpp +template +static void ComputePadsImpl(KernelContextType& ctx, ...) { ... } +``` + +The CUDA kernel calls `ComputePadsImpl(*ctx, ...)` directly. + +### 11.4 If the kernel uses GetComputeStream() + +Check whether the kernel actually dereferences the `Stream*` or just needs the raw `cudaStream_t`: + +- If it only needs `stream->GetHandle()` → use `Stream(ctx)` instead (returns `cudaStream_t`) +- If it dereferences `CudaStream*` members → the kernel is blocked until the `Stream*` adapter is extended + +### 11.5 If the kernel uses handle accessors + +Use the plugin-compatible overloads already in `CudaKernel`: + +```cpp +// Instead of: GetCublasHandle(ctx->GetComputeStream()) +// Use: GetCublasHandle(ctx) // or GetCublasHandle(Stream(ctx)) +``` + +--- + +## 12. File Layout + +``` +onnxruntime/core/providers/cuda/plugin/ +├── cuda_kernel_adapter.h # CudaKernel base, macros, CPU shims (force-included) +├── cuda_ep_provider.h # Plugin-local CUDAExecutionProvider +├── cuda_ep.h / .cc # CudaEp : adapter::Ep +├── cuda_ep_factory.h / .cc # CudaEpFactory : OrtEpFactory +├── cuda_plugin_ep.cc # DLL entry points (CreateEpFactories/ReleaseEpFactory) +├── cuda_plugin_kernels.h / .cu # Kernel registry creation +├── cuda_stream_plugin.h / .cc # CudaSyncStream (handles, notifications) +├── cuda_allocator_plugin.h / .cc # Device/pinned allocators +├── cuda_data_transfer_plugin.h / .cc # GPU↔CPU data transfer +├── cuda_controlflow_plugin.h / .cc / .cu # If/Loop/Scan wrappers +├── cuda_graph_plugin.h / .cc # CUDA Graph support +├── cuda_plugin_utils.h # Common macros, error handling +├── cuda_iallocator_plugin.h # IAllocator declarations +├── cuda_idata_transfer_plugin.h # IDataTransfer declarations +└── provider_api_shims.cc # Reimplemented utility functions + +include/onnxruntime/ep/ +├── adapters.h # Master include + type aliasing (force-included) +├── api.h # ORT C API includes +├── common.h # EP common utilities +└── adapter/ + ├── allocator.h # IAllocator adapter + ├── data_transfer_manager.h # DataTransferManager adapter + ├── ep.h # Ep base class (wraps IExecutionProvider) + ├── kernel_def.h # KernelDef adapter + ├── kernel_def_builder.h # KernelDefBuilder adapter + ├── kernel_registry.h # KernelRegistry adapter + ├── logging.h # Logger adapter + ├── node.h # Node adapter + ├── op_kernel.h # OpKernel + OpKernelContext adapters + ├── op_kernel_info.h # OpKernelInfo adapter + └── tensor_helper.h # Tensor creation from C API values +``` + +--- + +## 13. Future Work + +1. **`Stream*` adapter** — Extend the adapter `OpKernelContext::GetComputeStream()` to return a full `Stream*` that attention/LLM ops can use. This unblocks ~25 operators. + +2. **Tunable ops** — Implement a plugin-side `ITuningContext` and remove the `ORT_USE_EP_API_ADAPTERS` guards in `matmul.cc`/`gemm.cc`. + +3. **String-array C API** — Add `KernelInfoGetAttributeArray_string` to the ORT C API to unblock RNN ops. + +4. **Remaining CPU base classes** — Inline `SpaceDepthBase`, `UpsampleBase`, and object detection base classes. + +5. **CI integration** — Add plugin build + test to the CI pipeline. + +6. **CUDA Graph API for plugin EPs** — Add `IsGraphCaptureEnabled`, `IsGraphCaptured`, and `ReplayGraph` callbacks to the `OrtEp` C API (see [Section 5.4.4](#544-what-needs-to-change-in-ort-core-option-a)). This is required for efficient CUDA graph replay in the plugin EP. The capture/replay infrastructure (`cuda_graph_plugin.h/.cc`, `CudaEp` state machine) is already implemented and will activate once the API is extended. From f97bbe41fc313096c0a93f195fa4121ea220ae72 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 23 Mar 2026 13:24:18 -0700 Subject: [PATCH 03/48] ops --- cmake/onnxruntime_providers_cuda_plugin.cmake | 2 +- cmake/onnxruntime_unittests.cmake | 5 + .../cuda/bert/attention_softmax.cu | 2 +- .../bert/decoder_masked_self_attention.cc | 2 +- .../contrib_ops/cuda/bert/fast_gelu.cc | 31 +++++ .../cuda/bert/group_query_attention.cc | 40 +++--- .../contrib_ops/cuda/bert/packed_attention.cc | 4 +- .../cuda/bert/packed_multihead_attention.cc | 2 +- .../contrib_ops/cuda/bert/paged_attention.cc | 18 ++- .../cuda/bert/relative_attn_bias.cc | 2 +- .../contrib_ops/cuda/bert/remove_padding.cc | 2 +- .../contrib_ops/cuda/bert/skip_layer_norm.cc | 5 +- .../cuda/collective/distributed_reduce.cc | 4 +- onnxruntime/contrib_ops/cuda/inverse.cc | 10 +- .../contrib_ops/cuda/math/bias_dropout.cc | 2 +- .../quantization/attention_quantization.cc | 3 +- .../cuda/quantization/matmul_bnb4.cc | 10 +- .../cuda/quantization/matmul_nbits.cc | 6 +- .../cuda/quantization/matmul_nbits.h | 22 ++- .../cuda/sparse/sparse_attention.cc | 19 ++- .../cuda/tensor/dynamic_time_warping.cc | 2 +- .../transformers/generation_device_helper.cc | 9 +- .../cuda/transformers/sampling_cuda_helper.h | 4 +- .../core/providers/cpu/tensor/upsamplebase.h | 31 +++-- .../cuda/generator/constant_of_shape.h | 116 ++++++++++++++++ .../core/providers/cuda/integer_gemm.cc | 15 +- .../core/providers/cuda/math/cumsum.cc | 28 +++- .../core/providers/cuda/math/matmul.cc | 16 ++- .../providers/cuda/math/matmul_integer.cc | 7 +- .../core/providers/cuda/math/softmax.cc | 11 +- .../core/providers/cuda/math/softmax.h | 10 +- .../core/providers/cuda/math/softmax_impl.cu | 44 +++--- onnxruntime/core/providers/cuda/math/topk.cc | 3 +- .../core/providers/cuda/math/topk_impl.cuh | 16 +-- .../core/providers/cuda/math/topk_impl.h | 2 +- .../cuda/math/variadic_elementwise_ops.cc | 2 +- .../core/providers/cuda/nn/batch_norm.cc | 10 +- onnxruntime/core/providers/cuda/nn/conv.cc | 2 +- onnxruntime/core/providers/cuda/nn/conv.h | 2 +- onnxruntime/core/providers/cuda/nn/conv_8.h | 8 +- .../core/providers/cuda/nn/conv_transpose.cc | 8 ++ .../core/providers/cuda/nn/conv_transpose.h | 2 +- .../core/providers/cuda/nn/conv_transpose_8.h | 8 +- onnxruntime/core/providers/cuda/nn/dropout.cc | 2 +- .../core/providers/cuda/nn/instance_norm.cc | 20 +-- onnxruntime/core/providers/cuda/nn/pool.cc | 4 +- .../providers/cuda/reduction/reduction_ops.cc | 76 ++++++----- .../providers/cuda/reduction/reduction_ops.h | 4 +- .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 37 +++-- .../core/providers/cuda/rnn/cudnn_rnn_base.h | 5 +- .../providers/cuda/shared_inc/integer_gemm.h | 5 +- .../core/providers/cuda/tensor/compress.cc | 4 +- .../core/providers/cuda/tensor/concat.cc | 14 +- .../core/providers/cuda/tensor/gather.cc | 2 +- .../core/providers/cuda/tensor/gather_nd.cc | 10 +- .../core/providers/cuda/tensor/gather_nd.h | 3 +- .../core/providers/cuda/tensor/nonzero_op.cc | 4 +- onnxruntime/core/providers/cuda/tensor/pad.cc | 43 +++++- .../core/providers/cuda/tensor/scatter_nd.cc | 55 +++++++- .../core/providers/cuda/tensor/scatter_nd.h | 2 + .../core/providers/cuda/tensor/slice.cc | 15 +- .../core/providers/cuda/tensor/slice.h | 6 +- .../providers/cuda/tensor/space_depth_ops.h | 129 +++++++++++++++++- .../core/providers/cuda/tensor/split.cc | 81 +++++++++-- .../core/providers/cuda/tensor/split.h | 9 ++ .../core/providers/cuda/tensor/tile.cc | 34 +++++ .../core/providers/cuda/tensor/transpose.cc | 2 +- .../core/providers/cuda/tensor/unsqueeze.cc | 55 ++++++++ .../core/providers/cuda/tensor/upsample.cc | 28 ++-- .../core/providers/cuda/tensor/upsample.h | 3 - .../python/onnxruntime_pybind_module.cc | 23 +++- .../python/onnxruntime_pybind_state.cc | 50 ++++++- onnxruntime/test/unittest_util/base_tester.cc | 20 ++- 73 files changed, 995 insertions(+), 297 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index 1acc9cc133024..9834700ec220b 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -328,7 +328,7 @@ target_link_libraries(onnxruntime_providers_cuda_plugin PRIVATE ) # Symbol visibility — only export CreateEpFactories and ReleaseEpFactory -target_compile_definitions(onnxruntime_providers_cuda_plugin PRIVATE ORT_API_MANUAL_INIT BUILD_CUDA_EP_AS_PLUGIN ONNX_ML=1 ONNX_NAMESPACE=onnx ONNX_USE_LITE_PROTO=1) +target_compile_definitions(onnxruntime_providers_cuda_plugin PRIVATE ORT_API_MANUAL_INIT BUILD_CUDA_EP_AS_PLUGIN ORT_USE_EP_API_ADAPTERS=1 ONNX_ML=1 ONNX_NAMESPACE=onnx ONNX_USE_LITE_PROTO=1) if(WIN32) # Windows: use .def file for symbol exports diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 9ae3e79d86443..98171599228c1 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1461,6 +1461,11 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() else() target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common absl::flags absl::flags_parse ${onnx_test_libs}) + # When onnxruntime_BUILD_SHARED_LIB is OFF (the plugin build path), perf test was missing CUDA include directories and CUDA::cudart linkage. + if (onnxruntime_USE_CUDA OR onnxruntime_USE_NV OR onnxruntime_USE_TENSORRT) + target_include_directories(onnxruntime_perf_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_link_libraries(onnxruntime_perf_test PRIVATE CUDA::cudart) + endif() endif() set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest") diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index ff7ac67852427..ba647bc3f6d33 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -974,7 +974,7 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream, if (use_persistent_softmax) { return onnxruntime::cuda::dispatch_warpwise_softmax_forward( - ort_stream, + stream, output, persistent_softmax_workspace, total_sequence_length, diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc index b4643da58eba5..416bf24dd818c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc @@ -134,7 +134,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont int m = batch_size * sequence_length; int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size); int k = parameters.input_hidden_size; - gemm_buffer = GetScratchBuffer(static_cast(m) * n, context->GetComputeStream()); + gemm_buffer = GetScratchBuffer(static_cast(m) * n, GetComputeStream(context)); CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index e7ed96d7f5ee2..3c1feb7af956a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -30,6 +30,33 @@ REGISTER_KERNEL_TYPED(double) using namespace ONNX_NAMESPACE; +#ifdef BUILD_CUDA_EP_AS_PLUGIN +static Status CheckInputsForPlugin(const OpKernelContext* context) { + const Tensor* input = context->Input(0); + const Tensor* bias = context->Input(1); + + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() < 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 0 is expected to have 1 or more dimensions, got ", input_dims.size()); + } + + if (nullptr != bias) { + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 1 is expected to have 1 dimensions, got ", bias_dims.size()); + } + if (bias_dims[0] != input_dims[input_dims.size() - 1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 1 dimension 0 should have same length as the last dimension of input 0"); + } + } + + return Status::OK(); +} +#endif + template FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { const TransformerOptions* options = TransformerOptions::GetInstance(); @@ -38,7 +65,11 @@ FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel template Status FastGelu::ComputeInternal(OpKernelContext* context) const { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + ORT_RETURN_IF_ERROR(CheckInputsForPlugin(context)); +#else ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context)); +#endif const Tensor* input = context->Input(0); const Tensor* bias = context->Input(1); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 6e6d463889007..ae2fd1c43ceb7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -142,6 +142,14 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) // 11. head_sink (Tensor) - Attention sink for GPT-OSS template Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { + // Stream access: void* for GetScratchBuffer, Stream* for QkvToContext. +#ifdef BUILD_CUDA_EP_AS_PLUGIN + onnxruntime::PluginStreamShim __stream_shim(GetComputeStream(context)); + auto* ort_stream = static_cast(&__stream_shim); +#else + auto* ort_stream = context->GetComputeStream(); +#endif + const Tensor* query = context->Input(0); const Tensor* key = context->Input(1); const Tensor* value = context->Input(2); @@ -259,8 +267,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, dense_head_size}; TensorShape present_shape(present_dims); - Tensor* present_key_tensor = context->Output(1, present_shape); - Tensor* present_value_tensor = context->Output(2, present_shape); + Tensor* present_key_output = context->Output(1, present_shape); // present_key + Tensor* present_value_output = context->Output(2, present_shape); // present_value IAllocatorUniquePtr k_buffer; IAllocatorUniquePtr v_buffer; @@ -288,8 +296,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); - data.present_key = reinterpret_cast(present_key_tensor->MutableData()); - data.present_value = reinterpret_cast(present_value_tensor->MutableData()); + data.present_key = reinterpret_cast(present_key_output->MutableData()); + data.present_value = reinterpret_cast(present_value_output->MutableData()); // Compute past_present_share_buffer early since it's needed for flash attention path selection. // This compares the final pointer values after quantization handling. @@ -370,7 +378,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons xqa_total_bytes += q_bytes + k_bytes; } - xqa_scratch_buffer = this->GetScratchBuffer(xqa_total_bytes, context->GetComputeStream()); + xqa_scratch_buffer = this->GetScratchBuffer(xqa_total_bytes, GetComputeStream(context)); data.xqa_buffer = xqa_scratch_buffer.get(); data.xqa_buffer_bytes = xqa_internal_bytes; @@ -413,11 +421,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons out_accum_bytes = onnxruntime::flash::get_out_accum_size(num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, round_multiple(parameters.head_size, 32)); } - softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); - softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); - out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, GetComputeStream(context)); + softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, GetComputeStream(context)); + out_accum_buffer = GetScratchBuffer(out_accum_bytes, GetComputeStream(context)); - auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + auto cuda_stream = Stream(context); if (softmax_lse_accum_bytes > 0) { // Initialize to 0 is fine because Flash kernel will write -inf to it if needed. // However, the standard Flash kernel often doesn't zero it globally. @@ -442,8 +450,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons } else { // Compute sequence length buffers (past_seq_lens and total_seq_lens). // Allocate buffer for both: first half is past_seq_lens, second half is total_seq_lens. - seq_lens_buffer = GetScratchBuffer(3 * parameters.batch_size, context->GetComputeStream()); - auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + seq_lens_buffer = GetScratchBuffer(3 * parameters.batch_size, GetComputeStream(context)); + auto cuda_stream = Stream(context); data.past_seq_lens = seq_lens_buffer.get(); data.total_seq_lens = seq_lens_buffer.get() + parameters.batch_size; data.padded_seq_lens = data.total_seq_lens + parameters.batch_size; @@ -480,9 +488,9 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons ? (sizeof(float) * parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size) : 0; - k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); - v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); - fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); + k_buffer = GetScratchBuffer(kv_buffer_bytes, GetComputeStream(context)); + v_buffer = GetScratchBuffer(kv_buffer_bytes, GetComputeStream(context)); + fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, GetComputeStream(context)); data.k = reinterpret_cast(k_buffer.get()); data.v = reinterpret_cast(v_buffer.get()); @@ -501,7 +509,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons data.use_memory_efficient_attention); if (buffer_req.qkv_buffer_bytes > 0) { - unpacked_qkv_buffer = GetScratchBuffer(buffer_req.qkv_buffer_bytes, context->GetComputeStream()); + unpacked_qkv_buffer = GetScratchBuffer(buffer_req.qkv_buffer_bytes, GetComputeStream(context)); data.qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); } @@ -556,7 +564,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons cublasHandle_t cublas = GetCublasHandle(context); ORT_RETURN_IF_ERROR((QkvToContext( - device_prop, cublas, context->GetComputeStream(), parameters, data))); + device_prop, cublas, ort_stream, parameters, data))); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index 10bd1170d8f07..df4bc7b170106 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -286,7 +286,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { int m = parameters.token_count; int n = parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size; int k = parameters.input_hidden_size; - gemm_buffer = this->template GetScratchBuffer(static_cast(m) * n, context->GetComputeStream()); + gemm_buffer = this->template GetScratchBuffer(static_cast(m) * n, this->GetComputeStream(context)); cublasHandle_t cublas = this->GetCublasHandle(context); @@ -310,7 +310,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { false, use_memory_efficient_attention, no_qkv_workspace); - auto work_space = this->template GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + auto work_space = this->template GetScratchBuffer(workSpaceSize, this->GetComputeStream(context)); typedef typename ToCudaType::MappedType CudaT; PackedAttentionData data; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 928a8b17229e6..20a45d18367bc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -267,7 +267,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co use_flash_attention, use_memory_efficient_attention, no_qkv_workspace); - auto work_space = this->template GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + auto work_space = this->template GetScratchBuffer(workSpaceSize, this->GetComputeStream(context)); PackedMultiHeadAttentionData data; data.query = reinterpret_cast(query->Data()); diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc index b9f30f71e9a66..5f68282726c2a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc @@ -54,6 +54,14 @@ PagedAttention::PagedAttention(const OpKernelInfo& info) template Status PagedAttention::ComputeInternal(OpKernelContext* context) const { + // Stream access: void* for GetScratchBuffer, Stream* for QkvToContext. +#ifdef BUILD_CUDA_EP_AS_PLUGIN + onnxruntime::PluginStreamShim __stream_shim(GetComputeStream(context)); + auto* ort_stream = static_cast(&__stream_shim); +#else + auto* ort_stream = context->GetComputeStream(); +#endif + const Tensor* query = context->Input(0); const Tensor* key = context->Input(1); const Tensor* value = context->Input(2); @@ -151,10 +159,10 @@ Status PagedAttention::ComputeInternal(OpKernelContext* context) const { softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.token_count, parameters.num_heads); } - auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, GetComputeStream(context)); #else constexpr bool use_flash_attention = false; - auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto softmax_lse_buffer = GetScratchBuffer(0, GetComputeStream(context)); // nullptr #endif if (!use_flash_attention) { @@ -163,7 +171,7 @@ Status PagedAttention::ComputeInternal(OpKernelContext* context) const { } size_t cumulative_seqlens_kv_bytes = sizeof(int) * (parameters.batch_size + 1); - auto cumulative_seqlens_kv_buffer = GetScratchBuffer(cumulative_seqlens_kv_bytes, context->GetComputeStream()); + auto cumulative_seqlens_kv_buffer = GetScratchBuffer(cumulative_seqlens_kv_bytes, GetComputeStream(context)); size_t workspace_buffer_bytes = 0; if (do_rotary_) { @@ -171,7 +179,7 @@ Status PagedAttention::ComputeInternal(OpKernelContext* context) const { } else if (parameters.is_packed_qkv) { workspace_buffer_bytes = sizeof(T) * parameters.token_count * parameters.hidden_size; } - auto workspace_buffer = GetScratchBuffer(workspace_buffer_bytes, context->GetComputeStream()); + auto workspace_buffer = GetScratchBuffer(workspace_buffer_bytes, GetComputeStream(context)); // Print debug info if (kernel_options_->AllowDebugInfo()) { @@ -210,7 +218,7 @@ Status PagedAttention::ComputeInternal(OpKernelContext* context) const { cublasHandle_t cublas = GetCublasHandle(context); return QkvToContext( - device_prop, cublas, context->GetComputeStream(), parameters, data); + device_prop, cublas, ort_stream, parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc index 05f55d9106d0e..e62020a09216d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -163,7 +163,7 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c const size_t elements_after_gemm = (size_t)BNS * (size_t)D; bool reuse_output = (seq_len >= D); size_t workspace_size = sizeof(T) * (elements_in_query + (reuse_output ? (size_t)0 : elements_after_gemm)); - auto workspace = GetScratchBuffer(workspace_size, context->GetComputeStream()); + auto workspace = this->GetScratchBuffer(workspace_size, this->GetComputeStream(context)); cudaStream_t stream = Stream(context); if (!is_padding_removed) { diff --git a/onnxruntime/contrib_ops/cuda/bert/remove_padding.cc b/onnxruntime/contrib_ops/cuda/bert/remove_padding.cc index ec9617982c776..eba4c48301cf3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/remove_padding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/remove_padding.cc @@ -53,7 +53,7 @@ Status RemovePadding::ComputeInternal(OpKernelContext* context) const { int64_t sequence_length = dims[1]; int64_t hidden_size = dims[2]; - auto token_count_buffer = GetScratchBuffer(2, context->GetComputeStream()); + auto token_count_buffer = GetScratchBuffer(2, GetComputeStream(context)); TensorShapeVector token_offset_shape(2); token_offset_shape[0] = batch_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 92ae7e81fb5bd..1cbb44a82f97b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -43,9 +43,12 @@ SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); +#ifdef BUILD_CUDA_EP_AS_PLUGIN + strict_ = onnxruntime::cuda::GetCudaKernelAdapterSkipLayerNormStrictMode(); +#else const CUDAExecutionProvider* cuda_ep = static_cast(op_kernel_info.GetExecutionProvider()); - strict_ = cuda_ep->IsSkipLayerNormInStrictMode(); +#endif } template diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc index 967f30a304ac2..00e50a80c327b 100644 --- a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc @@ -83,9 +83,11 @@ Status DistributedReduceBase::ComputeInternal(OpKernelContext* context) const const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute(); return onnxruntime::cuda::ReduceComputeCore( /* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + /* kernel */ this, *input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span, /* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false, - enable_fast_but_non_deterministic_reduction, context->GetComputeStream()); + enable_fast_but_non_deterministic_reduction, + Stream(context), GetComputeStream(context), GetCudnnHandle(context)); } return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/inverse.cc b/onnxruntime/contrib_ops/cuda/inverse.cc index 9075dda26f86b..e3ece229142aa 100644 --- a/onnxruntime/contrib_ops/cuda/inverse.cc +++ b/onnxruntime/contrib_ops/cuda/inverse.cc @@ -65,7 +65,7 @@ Status CheckForSingularity(cudaStream_t stream, const IAllocatorUniquePtr& template struct Inverse::ComputeImpl { - Status operator()(onnxruntime::Stream* ort_stream, Inverse::CublasHandle cublas_h, const Inverse* inst, const Tensor& input, Tensor& output, + Status operator()(void* ort_stream, Inverse::CublasHandle cublas_h, const Inverse* inst, const Tensor& input, Tensor& output, const IAllocatorUniquePtr& info, const IAllocatorUniquePtr& pivots, size_t num_batches, size_t rows) const { using namespace onnxruntime::cuda; @@ -75,7 +75,7 @@ struct Inverse::ComputeImpl { auto info_cpu = std::make_unique(num_batches); const auto dim = static_cast(rows); const auto n_batches = static_cast(num_batches); - cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; + cudaStream_t stream = ort_stream ? static_cast(ort_stream) : nullptr; // Make a copy of the input which will serve as a workspace as well. if constexpr (std::is_same::value || std::is_same::value) { @@ -150,13 +150,13 @@ Status Inverse::ComputeInternal(OpKernelContext* ctx) const { num_batches = static_cast(input_shape.SizeToDimension(num_dim - 2)); } - IAllocatorUniquePtr info = GetScratchBuffer(num_batches, ctx->GetComputeStream()); + IAllocatorUniquePtr info = GetScratchBuffer(num_batches, GetComputeStream(ctx)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(info.get(), 0, num_batches * sizeof(int), Stream(ctx))); - IAllocatorUniquePtr pivots = GetScratchBuffer(rows * num_batches, ctx->GetComputeStream()); + IAllocatorUniquePtr pivots = GetScratchBuffer(rows * num_batches, GetComputeStream(ctx)); utils::MLTypeCallDispatcher t_disp(input->GetElementType()); return t_disp.InvokeRet( - ctx->GetComputeStream(), GetCublasHandle(ctx), this, *input, *output, info, pivots, num_batches, rows); + GetComputeStream(ctx), GetCublasHandle(ctx), this, *input, *output, info, pivots, num_batches, rows); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/math/bias_dropout.cc b/onnxruntime/contrib_ops/cuda/math/bias_dropout.cc index 79cc656ba0491..163d85b458ba1 100644 --- a/onnxruntime/contrib_ops/cuda/math/bias_dropout.cc +++ b/onnxruntime/contrib_ops/cuda/math/bias_dropout.cc @@ -124,7 +124,7 @@ Status BiasDropout::ComputeInternal(OpKernelContext* context) const } IAllocatorUniquePtr temp_mask_buffer{}; // buffer to use if mask is not provided - auto* ort_stream = context->GetComputeStream(); + auto* ort_stream = GetComputeStream(context); void* const mask_data = [this, mask_element_count, mask, &temp_mask_buffer, ort_stream]() { if (mask) return mask->MutableDataRaw(); temp_mask_buffer = diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 33cd906508bcf..20ac68d602684 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -149,7 +149,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { weights->Data(), n, gemm_buffer_quantized.get(), n, this, - context->GetComputeStream())); + GetComputeStream(context), Stream(context), + GetCublasHandle(context))); CudaT dequant_scale; CudaT input_scale = *(reinterpret_cast(input_scale_tensor->Data())); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc index 0534ed6dc7fc0..d5d7153d0c8b9 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -56,12 +56,12 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { // TODO: find a better way to create the quant_map without using a buffer // don't want to use malloc directly so asking from the caller // can create a __device__ static array for float but doesn't work for half - IAllocatorUniquePtr quant_map_buffer = GetScratchBuffer(16, ctx->GetComputeStream()); + IAllocatorUniquePtr quant_map_buffer = this->template GetScratchBuffer(16, this->GetComputeStream(ctx)); auto* quant_map_buffer_data = quant_map_buffer.get(); ORT_RETURN_IF_ERROR(SetBnbQuantMap( SafeInt(quant_type_), reinterpret_cast(quant_map_buffer_data), - static_cast(ctx->GetComputeStream()->GetHandle()))); + this->Stream(ctx))); constexpr bool transa = false; const bool transb = transB_; @@ -85,10 +85,10 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { SafeInt(helper.N()), SafeInt(helper.K()), SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle())); + this->Stream(ctx)); if (!is_4bit_done) { - IAllocatorUniquePtr b_dequant_ptr = GetScratchBuffer(N_ * K_, ctx->GetComputeStream()); + IAllocatorUniquePtr b_dequant_ptr = this->template GetScratchBuffer(N_ * K_, this->GetComputeStream(ctx)); auto* b_dequant_data = b_dequant_ptr.get(); ORT_RETURN_IF_ERROR(DequantizeBnb4( reinterpret_cast(quant_map_buffer_data), @@ -97,7 +97,7 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { reinterpret_cast(absmax_data), SafeInt(block_size_), SafeInt(N_ * K_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + this->Stream(ctx))); const CudaT alpha = ToCudaType::FromFloat(1.f); const CudaT zero = ToCudaType::FromFloat(0.f); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 3b667bf68634c..699ead654c83f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -304,7 +304,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { if (Y->Shape().Size() == 0) return Status::OK(); - cudaStream_t stream = static_cast(ctx->GetComputeStream()->GetHandle()); + cudaStream_t stream = this->Stream(ctx); typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; @@ -352,7 +352,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { onnxruntime::llm::kernels::fpA_intB_gemv::kernel_launcher(sm_, params, stream); } else { const size_t workspace_size = weightOnlyGemmRunner_->getWorkspaceSize(m, n, k); - auto workspace_buffer = GetScratchBuffer(workspace_size, ctx->GetComputeStream()); + auto workspace_buffer = this->template GetScratchBuffer(workspace_size, this->GetComputeStream(ctx)); weightOnlyGemmRunner_->gemm( a_data, @@ -394,7 +394,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { } int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; - IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); + IAllocatorUniquePtr b_data_ptr = this->template GetScratchBuffer(N_ * K_padded, this->GetComputeStream(ctx)); auto* b_data = b_data_ptr.get(); if (nbits_ == 8) { diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h index a7f0a9516584c..87a675d282fdd 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -44,21 +44,31 @@ class MatMulNBits final : public CudaKernel { ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); - constexpr size_t kInputIndexScale = 2; - constexpr size_t kInputIndexZeroPoints = 3; - constexpr size_t kInputIndexGroupIndex = 4; - constexpr size_t kInputIndexBias = 5; - + constexpr int kInputIndexScale = 2; + constexpr int kInputIndexZeroPoints = 3; + constexpr int kInputIndexGroupIndex = 4; + constexpr int kInputIndexBias = 5; + +#ifdef BUILD_CUDA_EP_AS_PLUGIN + // Plugin adapter Node does not have InputDefs(). Defer existence checks to ComputeInternal + // where we can check if the actual input tensor is null or not. + (void)kInputIndexScale; // used only in non-plugin path + has_zero_points_ = info.GetInputCount() > kInputIndexZeroPoints; + has_g_idx_ = info.GetInputCount() > kInputIndexGroupIndex; + has_bias_ = info.GetInputCount() > kInputIndexBias; + // is_zero_points_scale_same_type_ defaults to false; runtime will handle differences. +#else has_zero_points_ = info.GetInputCount() > kInputIndexZeroPoints && info.node().InputDefs()[kInputIndexZeroPoints]->Exists(); has_g_idx_ = info.GetInputCount() > kInputIndexGroupIndex && info.node().InputDefs()[kInputIndexGroupIndex]->Exists(); has_bias_ = info.GetInputCount() > kInputIndexBias && info.node().InputDefs()[kInputIndexBias]->Exists(); - sm_ = this->GetDeviceProp().major * 10 + this->GetDeviceProp().minor; if (has_zero_points_) { int32_t zero_point_type = info.node().InputDefs()[kInputIndexZeroPoints]->TypeAsProto()->tensor_type().elem_type(); int32_t scale_type = info.node().InputDefs()[kInputIndexScale]->TypeAsProto()->tensor_type().elem_type(); is_zero_points_scale_same_type_ = (zero_point_type == scale_type); } +#endif + sm_ = this->GetDeviceProp().major * 10 + this->GetDeviceProp().minor; #if USE_FPA_INTB_GEMM if constexpr (std::is_same::value || std::is_same::value) { diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc index 865a1dc29ce47..bb4a9fabaca7e 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc @@ -60,6 +60,14 @@ SparseAttention::SparseAttention(const OpKernelInfo& info) template Status SparseAttention::ComputeInternal(OpKernelContext* context) const { + // Stream access: void* for GetScratchBuffer, Stream* for QkvToContext. +#ifdef BUILD_CUDA_EP_AS_PLUGIN + onnxruntime::PluginStreamShim __stream_shim(GetComputeStream(context)); + auto* ort_stream = static_cast(&__stream_shim); +#else + auto* ort_stream = context->GetComputeStream(); +#endif + auto& device_prop = GetDeviceProp(); if constexpr (std::is_same::value) { if (device_prop.major < 8) { @@ -219,8 +227,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length * parameters.head_size; rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length; } - onnxruntime::Stream* stream = context->GetComputeStream(); - auto rotary_buffer = GetScratchBuffer(rotary_buffer_bytes, stream); + auto rotary_buffer = GetScratchBuffer(rotary_buffer_bytes, GetComputeStream(context)); data.rotary_buffer = reinterpret_cast(rotary_buffer.get()); size_t transposed_q_bytes = 0; @@ -228,7 +235,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { transposed_q_bytes = parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(T); } - auto transposed_q_buffer = GetScratchBuffer(transposed_q_bytes, stream); + auto transposed_q_buffer = GetScratchBuffer(transposed_q_bytes, GetComputeStream(context)); if (transposed_q_buffer) { data.transposed_q_buffer = reinterpret_cast(transposed_q_buffer.get()); } @@ -239,7 +246,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { (parameters.num_heads + 2 * parameters.kv_num_heads) * parameters.head_size * sizeof(T)); } - auto unpacked_qkv_buffer = GetScratchBuffer(unpacked_qkv_bytes, stream); + auto unpacked_qkv_buffer = GetScratchBuffer(unpacked_qkv_bytes, GetComputeStream(context)); if (unpacked_qkv_buffer) { data.unpacked_qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); } @@ -303,7 +310,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { } } - v2_kernel_buffer = GetScratchBuffer(v2_kernel_buffer_size, stream); + v2_kernel_buffer = GetScratchBuffer(v2_kernel_buffer_size, GetComputeStream(context)); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(v2_kernel_buffer.get(), v2_kernel_inputs_pinned, sizeof(int32_t) * v2_kernel_buffer_size, cudaMemcpyHostToDevice, cuda_stream)); @@ -317,7 +324,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { data.active_q_blocks = active_q_blocks; } - return QkvToContext(device_prop, stream, parameters, data); + return QkvToContext(device_prop, ort_stream, parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc index 381316f605fc9..2c9f5c1cd9014 100644 --- a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc @@ -35,7 +35,7 @@ Status DynamicTimeWarping::ComputeInternal(OpKernelContext* ctx) const { size_t max_index_len = 0; size_t buffer_size_in_bytes = GetDynamicTimeWarpingBufferSize(1, rows, cols, max_index_len); - IAllocatorUniquePtr buffer = GetScratchBuffer(buffer_size_in_bytes, ctx->GetComputeStream()); + IAllocatorUniquePtr buffer = GetScratchBuffer(buffer_size_in_bytes, GetComputeStream(ctx)); size_t result_len = 0; ORT_RETURN_IF_ERROR(LaunchDynamicTimeWarping( diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e2e60066ec36d..36d0287685838 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -139,10 +139,12 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest, output_indices = std::move(*Tensor::Create(DataTypeImpl::GetType(), output_shape, std::move(allocator))); Status result; + auto cuda_stream = stream ? static_cast(stream->GetHandle()) : nullptr; if (input->IsDataType()) { result = TopKImpl(nullptr, // We limit number of beams in BeamSearchParameters, so K <= 256 and use NULL here false /*use_deterministic_compute*/, - stream, + cuda_stream, + nullptr, // alloc_stream not needed when kernel is nullptr input->Data(), static_cast(output_values.MutableDataRaw()), static_cast(output_indices.MutableDataRaw()), @@ -157,7 +159,8 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest, } else if (input->IsDataType()) { result = TopKImpl(nullptr, false /*use_deterministic_compute*/, - stream, + cuda_stream, + nullptr, // alloc_stream not needed when kernel is nullptr input->Data(), static_cast(output_values.MutableDataRaw()), static_cast(output_indices.MutableDataRaw()), @@ -411,7 +414,7 @@ Status ProcessLogits(const OrtValue& logits, // const CudaT* X_data = is_reuse_logits_buffer ? logits_data : reinterpret_cast(next_token_logits.data()); ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward( - ort_stream, Y_data, X_data, vocab_size, + ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr, Y_data, X_data, vocab_size, is_reuse_logits_buffer ? padded_vocab_size : vocab_size, vocab_size, batch_size * num_beams))); diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h index 7be3a4851aaed..d4e7e02f51fdd 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h @@ -93,7 +93,7 @@ Status Sample(AllocatorPtr& allocator, #endif gsl::span& d_sorted_softmaxed_score = sampling_state->d_sorted_softmaxed_score; - ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward(stream, + ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward(cuda_stream, d_sorted_softmaxed_score.data(), reinterpret_cast(d_sorted_score.data()), parameters->vocab_size, @@ -127,7 +127,7 @@ Status Sample(AllocatorPtr& allocator, #endif gsl::span& d_softmaxed_score = sampling_state->d_softmaxed_score; - ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward(stream, + ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward(cuda_stream, d_softmaxed_score.data(), reinterpret_cast(next_token_scores.data()), parameters->vocab_size, diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index 7dcf88133e967..724adaaf0e3e3 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -16,6 +16,7 @@ #include "core/common/status.h" #include #include +#include #ifndef SHARED_PROVIDER #include "core/framework/op_kernel.h" #endif @@ -166,12 +167,28 @@ inline void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span +struct has_input_defs : std::false_type {}; +template +struct has_input_defs().InputDefs())>> : std::true_type {}; +} // namespace upsamplebase_detail + class UpsampleBase { public: // Make this available in other EP via provider bridge // it works iff output_shape is specified +#ifdef SHARED_PROVIDER void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, InlinedVector& scales) const; +#else + void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, + InlinedVector& scales) const { + upsamplebase_helper::AdjustOutputSizeAsPolicy(output_dims, input_dims, scales, keep_aspect_ratio_policy_, axes_); + } +#endif protected: template @@ -265,8 +282,11 @@ class UpsampleBase { if (scales_input_idx_ > 0) { const Tensor* scale; bool get_scale = info.TryGetConstantInput(scales_input_idx_, &scale); - auto x_shape = node.InputDefs()[0]->Shape(); - int64_t rank = x_shape ? x_shape->dim_size() : -1; + int64_t rank = -1; + if constexpr (upsamplebase_detail::has_input_defs::value) { + auto x_shape = node.InputDefs()[0]->Shape(); + rank = x_shape ? x_shape->dim_size() : -1; + } if (get_scale && scale->Shape().Size() > 0 && ((opset < 18) || (rank > 0 && opset >= 18))) { ORT_THROW_IF_ERROR(ParseScalesData(scale, scales_, rank)); scales_cached_ = true; @@ -643,13 +663,6 @@ class UpsampleBase { } }; // UpsampleBase -#ifndef SHARED_PROVIDER -inline void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, - InlinedVector& scales) const { - upsamplebase_helper::AdjustOutputSizeAsPolicy(output_dims, input_dims, scales, keep_aspect_ratio_policy_, axes_); -} -#endif - } // namespace onnxruntime #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) diff --git a/onnxruntime/core/providers/cuda/generator/constant_of_shape.h b/onnxruntime/core/providers/cuda/generator/constant_of_shape.h index 99c5da0615ede..10d9b979a82bd 100644 --- a/onnxruntime/core/providers/cuda/generator/constant_of_shape.h +++ b/onnxruntime/core/providers/cuda/generator/constant_of_shape.h @@ -4,12 +4,126 @@ #pragma once #include "core/providers/cuda/cuda_kernel.h" +#ifndef BUILD_CUDA_EP_AS_PLUGIN #include "core/providers/cpu/generator/constant_of_shape_base.h" +#endif #include "core/providers/cuda/shared_inc/cuda_utils.h" namespace onnxruntime { namespace cuda { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + +// Plugin build: self-contained ConstantOfShape without ConstantOfShapeBase dependency. +// ConstantOfShapeBase uses TensorProto/UnpackTensor utilities not available in the plugin, +// so we read the 'value' attribute via the ORT C API (KernelInfoGetAttribute_tensor) instead. +class ConstantOfShape final : public CudaKernel { + public: + explicit ConstantOfShape(const OpKernelInfo& info) : CudaKernel(info) { + InitValue(info); + } + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConstantOfShape); + + Status ComputeInternal(OpKernelContext* ctx) const override; + + protected: + void* GetValuePtr() const { return p_value_; } + + static Status PrepareCompute(OpKernelContext* ctx, Tensor** output_tensor) { + const auto* shape_tensor = ctx->Input(0); + const auto& input_shape = shape_tensor->Shape(); + ORT_RETURN_IF_NOT(input_shape.NumDimensions() > 0, "Must have a valid input shape."); + const auto span = shape_tensor->DataAsSpan(); + TensorShape output_shape(span); + (*output_tensor) = ctx->Output(0, output_shape); + return Status::OK(); + } + + private: + union SizeBasedValue { + int8_t int8_; + int16_t int16_; + int32_t int32_; + int64_t int64_; + }; + + mutable SizeBasedValue s_value_{}; + mutable void* p_value_ = nullptr; + + void SetValue(size_t size, const void* value) { + switch (size) { + case sizeof(int8_t): + s_value_.int8_ = *(reinterpret_cast(value)); + p_value_ = reinterpret_cast(&(s_value_.int8_)); + break; + case sizeof(int16_t): + s_value_.int16_ = *(reinterpret_cast(value)); + p_value_ = reinterpret_cast(&(s_value_.int16_)); + break; + case sizeof(int32_t): + s_value_.int32_ = *(reinterpret_cast(value)); + p_value_ = reinterpret_cast(&(s_value_.int32_)); + break; + case sizeof(int64_t): + s_value_.int64_ = *(reinterpret_cast(value)); + p_value_ = reinterpret_cast(&(s_value_.int64_)); + break; + default: + ORT_THROW("Unsupported value attribute datatype with size: ", size); + } + } + + void InitValue(const OpKernelInfo& info) { + Ort::AllocatorWithDefaultOptions allocator; + auto ort_info = info.GetKernelInfo(); + try { + Ort::Value value_tensor = ort_info.GetTensorAttribute("value", allocator); + auto type_and_shape = value_tensor.GetTensorTypeAndShapeInfo(); + size_t elem_count = type_and_shape.GetElementCount(); + ORT_ENFORCE(elem_count == 1 || elem_count == 0, + "The value attribute of ConstantOfShape must be a single-element tensor"); + if (elem_count == 1) { + const void* data = value_tensor.GetTensorRawData(); + size_t elem_size = GetElementSize(type_and_shape.GetElementType()); + SetValue(elem_size, data); + } else { + float f_value = 0.f; + SetValue(sizeof(float), &f_value); + } + } catch (const Ort::Exception&) { + float f_value = 0.f; + SetValue(sizeof(float), &f_value); + } + } + + static size_t GetElementSize(ONNXTensorElementDataType type) { + switch (type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return 1; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + return 2; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return 4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + return 8; + default: + ORT_THROW("Unsupported element type for ConstantOfShape: ", static_cast(type)); + } + } +}; + +#else // !BUILD_CUDA_EP_AS_PLUGIN + class ConstantOfShape final : public ConstantOfShapeBase<>, public CudaKernel { public: explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), CudaKernel(info) {} @@ -19,5 +133,7 @@ class ConstantOfShape final : public ConstantOfShapeBase<>, public CudaKernel { Status ComputeInternal(OpKernelContext* ctx) const override; }; +#endif // BUILD_CUDA_EP_AS_PLUGIN + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/integer_gemm.cc b/onnxruntime/core/providers/cuda/integer_gemm.cc index ef22f4793befc..a3e8ba0cc41af 100644 --- a/onnxruntime/core/providers/cuda/integer_gemm.cc +++ b/onnxruntime/core/providers/cuda/integer_gemm.cc @@ -17,12 +17,13 @@ constexpr int roundoff(int v, int d) { Status GemmInt8(int m, int n, int k, int32_t alpha, int32_t beta, const int8_t* a, int lda, const int8_t* b, int ldb, int32_t* c, int ldc, - const CudaKernel* cuda_kernel, onnxruntime::Stream* ort_stream) { + const CudaKernel* cuda_kernel, void* alloc_stream, cudaStream_t cuda_stream, + cublasHandle_t cublas_handle) { ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null"); ORT_ENFORCE(cuda_kernel != nullptr, "kernel is null"); - ORT_ENFORCE(ort_stream != nullptr, "Cuda kernel must have the stream instance"); + ORT_ENFORCE(cuda_stream != nullptr, "Cuda kernel must have the cuda stream"); - cudaStream_t stream = static_cast(ort_stream->GetHandle()); + cudaStream_t stream = cuda_stream; // pad A and B to make their leading dimension be multiples of 32 // because cublasGemmEx requires: @@ -34,7 +35,7 @@ Status GemmInt8(int m, int n, int k, IAllocatorUniquePtr a_padded; if ((mask & lda_aligned) != 0) { lda_aligned = roundoff(lda, 32); - a_padded = cuda_kernel->GetScratchBuffer(SafeInt(m) * lda_aligned, ort_stream); + a_padded = cuda_kernel->GetScratchBuffer(SafeInt(m) * lda_aligned, alloc_stream); cudaMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, cudaMemcpyDeviceToDevice, stream); } @@ -42,14 +43,12 @@ Status GemmInt8(int m, int n, int k, IAllocatorUniquePtr b_padded; if ((mask & ldb_aligned) != 0) { ldb_aligned = roundoff(ldb, 32); - b_padded = cuda_kernel->GetScratchBuffer(SafeInt(k) * ldb_aligned, ort_stream); + b_padded = cuda_kernel->GetScratchBuffer(SafeInt(k) * ldb_aligned, alloc_stream); cudaMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, cudaMemcpyDeviceToDevice, stream); } - auto cublas = cuda_kernel->GetCublasHandleOrDefault(ort_stream); - CUBLAS_RETURN_IF_ERROR(cublasGemmEx( - cublas, + cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, diff --git a/onnxruntime/core/providers/cuda/math/cumsum.cc b/onnxruntime/core/providers/cuda/math/cumsum.cc index b8b1f29c643d8..899a37f5bff16 100644 --- a/onnxruntime/core/providers/cuda/math/cumsum.cc +++ b/onnxruntime/core/providers/cuda/math/cumsum.cc @@ -3,12 +3,36 @@ #include "cumsum.h" #include "cumsum_impl.h" -#include "core/providers/cpu/math/cumsum.h" #include "core/providers/common.h" namespace onnxruntime { namespace cuda { +namespace { + +Status GetAxisFromInput(const Tensor* axis_tensor, int64_t input_rank, int64_t& axis_out) { + if (!axis_tensor) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Axis tensor must be provided to the CumSum op"); + } + + if (axis_tensor->Shape().NumDimensions() > 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Axis tensor should be 0D or 1D"); + } + + if (axis_tensor->IsDataType()) { + axis_out = static_cast(axis_tensor->Data()[0]); + } else if (axis_tensor->IsDataType()) { + axis_out = axis_tensor->Data()[0]; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Axis tensor should be of type `int32_t` or `int64_t`"); + } + + axis_out = HandleNegativeAxis(axis_out, input_rank); + return Status::OK(); +} + +} // namespace + ONNX_OPERATOR_VERSIONED_KERNEL_EX( CumSum, kOnnxDomain, @@ -53,7 +77,7 @@ Status CumSum::ComputeInternal(OpKernelContext* ctx) const { const Tensor* axis_tensor = ctx->Input(1); // axis input tensor int64_t axis = 0; - ORT_THROW_IF_ERROR(cumsum_op::GetAxis(axis_tensor, rank, axis)); + ORT_THROW_IF_ERROR(GetAxisFromInput(axis_tensor, rank, axis)); TensorShape output_shape(input->Shape()); auto& output = *ctx->Output(0, output_shape); // output tensor diff --git a/onnxruntime/core/providers/cuda/math/matmul.cc b/onnxruntime/core/providers/cuda/math/matmul.cc index 04ffa875c1b9d..67e84ac99322c 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.cc +++ b/onnxruntime/core/providers/cuda/math/matmul.cc @@ -5,7 +5,9 @@ #include "core/providers/cuda/shared_inc/fpgeneric.h" #include "core/providers/cuda/cuda_allocator.h" +#ifndef BUILD_CUDA_EP_AS_PLUGIN #include "core/providers/cuda/tunable/math/matmul.h" +#endif namespace onnxruntime { namespace cuda { @@ -121,9 +123,11 @@ Status MatMul::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); } +#ifndef BUILD_CUDA_EP_AS_PLUGIN if (GetTuningContext()->IsTunableOpEnabled()) { return tunable::TunableMatMul(alpha_, trans_a, trans_b, trans_batch_a_, trans_batch_b_, helper, this, ctx); } +#endif return ComputeDefault(ctx, helper); } @@ -219,9 +223,9 @@ Status FuncMatMul( MatMulComputeHelper::OffsetToArrays(reinterpret_cast(A->Data()), helper.LeftOffsets(), A_arrays.CpuSpan()); MatMulComputeHelper::OffsetToArrays(reinterpret_cast(B->Data()), helper.RightOffsets(), B_arrays.CpuSpan()); MatMulComputeHelper::OffsetToArrays(reinterpret_cast(Y->MutableData()), helper.OutputOffsets(), Y_arrays.CpuSpan()); - ORT_RETURN_IF_ERROR(A_arrays.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(B_arrays.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(Y_arrays.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(A_arrays.CopyToGpu(cuda_kernel->GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(B_arrays.CopyToGpu(cuda_kernel->GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(Y_arrays.CopyToGpu(cuda_kernel->GetComputeStream(ctx))); // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. // It requires Ampere or newer GPU, and pointers of matrices shall be aligned (ideal alignment is 16-byte). @@ -370,9 +374,9 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help MatMulComputeHelper::OffsetToArrays(reinterpret_cast(left_X->Data()), helper.LeftOffsets(), left_arrays.CpuSpan()); MatMulComputeHelper::OffsetToArrays(reinterpret_cast(right_X->Data()), helper.RightOffsets(), right_arrays.CpuSpan()); MatMulComputeHelper::OffsetToArrays(reinterpret_cast(Y->MutableData()), helper.OutputOffsets(), output_arrays.CpuSpan()); - ORT_RETURN_IF_ERROR(left_arrays.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(right_arrays.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(output_arrays.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(left_arrays.CopyToGpu(this->GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(right_arrays.CopyToGpu(this->GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(output_arrays.CopyToGpu(this->GetComputeStream(ctx))); // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. // It requires Ampere or newer GPU, and pointers of matrices shall be aligned (ideal alignment is 16-byte). diff --git a/onnxruntime/core/providers/cuda/math/matmul_integer.cc b/onnxruntime/core/providers/cuda/math/matmul_integer.cc index 6e43834bdf1ef..b42f204efa430 100644 --- a/onnxruntime/core/providers/cuda/math/matmul_integer.cc +++ b/onnxruntime/core/providers/cuda/math/matmul_integer.cc @@ -69,13 +69,13 @@ Status MatMulInteger::ComputeInternal(OpKernelContext* ctx) cons // OffsetOutput computes gets the final result IAllocatorUniquePtr a_row_buf; if (b_offset != 0) { - a_row_buf = GetScratchBuffer(helper.OutputShape().Size() / helper.N(), ctx->GetComputeStream()); + a_row_buf = GetScratchBuffer(helper.OutputShape().Size() / helper.N(), GetComputeStream(ctx)); ORT_RETURN_IF_ERROR(ReduceRowSumOnMatrixA(Stream(ctx), a_ptr, a_row_buf.get(), b_offset, helper)); } IAllocatorUniquePtr b_col_buf; if (a_offset != 0) { - b_col_buf = GetScratchBuffer(helper.OutputShape().Size() / helper.M(), ctx->GetComputeStream()); + b_col_buf = GetScratchBuffer(helper.OutputShape().Size() / helper.M(), GetComputeStream(ctx)); ORT_RETURN_IF_ERROR(ReduceColSumOnMatrixB(Stream(ctx), b_ptr, b_col_buf.get(), a_offset, helper)); } @@ -105,7 +105,8 @@ Status MatMulInteger::ComputeInternal(OpKernelContext* ctx) cons output_ptr + helper.OutputOffsets()[batch], static_cast(helper.N()), this, - ctx->GetComputeStream())); + GetComputeStream(ctx), Stream(ctx), + GetCublasHandle(ctx))); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/math/softmax.cc b/onnxruntime/core/providers/cuda/math/softmax.cc index 9402232a24737..609f17891fc0a 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.cc +++ b/onnxruntime/core/providers/cuda/math/softmax.cc @@ -13,7 +13,7 @@ namespace cuda { template Status SoftMaxComputeHelper( - Stream* stream, + SoftmaxComputeStreamT stream, const T* X, const TensorShape& input_shape, TOut* Y, @@ -40,9 +40,9 @@ Status SoftMaxComputeHelper( } #define SPECIALIZED_SOFTMAX_HELPER_IMPL(T, TOut) \ - template Status SoftMaxComputeHelper(Stream * stream, const T* input, \ + template Status SoftMaxComputeHelper(SoftmaxComputeStreamT stream, const T* input, \ const TensorShape& shape, TOut* Y, int64_t axis); \ - template Status SoftMaxComputeHelper(Stream * stream, const T* input, \ + template Status SoftMaxComputeHelper(SoftmaxComputeStreamT stream, const T* input, \ const TensorShape& shape, TOut* Y, int64_t axis); SPECIALIZED_SOFTMAX_HELPER_IMPL(MLFloat16, float) @@ -177,12 +177,13 @@ Status Softmax::ComputeInternal(OpKernelContext* ctx) const { } Status status; + auto compute_stream = Stream(ctx); if (log_softmax_) { - status = SoftMaxComputeHelper(ctx->GetComputeStream(), X_data, *compute_input_shape, Y_data, + status = SoftMaxComputeHelper(compute_stream, X_data, *compute_input_shape, Y_data, is_transpose_required ? static_cast(rank) - 1 : static_cast(axis)); } else { - status = SoftMaxComputeHelper(ctx->GetComputeStream(), X_data, *compute_input_shape, Y_data, + status = SoftMaxComputeHelper(compute_stream, X_data, *compute_input_shape, Y_data, is_transpose_required ? static_cast(rank) - 1 : static_cast(axis)); } diff --git a/onnxruntime/core/providers/cuda/math/softmax.h b/onnxruntime/core/providers/cuda/math/softmax.h index 6f4016b655c96..c0c0818042c15 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.h +++ b/onnxruntime/core/providers/cuda/math/softmax.h @@ -9,20 +9,22 @@ namespace onnxruntime { namespace cuda { +using SoftmaxComputeStreamT = cudaStream_t; + template Status SoftMaxComputeHelper( - Stream* stream, + SoftmaxComputeStreamT stream, const T* input, const TensorShape& shape, TOut* Y, int64_t axis); template -Status dispatch_warpwise_softmax_forward(Stream* stream, output_t* dst, const input_t* src, +Status dispatch_warpwise_softmax_forward(SoftmaxComputeStreamT stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); template -Status dispatch_blockwise_softmax_forward(Stream* stream, output_t* output, const input_t* input, +Status dispatch_blockwise_softmax_forward(SoftmaxComputeStreamT stream, output_t* output, const input_t* input, int softmax_elements, int input_stride, int output_stride, int batch_count); template @@ -32,7 +34,7 @@ class Softmax final : public CudaKernel { const auto& node = info.node(); opset_ = node.SinceVersion(); - int64_t axis; + int64_t axis = 0; Status status = info.GetAttr("axis", &axis); if (status.IsOK()) { diff --git a/onnxruntime/core/providers/cuda/math/softmax_impl.cu b/onnxruntime/core/providers/cuda/math/softmax_impl.cu index 04e66e9e1529e..10550fb0cb810 100644 --- a/onnxruntime/core/providers/cuda/math/softmax_impl.cu +++ b/onnxruntime/core/providers/cuda/math/softmax_impl.cu @@ -29,9 +29,9 @@ namespace onnxruntime { namespace cuda { template -Status dispatch_warpwise_softmax_forward(Stream* ort_stream, output_t* dst, const input_t* src, int softmax_elements, +Status dispatch_warpwise_softmax_forward(SoftmaxComputeStreamT ort_stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) { - auto stream = static_cast(ort_stream->GetHandle()); + auto stream = ort_stream; if (softmax_elements == 0) { return Status::OK(); } else { @@ -95,18 +95,18 @@ Status dispatch_warpwise_softmax_forward(Stream* ort_stream, output_t* dst, cons return CUDA_CALL(cudaGetLastError()); } -#define SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ - template Status dispatch_warpwise_softmax_forward(Stream * ort_stream, \ - output_t * dst, \ - const input_t* src, \ - int softmax_elements, \ - int softmax_elements_stride, \ - int batch_count); \ - template Status dispatch_warpwise_softmax_forward(Stream * ort_stream, \ - output_t * dst, \ - const input_t* src, \ - int softmax_elements, \ - int softmax_elements_stride, \ +#define SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ + template Status dispatch_warpwise_softmax_forward(SoftmaxComputeStreamT ort_stream, \ + output_t * dst, \ + const input_t* src, \ + int softmax_elements, \ + int softmax_elements_stride, \ + int batch_count); \ + template Status dispatch_warpwise_softmax_forward(SoftmaxComputeStreamT ort_stream, \ + output_t * dst, \ + const input_t* src, \ + int softmax_elements, \ + int softmax_elements_stride, \ int batch_count); SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(float, float, float) @@ -116,9 +116,9 @@ SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(double, double, double) SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(BFloat16, BFloat16, float) template -Status dispatch_blockwise_softmax_forward(Stream* ort_stream, output_t* output, const input_t* input, int softmax_elements, +Status dispatch_blockwise_softmax_forward(SoftmaxComputeStreamT ort_stream, output_t* output, const input_t* input, int softmax_elements, int input_stride, int output_stride, int batch_count) { - auto stream = static_cast(ort_stream->GetHandle()); + auto stream = ort_stream; dim3 grid(batch_count); constexpr int ILP = sizeof(float4) / sizeof(input_t); dim3 block = SoftMax_getBlockSize(ILP, softmax_elements); @@ -134,12 +134,12 @@ Status dispatch_blockwise_softmax_forward(Stream* ort_stream, output_t* output, return CUDA_CALL(cudaGetLastError()); } -#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ - template Status dispatch_blockwise_softmax_forward( \ - Stream * ort_stream, output_t * output, const input_t* src, int softmax_elements, \ - int input_stride, int output_stride, int batch_count); \ - template Status dispatch_blockwise_softmax_forward( \ - Stream * ort_stream, output_t * output, const input_t* src, int softmax_elements, \ +#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ + template Status dispatch_blockwise_softmax_forward( \ + SoftmaxComputeStreamT ort_stream, output_t * output, const input_t* src, int softmax_elements, \ + int input_stride, int output_stride, int batch_count); \ + template Status dispatch_blockwise_softmax_forward( \ + SoftmaxComputeStreamT ort_stream, output_t * output, const input_t* src, int softmax_elements, \ int input_stride, int output_stride, int batch_count); SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float) diff --git a/onnxruntime/core/providers/cuda/math/topk.cc b/onnxruntime/core/providers/cuda/math/topk.cc index cf26e0acfa557..fedd33a3638e6 100644 --- a/onnxruntime/core/providers/cuda/math/topk.cc +++ b/onnxruntime/core/providers/cuda/math/topk.cc @@ -62,7 +62,8 @@ TopK::TopK(const OpKernelInfo& info) : CudaKernel(info) { #define IS_PRIM_TYPE(T) utils::IsPrimitiveDataType(prim_type) #define TOPKIMPL(T) TopKImpl(this, use_deterministic_compute, \ - ctx->GetComputeStream(), tensor_X->Data(), \ + Stream(ctx), GetComputeStream(ctx), \ + tensor_X->Data(), \ static_cast(tensor_V->MutableDataRaw()), \ static_cast(tensor_I->MutableDataRaw()), \ elem_nums_cuda, \ diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.cuh b/onnxruntime/core/providers/cuda/math/topk_impl.cuh index 28efde1852d4f..e29e12fb09d6b 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/topk_impl.cuh @@ -395,13 +395,12 @@ __global__ void ExcludeOutput(T* output_i, T K, T dimension) { template Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, - Stream* ort_stream, const T* input_x, T* output_v, int64_t* output_i, + cudaStream_t stream, void* alloc_stream, const T* input_x, T* output_v, int64_t* output_i, const TArray& elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) { typedef typename ToCudaType::MappedType CudaT; const CudaT* input_x_ptr = reinterpret_cast(input_x); CudaT* output_v_ptr = reinterpret_cast(output_v); - cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; auto aligned_K = ALIGN(K); auto aligned_dimension = ALIGN(dimension); @@ -436,17 +435,17 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, NumericLimits::Lowest(), NumericLimits::Max()); } } else { - auto input_key_buffer = kernel->GetScratchBuffer(dimension, ort_stream); - auto output_key_buffer = kernel->GetScratchBuffer(dimension, ort_stream); - auto input_value_buffer = kernel->GetScratchBuffer(dimension, ort_stream); - auto output_value_buffer = kernel->GetScratchBuffer(dimension, ort_stream); + auto input_key_buffer = kernel->GetScratchBuffer(dimension, alloc_stream); + auto output_key_buffer = kernel->GetScratchBuffer(dimension, alloc_stream); + auto input_value_buffer = kernel->GetScratchBuffer(dimension, alloc_stream); + auto output_value_buffer = kernel->GetScratchBuffer(dimension, alloc_stream); auto* input_key = input_key_buffer.get(); auto* output_key = output_key_buffer.get(); auto* input_value = input_value_buffer.get(); auto* output_value = output_value_buffer.get(); size_t temp_bytes = 0; CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(nullptr, temp_bytes, input_key, output_key, input_value, output_value, dimension, 0, sizeof(T) * 8, stream)); - auto temp_storage_buffer = kernel->GetScratchBuffer(temp_bytes, ort_stream); + auto temp_storage_buffer = kernel->GetScratchBuffer(temp_bytes, alloc_stream); auto* temp_storage = temp_storage_buffer.get(); auto blocks_per_grid_D = (int)(ceil(static_cast(dimension) / BT)); auto blocks_per_grid_K = (int)(ceil(static_cast(K) / BT)); @@ -468,7 +467,8 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, #define TOPKIMPLE(T) template Status TopKImpl(const CudaKernel* kernel, \ bool use_deterministic_compute, \ - Stream* ort_stream, \ + cudaStream_t stream, \ + void* alloc_stream, \ const T* input_x, \ T* output_v, \ int64_t* output_i, \ diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.h b/onnxruntime/core/providers/cuda/math/topk_impl.h index c5f63aadc402a..f072dd4e66107 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.h +++ b/onnxruntime/core/providers/cuda/math/topk_impl.h @@ -11,7 +11,7 @@ namespace onnxruntime { namespace cuda { template -Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, Stream* ort_stream, +Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, cudaStream_t stream, void* alloc_stream, const T* input_x, T* output_v, int64_t* output_i, const TArray& elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension); diff --git a/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc index f543ba9c975e1..bb05c04a4dc1e 100644 --- a/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc @@ -158,7 +158,7 @@ Status VariadicElementwiseOp OpKernelContext* context) const { const auto& node = Node(); const auto& node_name = node.Name(); - auto input_count = node.InputArgCount().front(); + auto input_count = context->InputCount(); ORT_RETURN_IF_NOT(input_count >= 1, "Must have 1 or more inputs"); const InputTensorVector input_tensors = diff --git a/onnxruntime/core/providers/cuda/nn/batch_norm.cc b/onnxruntime/core/providers/cuda/nn/batch_norm.cc index 02da1a2c99dfd..b5aacff1d22ef 100644 --- a/onnxruntime/core/providers/cuda/nn/batch_norm.cc +++ b/onnxruntime/core/providers/cuda/nn/batch_norm.cc @@ -99,10 +99,10 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) // Convert the scale, B, mean, var to float const int64_t C = x_shape.GetDims()[NHWC ? 3 : 1]; - auto f_scale = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); - auto f_B = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); - auto f_mean = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); - auto f_var = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); + auto f_scale = GetScratchBuffer(C, GetComputeStream(p_op_kernel_context)); + auto f_B = GetScratchBuffer(C, GetComputeStream(p_op_kernel_context)); + auto f_mean = GetScratchBuffer(C, GetComputeStream(p_op_kernel_context)); + auto f_var = GetScratchBuffer(C, GetComputeStream(p_op_kernel_context)); Impl_Cast(Stream(p_op_kernel_context), scale_data, f_scale.get(), C); Impl_Cast(Stream(p_op_kernel_context), b_data, f_B.get(), C); Impl_Cast(Stream(p_op_kernel_context), mean_data, f_mean.get(), C); @@ -137,7 +137,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) auto saved_mean_data = reinterpret_cast(saved_mean->MutableData()); auto saved_inv_var_data = reinterpret_cast(saved_var->MutableData()); - auto stream = static_cast(p_op_kernel_context->GetComputeStream()->GetHandle()); + auto stream = Stream(p_op_kernel_context); CUDA_RETURN_IF_ERROR( cudaMemcpyAsync(running_mean_data, mean_data, mean->SizeInBytes(), cudaMemcpyDeviceToDevice, stream)); CUDA_RETURN_IF_ERROR( diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index dbda2d3ad0c88..7d6335dbd2cbf 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -491,7 +491,7 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { CUDA_RETURN_IF_ERROR(cudaMemset(s_.y_data, 0, s_.Y->SizeInBytes())); } } - auto ws = GetWorkSpace(context->GetComputeStream()); + auto ws = GetWorkSpace(GetComputeStream(context)); CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_graph->execute(cudnn_handle, s_.variant_pack, diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index e4047a6af272e..fa9808a6d3d3e 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -224,7 +224,7 @@ class Conv : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; protected: - inline IAllocatorUniquePtr GetWorkSpace(onnxruntime::Stream* stream) const { + inline IAllocatorUniquePtr GetWorkSpace(void* stream) const { return GetScratchBuffer(s_.workspace_bytes, stream); } diff --git a/onnxruntime/core/providers/cuda/nn/conv_8.h b/onnxruntime/core/providers/cuda/nn/conv_8.h index bcee1bcb7e231..09745d785dd69 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_8.h +++ b/onnxruntime/core/providers/cuda/nn/conv_8.h @@ -189,7 +189,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) s_.Y = context->Output(0, TensorShape(s_.y_dims)); if (post_slicing_required) { // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. - s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); + s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, GetComputeStream(context)); s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); } else { // No post slicing needed. Fill the output tensor's buffer directly. @@ -376,7 +376,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) return Status::OK(); } if (s_.post_slicing_required) { - s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); + s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, GetComputeStream(context)); s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); } else { s_.y_data = reinterpret_cast(s_.Y->MutableData()); @@ -394,7 +394,7 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { } const auto alpha = Consts::One; const auto beta = Consts::Zero; - IAllocatorUniquePtr workspace = GetWorkSpace(context->GetComputeStream()); + IAllocatorUniquePtr workspace = GetWorkSpace(GetComputeStream(context)); auto cudnn_handle = GetCudnnHandle(context); CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(cudnn_handle, &alpha, @@ -481,4 +481,4 @@ Status CudnnConvolutionDescriptor::Set( return Status::OK(); } } // namespace cuda -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 28197c20af052..0ad19815f0600 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -226,7 +226,11 @@ template Status ConvTranspose::UpdateState(OpKernelContext* context, bool dynamic_padding) const { constexpr bool channels_last = Layout == LAYOUT_NHWC; +#ifdef BUILD_CUDA_EP_AS_PLUGIN + size_t num_inputs = static_cast(Info().GetInputCount()); +#else size_t num_inputs = OpKernel::Node().InputDefs().size(); +#endif bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; // set X @@ -483,7 +487,11 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool CUDA_RETURN_IF_ERROR(cudaMemset(s_.y_data, 0, s_.Y->SizeInBytes())); } } +#ifdef BUILD_CUDA_EP_AS_PLUGIN + auto ws = GetWorkSpace(context->GetGPUComputeStream()); +#else auto ws = GetWorkSpace(context->GetComputeStream()); +#endif CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_graph->execute(cudnn_handle, s_.variant_pack, diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index 1a6957164d22f..62deb0475289a 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -37,7 +37,7 @@ class ConvTranspose : public CudaKernel { bool W_already_nhwc = false; // In case NHWC == true and Conv is not in kMSInternalNHWCDomain protected: - inline IAllocatorUniquePtr GetWorkSpace(onnxruntime::Stream* stream) const { + inline IAllocatorUniquePtr GetWorkSpace(void* stream) const { return GetScratchBuffer(s_.workspace_bytes, stream); } diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h index aa1fe26ac97db..d296c8540dd5f 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h @@ -48,8 +48,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy TensorShapeVector w_dims = w_shape.AsShapeVector(); auto w_data = reinterpret_cast(W->Data()); - size_t num_inputs = OpKernel::Node().InputDefs().size(); - bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; + const Tensor* B_tensor = context->Input(dynamic_padding ? 3 : 2); + bool has_bias = B_tensor != nullptr; CudaT* y_data = nullptr; @@ -190,7 +190,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy if (!s_.cached_benchmark_results.contains(x_dims)) { IAllocatorUniquePtr algo_search_workspace = - GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); + GetScratchBuffer(AlgoSearchWorkspaceSize, GetComputeStream(context)); // set math type to tensor core before algorithm search if constexpr (std::is_same::value) { @@ -241,7 +241,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy const auto alpha = Consts::One; const auto beta = Consts::Zero; - IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, context->GetComputeStream()); + IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, GetComputeStream(context)); CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardData(GetCudnnHandle(context), &alpha, s_.w_desc, w_data, s_.x_tensor, x_data, s_.conv_desc, s_.algo, workspace.get(), diff --git a/onnxruntime/core/providers/cuda/nn/dropout.cc b/onnxruntime/core/providers/cuda/nn/dropout.cc index 16818d010361a..54202e0231732 100644 --- a/onnxruntime/core/providers/cuda/nn/dropout.cc +++ b/onnxruntime/core/providers/cuda/nn/dropout.cc @@ -111,7 +111,7 @@ Status Dropout::ComputeInternal(OpKernelContext* context) const { void* const mask_data = [this, mask_element_count, mask, &temp_mask_buffer, context]() { if (mask) return mask->MutableDataRaw(); temp_mask_buffer = - GetScratchBuffer(mask_element_count * (UseBitmask ? sizeof(BitmaskElementType) : sizeof(bool)), context->GetComputeStream()); + GetScratchBuffer(mask_element_count * (UseBitmask ? sizeof(BitmaskElementType) : sizeof(bool)), GetComputeStream(context)); return temp_mask_buffer.get(); }(); diff --git a/onnxruntime/core/providers/cuda/nn/instance_norm.cc b/onnxruntime/core/providers/cuda/nn/instance_norm.cc index 30ba80dc8b05a..46c8086b74c3f 100644 --- a/onnxruntime/core/providers/cuda/nn/instance_norm.cc +++ b/onnxruntime/core/providers/cuda/nn/instance_norm.cc @@ -104,15 +104,15 @@ Status InstanceNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) co const size_t stats_byte_count = stats_count * sizeof(CudaT); // Mean & Variance are inputs & outputs and must be initialized to zero to work properly - auto mean = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto mean = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mean.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); - auto variance = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto variance = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(variance.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); // We must set the scale & bias inputs to zero as they are inputs to the calculation - auto unused_scale = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto unused_scale = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(unused_scale.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); - auto unused_bias = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto unused_bias = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(unused_bias.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); // first, compute mean and variance per-instance per-channel using cudnnBatchNorm training @@ -201,10 +201,10 @@ Status InstanceNorm::ComputeInternal(OpKernelContext* p_op_kernel_con // alpha, beta will be of type float as the Consts struct specialization // for MLFloat16 type take care of that. Only Convert the scale, bias to float) - auto scale_data_fp32 = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); + auto scale_data_fp32 = GetScratchBuffer(C, GetComputeStream(p_op_kernel_context)); Impl_Cast(Stream(p_op_kernel_context), scale_data, scale_data_fp32.get(), C); - auto bias_data_fp32 = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); + auto bias_data_fp32 = GetScratchBuffer(C, GetComputeStream(p_op_kernel_context)); Impl_Cast(Stream(p_op_kernel_context), bias_data, bias_data_fp32.get(), C); CUDNN_RETURN_IF_ERROR(BatchNormalizationForwardTrainingHelper( @@ -247,15 +247,15 @@ Status InstanceNorm::ComputeInternal(OpKernelContext* p_op_kernel_con const size_t stats_byte_count = stats_count * sizeof(float); // Mean & Variance are inputs & outputs and must be initialized to zero to work properly - auto mean = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto mean = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mean.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); - auto variance = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto variance = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(variance.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); // We must set the scale & bias inputs to zero as they are inputs to the calculation - auto unused_scale = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto unused_scale = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(unused_scale.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); - auto unused_bias = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto unused_bias = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(unused_bias.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); // first, compute mean and variance per-instance per-channel using cudnnBatchNorm training diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index 3a97a5f2481e7..89c6047769cf0 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -246,8 +246,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) cons const auto input_count = x_shape.Size(); const auto output_count = y_shape.Size(); - IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count, context->GetComputeStream()); - auto temp_Y = GetScratchBuffer(output_count, context->GetComputeStream()); + IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count, GetComputeStream(context)); + auto temp_Y = GetScratchBuffer(output_count, GetComputeStream(context)); Impl_Cast(Stream(context), reinterpret_cast(x_data), temp_X.get(), input_count); CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, temp_X.get(), &beta, y_tensor, temp_Y.get())); diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 006b2366af0a5..5fdf64275c274 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -335,17 +335,27 @@ Status PrepareForReduce(const Tensor* X, return Status::OK(); } +template +inline IAllocatorUniquePtr AllocateScratchBuffer( + const AllocatorPtr& gpu_allocator, const CudaKernel* kernel, size_t count, void* compute_stream) { + if (count == 0) return nullptr; + if (kernel) { + return kernel->GetScratchBuffer(count, compute_stream); + } + return IAllocator::MakeUniquePtr(gpu_allocator, count, false, static_cast(compute_stream)); +} + // `input_shape_override` is the input shape for compute purposes (if provided) template -Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, +Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const CudaKernel* kernel, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, gsl::span axes, bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, - Stream* ort_stream, cudnnHandle_t cudnn_handle, + cudaStream_t cuda_stream, void* compute_stream, cudnnHandle_t cudnn_handle, const TensorShape* input_shape_override) { typedef typename ToCudaType::MappedType CudaT; const TensorShape& input_shape = input_shape_override ? *input_shape_override : input.Shape(); - cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; + cudaStream_t stream = cuda_stream; int64_t input_count = prepare_reduce_metadata.input_count; int64_t output_count = prepare_reduce_metadata.output_count; @@ -367,7 +377,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, IAllocatorUniquePtr input_data_buffer(nullptr, [](T*) {}); const CudaT* input_data = reinterpret_cast(input.Data()); if (calculate_sqt) { - input_data_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); + input_data_buffer = AllocateScratchBuffer(gpu_allocator, kernel, input_count, compute_stream); input_data = reinterpret_cast(input_data_buffer.get()); fast_divmod tmp_div; Impl_Mul(stream, static_cast(SimpleBroadcast::NoBroadcast), nullptr, @@ -386,7 +396,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size(m, n); auto buffer = buffer_size_bytes == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, buffer_size_bytes, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, buffer_size_bytes, compute_stream); ORT_RETURN_IF_ERROR(reduce_matrix_columns(stream, input_data, reinterpret_cast(output.MutableData()), m, n, buffer.get(), buffer_size_bytes)); @@ -423,7 +433,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, if ((ReduceTensorIndices == CUDNN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same::value) || (ReduceTensorIndices == CUDNN_REDUCE_TENSOR_NO_INDICES && std::is_same::value)) { // ArgMax/ArgMin with FP16 are not supported by cudnn, so convert input to fp32 then call cudnn - temp_X = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); + temp_X = AllocateScratchBuffer(gpu_allocator, kernel, input_count, compute_stream); Impl_Cast(stream, reinterpret_cast(input.Data()), temp_X.get(), input_shape.Size()); } else { cudnn_type_X = CudnnTensor::GetDataType(); @@ -448,20 +458,20 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, input_tensor, output_tensor, &workspace_bytes)); auto workspace_cuda = workspace_bytes == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, workspace_bytes, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, workspace_bytes, compute_stream); size_t indices_bytes = 0; CUDNN_RETURN_IF_ERROR(cudnnGetReductionIndicesSize(cudnn_handle, reduce_desc, input_tensor, output_tensor, &indices_bytes)); auto indices_cuda = indices_bytes == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, indices_bytes, compute_stream); if (ReduceTensorIndices == CUDNN_REDUCE_TENSOR_NO_INDICES) { IAllocatorUniquePtr input_data_buffer(nullptr, [](T*) {}); CudaT* input_data = nullptr; if (calculate_sqt) { - input_data_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); + input_data_buffer = AllocateScratchBuffer(gpu_allocator, kernel, input_count, compute_stream); input_data = reinterpret_cast(input_data_buffer.get()); fast_divmod tmp_div; Impl_Mul(stream, @@ -490,7 +500,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, input_tensor, output_tensor, &indices_bytes_max)); auto indices_cuda_max = indices_bytes_max == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes_max, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, indices_bytes_max, compute_stream); auto* p_output = reinterpret_cast(output.template MutableData()); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( cudnn_handle, reduce_max_desc, indices_cuda_max.get(), indices_bytes_max, @@ -501,11 +511,11 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, // Exp(X-ReduceMax) const TensorShape output_shape(output_dims); - auto exp_result_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); + auto exp_result_buffer = AllocateScratchBuffer(gpu_allocator, kernel, input_count, compute_stream); auto exp_result = exp_result_buffer.get(); auto log_sum_result_buffer = output_count == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, output_count, compute_stream); auto log_sum_result = log_sum_result_buffer.get(); BinaryElementwisePreparation prepare; ORT_RETURN_IF_ERROR(prepare.BinaryElementwiseBroadcastPrepareHelper(input_shape, output_shape, input_shape)); @@ -575,7 +585,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, if (temp_X) { auto temp_output = output_count == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, output_count, compute_stream); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( cudnn_handle, reduce_desc, indices_cuda.get(), indices_bytes, workspace_cuda.get(), workspace_bytes, @@ -603,7 +613,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, if (temp_X) { auto temp_output = output_count == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, output_count, compute_stream); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( cudnn_handle, reduce_desc, indices_cuda.get(), indices_bytes, workspace_cuda.get(), workspace_bytes, @@ -612,7 +622,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, } else { auto temp_output = output_count == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, output_count, compute_stream); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( cudnn_handle, reduce_desc, indices_cuda.get(), indices_bytes, workspace_cuda.get(), workspace_bytes, @@ -637,27 +647,27 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, } template Status ReduceComputeCore( - const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + const AllocatorPtr& gpu_allocator, const CudaKernel* kernel, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, gsl::span axes, bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, - Stream* ort_stream, cudnnHandle_t cudnn_handle, + cudaStream_t cuda_stream, void* compute_stream, cudnnHandle_t cudnn_handle, const TensorShape* input_shape_override); template Status ReduceComputeCore( - const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + const AllocatorPtr& gpu_allocator, const CudaKernel* kernel, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, gsl::span axes, bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, - Stream* ort_stream, cudnnHandle_t cudnn_handle, + cudaStream_t cuda_stream, void* compute_stream, cudnnHandle_t cudnn_handle, const TensorShape* input_shape_override); template Status ReduceComputeCore( - const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + const AllocatorPtr& gpu_allocator, const CudaKernel* kernel, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, gsl::span axes, bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, - Stream* ort_stream, cudnnHandle_t cudnn_handle, + cudaStream_t cuda_stream, void* compute_stream, cudnnHandle_t cudnn_handle, const TensorShape* input_shape_override); template @@ -704,9 +714,9 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnRe ORT_RETURN_IF_ERROR(PrepareForReduce(X, keepdims_, axes, prepare_reduce_metadata)); Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); const bool fast_reduction = fast_reduction_ && !ctx->GetUseDeterministicCompute(); - return ReduceComputeCore(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), *X, prepare_reduce_metadata, *Y, cudnn_reduce_op, axes, - calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction, ctx->GetComputeStream(), - GetCudnnHandleOrDefault(ctx->GetComputeStream())); + return ReduceComputeCore(AllocatorPtr{}, this, *X, prepare_reduce_metadata, *Y, cudnn_reduce_op, axes, + calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction, + Stream(ctx), GetComputeStream(ctx), GetCudnnHandle(ctx)); } #define SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(T) \ @@ -768,7 +778,7 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnRe CudnnReduceDescriptor reduce_desc; \ \ cudnnDataType_t cudnn_type_X = CUDNN_DATA_FLOAT; \ - IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count, ctx->GetComputeStream()); \ + IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count, GetComputeStream(ctx)); \ Impl_Cast(Stream(ctx), reinterpret_cast(X->Data()), temp_X.get(), X->Shape().Size()); \ \ ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, cudnn_type_X, CUDNN_REDUCE_TENSOR_NO_INDICES)); \ @@ -778,12 +788,12 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnRe cudnnGetReductionIndicesSize(GetCudnnHandle(ctx), reduce_desc, input_tensor, output_tensor, &indices_bytes)); \ CUDNN_RETURN_IF_ERROR( \ cudnnGetReductionWorkspaceSize(GetCudnnHandle(ctx), reduce_desc, input_tensor, output_tensor, &workspace_bytes)); \ - IAllocatorUniquePtr indices_cuda = GetScratchBuffer(indices_bytes, ctx->GetComputeStream()); \ - IAllocatorUniquePtr workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); \ + IAllocatorUniquePtr indices_cuda = GetScratchBuffer(indices_bytes, GetComputeStream(ctx)); \ + IAllocatorUniquePtr workspace_cuda = GetScratchBuffer(workspace_bytes, GetComputeStream(ctx)); \ \ const auto one = Consts::One; \ const auto zero = Consts::Zero; \ - auto temp_Y = GetScratchBuffer(output_count, ctx->GetComputeStream()); \ + auto temp_Y = GetScratchBuffer(output_count, GetComputeStream(ctx)); \ CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(GetCudnnHandle(ctx), reduce_desc, indices_cuda.get(), indices_bytes, \ workspace_cuda.get(), workspace_bytes, &one, input_tensor, temp_X.get(), \ &zero, output_tensor, temp_Y.get())); \ @@ -796,9 +806,9 @@ SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int32_t) SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int64_t) SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int8_t) SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(uint8_t) - namespace ReductionOps { +#ifndef BUILD_CUDA_EP_AS_PLUGIN template std::unique_ptr ReduceCompute(const AllocatorPtr& gpu_allocator, cudnnReduceTensorOp_t cudnn_reduce_op, AllocatorPtr allocator, const Tensor& input, gsl::span axes, @@ -818,8 +828,10 @@ std::unique_ptr ReduceCompute(const AllocatorPtr& gpu_allocator, cudnnRe auto output = Tensor::Create(input.DataType(), prepare_reduce_metadata.squeezed_output_dims, allocator); - status = ReduceComputeCore(gpu_allocator, input, prepare_reduce_metadata, *output, cudnn_reduce_op, axes, - calculate_log, calculate_sqt, log_sum_exp, fast_reduction, stream, cudnn_handle, + status = ReduceComputeCore(gpu_allocator, nullptr, input, prepare_reduce_metadata, *output, cudnn_reduce_op, axes, + calculate_log, calculate_sqt, log_sum_exp, fast_reduction, + stream ? static_cast(stream->GetHandle()) : nullptr, + static_cast(stream), cudnn_handle, input_shape_override); if (!status.IsOK()) { @@ -852,6 +864,7 @@ template std::unique_ptr ReduceCompute -Status ReduceComputeCore(const AllocatorPtr& allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, +Status ReduceComputeCore(const AllocatorPtr& allocator, const CudaKernel* kernel, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, gsl::span axes, bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, - Stream* ort_stream, cudnnHandle_t cudnn_handle, + cudaStream_t cuda_stream, void* compute_stream, cudnnHandle_t cudnn_handle, const TensorShape* input_shape_override = nullptr); // CUDA's reduction descriptor cudnnReduceTensorDescriptor_t is a pointer so diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index 4718e59e5a042..e7c8f52950141 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -88,7 +88,9 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons size_t& reorganized_w_data_size_in_bytes, IAllocatorUniquePtr& reorganized_w_data, CudnnFilterDescriptor& target_w_desc, - CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const { + CudnnRNN& rnn_desc, + void* alloc_stream, cudaStream_t cuda_stream, + cudnnHandle_t cudnn_handle) const { typedef typename ToCudaType::MappedType CudaT; int64_t input_size = W->Shape()[2]; // RNN W[num_directions_, hidden_size_, input_size] @@ -107,20 +109,18 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons // Prepare the weight data reorganized_w_data_size_in_bytes = w_size * sizeof(T); - reorganized_w_data = GetScratchBuffer(reorganized_w_data_size_in_bytes, ort_stream); + reorganized_w_data = GetScratchBuffer(reorganized_w_data_size_in_bytes, alloc_stream); // In many cases, this allocation is bigger than needed, leaving part of // the buffer uninitialized. non-zero garbage data leads to wrong result // in call to cudnnRNNForwardInference() // TODO! refine allocation size for each case. - cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; CUDA_RETURN_IF_ERROR(cudaMemsetAsync(reorganized_w_data.get(), 0, reorganized_w_data_size_in_bytes, cuda_stream)); const T* W_data = W->Data(); const T* R_data = R->Data(); const T* B_data = B == nullptr ? nullptr : B->Data(); - cudnnHandle_t cudnn_handle = GetCudnnHandleOrDefault(ort_stream); ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, reorganized_w_data_size_in_bytes, reorganized_w_data.get(), W_data, R_data, B_data, cuda_stream)); @@ -157,11 +157,11 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { if (get_B) { ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, - tmp_rnn_desc, nullptr)); + tmp_rnn_desc, nullptr, nullptr, DefaultCudnnHandle())); } else { ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, - tmp_rnn_desc, nullptr)); + tmp_rnn_desc, nullptr, nullptr, DefaultCudnnHandle())); } cudaStreamSynchronize(nullptr); @@ -238,11 +238,11 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // Prior to cuDNN 8.9.1 the sequence lens buffer must be passed to cudnnRNNForward and thus is must // be copied to the GPU always. - ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(GetComputeStream(ctx))); // Starting with cuDNN 8.9.1 the sequence lens buffer is ignored by cudnnRNNForward and thus it must // be copied to the GPU only for the ReverseBySequence kernels. // if (reverse_) { - // ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); + // ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(GetComputeStream(ctx))); // } // optional outputs @@ -257,7 +257,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { const T* x_data = X->Data(); if (reverse_) { // reverse input data - x_reversed_data = GetScratchBuffer(seq_length * batch_size * input_size, ctx->GetComputeStream()); + x_reversed_data = GetScratchBuffer(seq_length * batch_size * input_size, GetComputeStream(ctx)); ReverseBySequence(Stream(ctx), gsl::narrow_cast(seq_length), sequence_lens_buffer.GpuPtr(), @@ -280,7 +280,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { if (Y != nullptr) { y_data = Y->MutableData(); } else { - y_alloc_data = GetScratchBuffer(output_size, ctx->GetComputeStream()); + y_alloc_data = GetScratchBuffer(output_size, GetComputeStream(ctx)); y_data = y_alloc_data.get(); } @@ -307,7 +307,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { const Tensor& W = *ctx->Input(RNN_Input_Index::W); const Tensor& R = *ctx->Input(RNN_Input_Index::R); ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data_size_in_bytes, w_data, w_desc, - rnn_desc, ctx->GetComputeStream())); + rnn_desc, GetComputeStream(ctx), Stream(ctx), GetCudnnHandle(ctx))); } CudnnDataTensor x_desc1; @@ -329,8 +329,8 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { CUDNN_RETURN_IF_ERROR(cudnnGetRNNTempSpaceSizes(GetCudnnHandle(ctx), rnn_desc, CUDNN_FWD_MODE_INFERENCE, x_desc1, &workspace_bytes, &reservespace_bytes)); - auto workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); - auto reservespace_cuda = GetScratchBuffer(reservespace_bytes, ctx->GetComputeStream()); + auto workspace_cuda = GetScratchBuffer(workspace_bytes, GetComputeStream(ctx)); + auto reservespace_cuda = GetScratchBuffer(reservespace_bytes, GetComputeStream(ctx)); CUDNN_RETURN_IF_ERROR(cudnnRNNForward(GetCudnnHandle(ctx), rnn_desc, @@ -357,7 +357,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // Mask on output for 0 sequence batches if (zero_seq_count > 0) { // Mask on output for 0 sequence batches - SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); + SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, GetComputeStream(ctx), Stream(ctx)); } return Status::OK(); } @@ -365,7 +365,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { IAllocatorUniquePtr y_reorganized_data; if (reverse_ || num_directions_ == 2) { // reverse output - y_reorganized_data = GetScratchBuffer(output_size, ctx->GetComputeStream()); + y_reorganized_data = GetScratchBuffer(output_size, GetComputeStream(ctx)); if (reverse_) { // reverse output data ReverseBySequence(Stream(ctx), @@ -397,7 +397,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // Mask on output for 0 sequence batches if (zero_seq_count > 0) { - SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); + SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, GetComputeStream(ctx), Stream(ctx)); } return Status::OK(); @@ -409,13 +409,12 @@ void CudnnRnnBase::SetZeroSequences(const int64_t zero_seq_index_cache_size, T* y_data, T* y_h_data, T* y_c_data, - onnxruntime::Stream* ort_stream) const { + void* alloc_stream, cudaStream_t cuda_stream) const { typedef typename ToCudaType::MappedType CudaT; CudaAsyncBuffer zero_seq_index_cache_async_buffer(this, zero_seq_index_cache_size); memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), zero_seq_index_cache_size * sizeof(int32_t)); - ORT_THROW_IF_ERROR(zero_seq_index_cache_async_buffer.CopyToGpu(ort_stream)); - cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; + ORT_THROW_IF_ERROR(zero_seq_index_cache_async_buffer.CopyToGpu(alloc_stream)); MaskZeroSequences(cuda_stream, gsl::narrow_cast(hidden_size_), reinterpret_cast(y_data), diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 7b827f8d04593..b7a3d67b45e93 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -138,7 +138,8 @@ class CudnnRnnBase : public CudaKernel { IAllocatorUniquePtr& target_w_data, CudnnFilterDescriptor& target_w_desc, CudnnRNN& rnn_desc, - onnxruntime::Stream* ort_stream) const; + void* alloc_stream, cudaStream_t cuda_stream, + cudnnHandle_t cudnn_handle) const; Status SetWeightBias(const cudnnHandle_t handle, const cudnnRNNDescriptor_t rnn_desc, @@ -156,7 +157,7 @@ class CudnnRnnBase : public CudaKernel { T* y_data, T* y_h_data, T* y_c_data, - onnxruntime::Stream* cuda_stream) const; + void* alloc_stream, cudaStream_t cuda_stream) const; protected: // W_lin_layer_id_ & R_lin_layer_id_ are set in Constructor diff --git a/onnxruntime/core/providers/cuda/shared_inc/integer_gemm.h b/onnxruntime/core/providers/cuda/shared_inc/integer_gemm.h index ea1f28f734c25..7ef4b7218611b 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/integer_gemm.h +++ b/onnxruntime/core/providers/cuda/shared_inc/integer_gemm.h @@ -19,6 +19,7 @@ Status GemmInt8(int m, int32_t* c, int ldc, const CudaKernel* cuda_kernel, - onnxruntime::Stream* stream); + void* alloc_stream, cudaStream_t cuda_stream, + cublasHandle_t cublas_handle); } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/compress.cc b/onnxruntime/core/providers/cuda/tensor/compress.cc index a75f24341fc47..87c6f068dacd0 100644 --- a/onnxruntime/core/providers/cuda/tensor/compress.cc +++ b/onnxruntime/core/providers/cuda/tensor/compress.cc @@ -48,7 +48,7 @@ Status Compress::ComputeInternal(OpKernelContext* ctx) const { int64_t compress_input_length = has_axis_ ? input_dimensions[axis] : input_size; int64_t valid_condition_length = compress_input_length < condition_length ? compress_input_length : condition_length; - auto condition_cumulative_sum_buffer = GetScratchBuffer(gsl::narrow(valid_condition_length), ctx->GetComputeStream()); + auto condition_cumulative_sum_buffer = GetScratchBuffer(gsl::narrow(valid_condition_length), GetComputeStream(ctx)); auto condition_cumulative_sum = condition_cumulative_sum_buffer.get(); size_t temp_storage_bytes = 0; @@ -58,7 +58,7 @@ Status Compress::ComputeInternal(OpKernelContext* ctx) const { gsl::narrow(valid_condition_length), temp_storage_bytes)); - auto temp_buffer = GetScratchBuffer(temp_storage_bytes, ctx->GetComputeStream()); + auto temp_buffer = GetScratchBuffer(temp_storage_bytes, GetComputeStream(ctx)); auto d_temp_storage = temp_buffer.get(); CUDA_RETURN_IF_ERROR(CompressInclusivePrefixSum(Stream(ctx), d_temp_storage, diff --git a/onnxruntime/core/providers/cuda/tensor/concat.cc b/onnxruntime/core/providers/cuda/tensor/concat.cc index 4a390541009d9..711f30e7a57ca 100644 --- a/onnxruntime/core/providers/cuda/tensor/concat.cc +++ b/onnxruntime/core/providers/cuda/tensor/concat.cc @@ -33,7 +33,7 @@ ONNX_OPERATOR_KERNEL_EX(Concat, Concat); Status Concat::ComputeInternal(OpKernelContext* ctx) const { - auto input_count = Node().InputArgCount().front(); + auto input_count = ctx->InputCount(); // Hold pointers to the input tensors to be used in the PrepareForCompute() step InlinedTensorsVector input_tensors; @@ -43,7 +43,7 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const { } Prepare p; - ORT_RETURN_IF_ERROR(PrepareForCompute(ctx, input_tensors, p)); + ORT_RETURN_IF_ERROR(PrepareForComputeImpl(ctx, input_tensors, p)); // Return at this point if output tensor is going to be empty if (p.output_num_elements == 0) @@ -76,7 +76,7 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const { Stream(ctx), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim, concat_sizes[0], p.output_tensor->MutableDataRaw(), input_ptr_array, static_cast(p.output_num_elements))); } else { - ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu(GetComputeStream(ctx))); ORT_RETURN_IF_ERROR(ConcatSameConcatDimImpl( Stream(ctx), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim, concat_sizes[0], p.output_tensor->MutableDataRaw(), input_ptr.GpuPtr(), static_cast(p.output_num_elements))); @@ -89,10 +89,10 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const { concat_sizes_range[i] += concat_sizes_range[i - 1]; } CudaAsyncBuffer concat_sizes_range_gpu(this, concat_sizes_range); - ORT_RETURN_IF_ERROR(concat_sizes_gpu.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(concat_sizes_range_gpu.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(concat_sizes_gpu.CopyToGpu(GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu(GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(concat_sizes_range_gpu.CopyToGpu(GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu(GetComputeStream(ctx))); ORT_RETURN_IF_ERROR(ConcatImpl(Stream(ctx), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim, concat_sizes_gpu.GpuPtr(), concat_sizes_range_gpu.GpuPtr(), axis_dimension_input_output_mapping_gpu.GpuPtr(), p.output_tensor->MutableDataRaw(), diff --git a/onnxruntime/core/providers/cuda/tensor/gather.cc b/onnxruntime/core/providers/cuda/tensor/gather.cc index b7170872b0c0f..0f1ce81099b56 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather.cc +++ b/onnxruntime/core/providers/cuda/tensor/gather.cc @@ -46,7 +46,7 @@ ONNX_OPERATOR_KERNEL_EX( Status Gather::ComputeInternal(OpKernelContext* context) const { Prepare p; - ORT_RETURN_IF_ERROR(PrepareForCompute(context, p)); + ORT_RETURN_IF_ERROR(PrepareForComputeImpl(context, p)); const TensorShape& input_shape = p.input_tensor->Shape(); diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd.cc b/onnxruntime/core/providers/cuda/tensor/gather_nd.cc index 0be8bab97c005..7933cbc53148c 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd.cc +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd.cc @@ -40,7 +40,8 @@ Status CheckBatchDimensionsMatch( template Status GatherNDBase::PrepareCompute( - onnxruntime::Stream* stream, + void* alloc_stream, + cudaStream_t cuda_stream, const int64_t batch_dims, const TensorShape& input_shape, const TensorShape& indices_shape, @@ -54,7 +55,6 @@ Status GatherNDBase::PrepareCompute( const auto num_batches = input_shape.SizeToDimension(batch_dims); const auto input_batch_stride = input_shape.SizeFromDimension(batch_dims); const auto num_slices_per_batch = num_slices / num_batches; - cudaStream_t cuda_stream = stream ? static_cast(stream->GetHandle()) : nullptr; const TIndex* const indices_data = indices_tensor->Data(); @@ -67,14 +67,14 @@ Status GatherNDBase::PrepareCompute( } } - auto sizes_from_slice_dims_buffer = GetScratchBuffer(sizes_from_slice_dims.size(), stream); + auto sizes_from_slice_dims_buffer = GetScratchBuffer(sizes_from_slice_dims.size(), alloc_stream); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( sizes_from_slice_dims_buffer.get(), sizes_from_slice_dims.data(), sizes_from_slice_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice, cuda_stream)); - input_slice_offsets_buffer = GetScratchBuffer(num_slices, stream); + input_slice_offsets_buffer = GetScratchBuffer(num_slices, alloc_stream); TArray input_dims(input_shape.GetDims()); @@ -180,7 +180,7 @@ Status GatherND::ComputeInternal(OpKernelContext* context) const { int64_t num_slices; int64_t slice_size; IAllocatorUniquePtr input_slice_offsets_buffer; - ORT_RETURN_IF_ERROR(PrepareCompute(context->GetComputeStream(), + ORT_RETURN_IF_ERROR(PrepareCompute(GetComputeStream(context), Stream(context), batch_dims_, input_shape, indices_shape, indices_tensor, num_slices, slice_size, input_slice_offsets_buffer)); diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd.h b/onnxruntime/core/providers/cuda/tensor/gather_nd.h index 0d63c8a159783..6973aadbe73ae 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd.h +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd.h @@ -23,7 +23,8 @@ class GatherNDBase : public CudaKernel { protected: template Status PrepareCompute( - onnxruntime::Stream* stream, + void* alloc_stream, + cudaStream_t cuda_stream, const int64_t batch_dims, const TensorShape& input_shape, const TensorShape& indices_shape, diff --git a/onnxruntime/core/providers/cuda/tensor/nonzero_op.cc b/onnxruntime/core/providers/cuda/tensor/nonzero_op.cc index 142bcb7eeb512..98bbff9855284 100644 --- a/onnxruntime/core/providers/cuda/tensor/nonzero_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/nonzero_op.cc @@ -63,13 +63,13 @@ Status NonZero::ComputeInternal(OpKernelContext* context) const { auto x_data = reinterpret_cast::MappedType*>(x->Data()); const int number_of_blocks = NonZeroCalcBlockCount(x_size); - auto prefix_buffer = GetScratchBuffer(number_of_blocks, context->GetComputeStream()); + auto prefix_buffer = GetScratchBuffer(number_of_blocks, GetComputeStream(context)); int* prefix_counts = prefix_buffer.get(); CUDA_RETURN_IF_ERROR(NonZeroCountEachBlock(Stream(context), x_data, x_size, prefix_counts)); size_t temp_storage_bytes = 0; CUDA_RETURN_IF_ERROR(NonZeroCalcPrefixSumTempStorageBytes(Stream(context), prefix_counts, number_of_blocks, temp_storage_bytes)); - auto temp_buffer = GetScratchBuffer(temp_storage_bytes, context->GetComputeStream()); + auto temp_buffer = GetScratchBuffer(temp_storage_bytes, GetComputeStream(context)); auto d_temp_storage = temp_buffer.get(); CUDA_RETURN_IF_ERROR(NonZeroInclusivePrefixSum(Stream(context), d_temp_storage, temp_storage_bytes, prefix_counts, number_of_blocks)); diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index 9b23209953081..59a33899fd4b2 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -55,6 +55,47 @@ namespace cuda { using PadsVector = PadBase::PadsVector; +template +static void ComputePadsLocal(KernelContextType& ctx, + size_t data_rank, + gsl::span pads_data, + PadsVector& pads) { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + PadBase::ComputePadsImpl(ctx, data_rank, pads_data, pads); +#else + PadBase::ComputePads(ctx, data_rank, pads_data, pads); +#endif +} + +static Status HandleDimValueZeroLocal(const Mode& mode, + const TensorShape& input_shape, + const TensorShape& output_shape) { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + switch (mode) { + case Mode::Constant: + break; + case Mode::Edge: + case Mode::Reflect: { + for (size_t i = 0, end = input_shape.NumDimensions(); i < end; ++i) { + if (input_shape[i] == 0 && output_shape[i] > 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Cannot use '", mode == Mode::Edge ? "edge" : "reflect", + "' mode to pad dimension with a value of 0. Input shape:", + input_shape); + } + } + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected mode of ", static_cast(mode)); + } + + return Status::OK(); +#else + return PadBase::HandleDimValueZero(mode, input_shape, output_shape); +#endif +} + static bool IsNCHWInputWithPaddingAlongHAndW(size_t input_rank, const TArray& lower_pads, const TArray& upper_pads) { @@ -161,7 +202,7 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { // this is expected for constant mode only, otherwise the output is empty // no error if (input_shape.Size() == 0) { - ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode_, input_shape, output_shape)); + ORT_RETURN_IF_ERROR(HandleDimValueZeroLocal(mode_, input_shape, output_shape)); if (mode_ == Mode::Constant) { const int64_t output_size = output_shape.Size(); if (output_size > 0) { diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc b/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc index 60f1a82605d26..81745733b16bc 100644 --- a/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc @@ -7,6 +7,49 @@ #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/providers/cpu/tensor/utils.h" +#ifdef BUILD_CUDA_EP_AS_PLUGIN +// In the plugin build, SCATTER_ND_VALIDATE_SHAPES is not accessible +// (it lives in the CPU provider). Provide an inline equivalent. +namespace onnxruntime { +namespace scatter_nd_plugin { +inline Status ValidateShapes(const TensorShape& input_shape, + const TensorShape& indice_shape, + const TensorShape& update_shape) { + auto input_rank = input_shape.NumDimensions(); + auto indice_rank = indice_shape.NumDimensions(); + auto update_rank = update_shape.NumDimensions(); + if (input_rank == 0 || indice_rank == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input tensor and indices tensor must have rank larger than 0. ", + "input shape: ", input_shape, ", indices shape: ", indice_shape); + } + auto last_indice_dimension = indice_shape[indice_rank - 1]; + if (last_indice_dimension > static_cast(input_rank)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "last dimension of indices must not be larger than rank of input tensor"); + } + auto expected_update_rank = input_rank + indice_rank - 1 - static_cast(last_indice_dimension); + if (update_rank != expected_update_rank) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "update tensor shape does not match expected shape"); + } + if (indice_shape.Slice(0, indice_rank - 1) != update_shape.Slice(0, indice_rank - 1)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "update tensor shape mismatch with indices tensor shape"); + } + if (input_shape.Slice(onnxruntime::narrow(last_indice_dimension)) != update_shape.Slice(indice_rank - 1)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "update tensor shape mismatch with input tensor shape"); + } + return Status::OK(); +} +} // namespace scatter_nd_plugin +} // namespace onnxruntime +#define SCATTER_ND_VALIDATE_SHAPES onnxruntime::scatter_nd_plugin::ValidateShapes +#else +#define SCATTER_ND_VALIDATE_SHAPES onnxruntime::ScatterND::ValidateShapes +#endif + namespace onnxruntime { namespace cuda { @@ -50,7 +93,7 @@ template static Status InitializeElementCountsAndInputDimsSpanOrGpu(int64_t last_index_dimension, const TensorShape& input_shape, ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims, CudaKernel::CudaAsyncBuffer& element_counts_and_input_dims_gpu, - KernelContextType* context) { + KernelContextType* stream) { TensorPitches input_strides(input_shape); if (last_index_dimension < 6) { @@ -66,7 +109,7 @@ static Status InitializeElementCountsAndInputDimsSpanOrGpu(int64_t last_index_di element_counts_and_input_dims_gpu.CpuPtr()[i] = input_strides[i]; element_counts_and_input_dims_gpu.CpuPtr()[i + last_index_dimension] = input_shape[i]; } - ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(context->GetComputeStream())); + ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(stream)); element_counts_and_input_dims.gpu_ptr = element_counts_and_input_dims_gpu.GpuPtr(); } return Status::OK(); @@ -82,7 +125,7 @@ Status ScatterNDDisjointAndNoReduction::ComputeInternal(OpKernelContext* context const auto& updates_shape = updates_tensor->Shape(); // Validate input shapes - ORT_RETURN_IF_ERROR(onnxruntime::ScatterND::ValidateShapes(input_shape, indices_shape, updates_shape)); + ORT_RETURN_IF_ERROR(SCATTER_ND_VALIDATE_SHAPES(input_shape, indices_shape, updates_shape)); auto* output_tensor = context->Output(0, input_shape); @@ -111,7 +154,7 @@ Status ScatterNDDisjointAndNoReduction::ComputeInternal(OpKernelContext* context ORT_RETURN_IF_ERROR(InitializeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape, element_counts_and_input_dims, element_counts_and_input_dims_gpu, - context)); + GetComputeStream(context))); ORT_RETURN_IF_ERROR(ScatterNDImpl( Stream(context), @@ -137,7 +180,7 @@ Status ScatterNDWithAtomicReduction::ComputeInternal(OpKernelContext* context) c const auto& updates_shape = updates_tensor->Shape(); // Validate input shapes - ORT_RETURN_IF_ERROR(onnxruntime::ScatterND::ValidateShapes(input_shape, indices_shape, updates_shape)); + ORT_RETURN_IF_ERROR(SCATTER_ND_VALIDATE_SHAPES(input_shape, indices_shape, updates_shape)); auto* output_tensor = context->Output(0, input_shape); @@ -163,7 +206,7 @@ Status ScatterNDWithAtomicReduction::ComputeInternal(OpKernelContext* context) c ORT_RETURN_IF_ERROR(InitializeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape, element_counts_and_input_dims, element_counts_and_input_dims_gpu, - context)); + GetComputeStream(context))); switch (reduction_) { case ScatterNDReduction::None: { diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd.h b/onnxruntime/core/providers/cuda/tensor/scatter_nd.h index 6d8bbe6f463fd..b46d896331c1d 100644 --- a/onnxruntime/core/providers/cuda/tensor/scatter_nd.h +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd.h @@ -7,7 +7,9 @@ #include "core/providers/shared_library/provider_api.h" #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/tensor/scatter_nd_kind.h" +#ifndef BUILD_CUDA_EP_AS_PLUGIN #include "core/providers/cpu/tensor/scatter_nd.h" +#endif namespace onnxruntime { namespace cuda { diff --git a/onnxruntime/core/providers/cuda/tensor/slice.cc b/onnxruntime/core/providers/cuda/tensor/slice.cc index db285ba547b6a..34de6eeac3ea2 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.cc +++ b/onnxruntime/core/providers/cuda/tensor/slice.cc @@ -177,10 +177,10 @@ Status Slice::ComputeInternal(OpKernelContext* ctx) const { if (dynamic) { TensorShapeVector input_starts, input_ends, input_axes, input_steps; ORT_RETURN_IF_ERROR(FillInputVectors(ctx, input_starts, input_ends, input_axes, input_steps)); - ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata)); + ORT_RETURN_IF_ERROR(SliceBase::PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata)); } else { - ORT_RETURN_IF_ERROR(PrepareForCompute(StartsAttribute(), EndsAttribute(), AxesAttribute(), compute_metadata)); + ORT_RETURN_IF_ERROR(SliceBase::PrepareForCompute(StartsAttribute(), EndsAttribute(), AxesAttribute(), compute_metadata)); } TensorShape output_shape(compute_metadata.output_dims_); @@ -212,8 +212,8 @@ template Status Slice::FillInputVectors(OpKernelContext* ctx, TensorShapeVector& input_starts, TensorShapeVector& input_ends, TensorShapeVector& input_axes, TensorShapeVector& input_steps) const { - return FillVectorsFromInput(*ctx->Input(1), *ctx->Input(2), ctx->Input(3), - ctx->Input(4), input_starts, input_ends, input_axes, input_steps); + return SliceBase::FillVectorsFromInput(*ctx->Input(1), *ctx->Input(2), ctx->Input(3), + ctx->Input(4), input_starts, input_ends, input_axes, input_steps); } template @@ -258,12 +258,7 @@ Status FuncSlice( SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions); - ORT_RETURN_IF_ERROR( - SliceOp::PrepareForComputeHelper(starts_span, ends_span, axes_span, steps_span, compute_metadata)); - - ORT_RETURN_IF_ERROR(SliceBase::FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_, - compute_metadata.ends_, compute_metadata.steps_, compute_metadata.p_flattened_input_dims_, - compute_metadata.p_flattened_output_dims_)); + ORT_RETURN_IF_ERROR(SliceBase::PrepareForCompute(starts_span, ends_span, axes_span, steps_span, compute_metadata)); TensorShape output_shape(compute_metadata.output_dims_); diff --git a/onnxruntime/core/providers/cuda/tensor/slice.h b/onnxruntime/core/providers/cuda/tensor/slice.h index 1a3ccd11cb1b9..050206f47217b 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.h +++ b/onnxruntime/core/providers/cuda/tensor/slice.h @@ -23,9 +23,11 @@ Status Impl(cudaStream_t stream, template class Slice : public CudaKernel, public SliceBase { public: - explicit Slice(const OpKernelInfo& info) : CudaKernel(info), SliceBase(info, dynamic, CudaProviderTag{}) {} + explicit Slice(const OpKernelInfo& info) : CudaKernel(info), + SliceBase(info, dynamic, CudaProviderTag{}) {} - Status ComputeInternal(OpKernelContext* ctx) const override; + Status + ComputeInternal(OpKernelContext* ctx) const override; private: virtual const Tensor* GetSlicedOrUnslicedTensor(OpKernelContext* ctx) const; diff --git a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h index 8780d9b365005..57914ba3af321 100644 --- a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h +++ b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h @@ -4,24 +4,129 @@ #pragma once #include "core/providers/cuda/cuda_kernel.h" +#ifndef BUILD_CUDA_EP_AS_PLUGIN #include "core/providers/cpu/tensor/space_depth_ops.h" +#endif namespace onnxruntime { namespace cuda { +#ifdef BUILD_CUDA_EP_AS_PLUGIN +// Plugin-local equivalent of SpaceDepthBase. +// The CPU header cannot be included in the plugin build because it pulls in +// core/framework/op_kernel.h which conflicts with the adapter types. +namespace detail { + +template +Status InputValidationsAndOutputDimsCalc(int64_t blocksize, + const Tensor& input, + int64_t& batch, + int64_t& input_depth, int64_t& input_height, int64_t& input_width, + int64_t& output_depth, int64_t& output_height, int64_t& output_width, + bool is_space_to_depth) { + const TensorShape& input_shape = input.Shape(); + + if (input_shape.NumDimensions() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "SpaceDepth ops require a 4-D input. Provided rank: ", + input_shape.NumDimensions()); + } + + batch = input_shape[0]; + if constexpr (IsNHWC) { + input_depth = input_shape[3]; + input_height = input_shape[1]; + input_width = input_shape[2]; + } else { + input_depth = input_shape[1]; + input_height = input_shape[2]; + input_width = input_shape[3]; + } + + if (is_space_to_depth) { + if ((input_height % blocksize) != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "SpaceToDepth requires input height to be a multiple of block_size"); + } + if ((input_width % blocksize) != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "SpaceToDepth requires input width to be a multiple of block_size"); + } + output_depth = input_depth * blocksize * blocksize; + output_height = input_height / blocksize; + output_width = input_width / blocksize; + } else { + if ((input_depth % (blocksize * blocksize) != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DepthToSpace requires input depth to be a multiple of (block_size * block_size)"); + } + output_depth = input_depth / blocksize / blocksize; + output_height = input_height * blocksize; + output_width = input_width * blocksize; + } + + return Status::OK(); +} + +} // namespace detail +#endif // BUILD_CUDA_EP_AS_PLUGIN + template -class SpaceToDepth final : public CudaKernel, SpaceDepthBase { +class SpaceToDepth final : public CudaKernel +#ifndef BUILD_CUDA_EP_AS_PLUGIN + , + SpaceDepthBase +#endif +{ public: - explicit SpaceToDepth(const OpKernelInfo& info) : CudaKernel(info), SpaceDepthBase(info) { + explicit SpaceToDepth(const OpKernelInfo& info) + : CudaKernel(info) +#ifndef BUILD_CUDA_EP_AS_PLUGIN + , + SpaceDepthBase(info) +#endif + { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + ORT_ENFORCE(info.GetAttr("blocksize", &blocksize_).IsOK(), + "Attribute blocksize is not set."); +#endif } Status ComputeInternal(OpKernelContext* context) const override; + +#ifdef BUILD_CUDA_EP_AS_PLUGIN + protected: + template + Status InputValidationsAndOutputDimsCalc(const Tensor& input, + int64_t& batch, + int64_t& input_depth, int64_t& input_height, int64_t& input_width, + int64_t& output_depth, int64_t& output_height, int64_t& output_width, + bool is_space_to_depth) const { + return detail::InputValidationsAndOutputDimsCalc( + blocksize_, input, batch, input_depth, input_height, input_width, + output_depth, output_height, output_width, is_space_to_depth); + } + + int64_t blocksize_; +#endif }; template -class DepthToSpace final : public CudaKernel, SpaceDepthBase { +class DepthToSpace final : public CudaKernel +#ifndef BUILD_CUDA_EP_AS_PLUGIN + , + SpaceDepthBase +#endif +{ public: - explicit DepthToSpace(const OpKernelInfo& info) : CudaKernel(info), SpaceDepthBase(info) { + explicit DepthToSpace(const OpKernelInfo& info) + : CudaKernel(info) +#ifndef BUILD_CUDA_EP_AS_PLUGIN + , + SpaceDepthBase(info) +#endif + { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + ORT_ENFORCE(info.GetAttr("blocksize", &blocksize_).IsOK(), + "Attribute blocksize is not set."); +#endif std::string mode; // if mode doesn't exist, then it is the default "DCR" mode // (or) it is an opset < 11 model for which the only mode is "DCR" mode @@ -38,6 +143,22 @@ class DepthToSpace final : public CudaKernel, SpaceDepthBase { private: bool is_dcr_ = true; + +#ifdef BUILD_CUDA_EP_AS_PLUGIN + protected: + template + Status InputValidationsAndOutputDimsCalc(const Tensor& input, + int64_t& batch, + int64_t& input_depth, int64_t& input_height, int64_t& input_width, + int64_t& output_depth, int64_t& output_height, int64_t& output_width, + bool is_space_to_depth) const { + return detail::InputValidationsAndOutputDimsCalc( + blocksize_, input, batch, input_depth, input_height, input_width, + output_depth, output_height, output_width, is_space_to_depth); + } + + int64_t blocksize_; +#endif }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/split.cc b/onnxruntime/core/providers/cuda/tensor/split.cc index ca82387600085..9322c261761a7 100644 --- a/onnxruntime/core/providers/cuda/tensor/split.cc +++ b/onnxruntime/core/providers/cuda/tensor/split.cc @@ -42,6 +42,63 @@ ONNX_OPERATOR_KERNEL_EX(Split, .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), Split_18); +Status SplitKernel::PrepareForComputeLocal(const TensorShape& input_shape, + int num_outputs, + int64_t& axis, + int& before_dims, + int& after_dims_including_split_axis, + int& after_dims_excluding_split, + std::vector& split_sizes) const { + auto input_dims = input_shape.GetDims(); + const auto num_dimensions = gsl::narrow_cast(input_shape.NumDimensions()); + axis = HandleNegativeAxis(axis_, num_dimensions); + const int64_t split_dim_size = input_dims[onnxruntime::narrow(axis)]; + + before_dims = gsl::narrow_cast(input_shape.SizeToDimension(onnxruntime::narrow(axis))); + after_dims_including_split_axis = gsl::narrow_cast(input_shape.SizeFromDimension(onnxruntime::narrow(axis))); + after_dims_excluding_split = (axis + 1 == num_dimensions) + ? 1 + : gsl::narrow_cast(input_shape.SizeFromDimension(onnxruntime::narrow(axis + 1))); + + if (num_outputs_ != -1) { + if (num_outputs_ > split_dim_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid num_outputs value of ", num_outputs_, + ". Size of dimension being split is ", split_dim_size); + } + + int64_t size = (split_dim_size + num_outputs_ - 1) / num_outputs_; + int64_t remainder = split_dim_size % size; + + split_sizes = std::vector(num_outputs, size); + if (remainder) { + split_sizes.back() = remainder; + } + } + + if (split_sizes.empty()) { + if (split_dim_size % static_cast(num_outputs) != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input cannot be split evenly on selected axis. Input shape=", input_shape, + " Axis=", axis_, " NumOutputs=", num_outputs); + } + split_sizes = std::vector(static_cast(num_outputs), split_dim_size / num_outputs); + } else { + int64_t split_size_sum = split_size_sum_; + if (split_size_sum == -1) { + split_size_sum = std::accumulate(split_sizes.cbegin(), split_sizes.cend(), 0LL); + } + if (split_sizes.size() != static_cast(num_outputs) || split_size_sum != split_dim_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Cannot split using values in 'split' attribute. Axis=", axis_, + " Input shape=", input_shape, + " NumOutputs=", num_outputs, + " Num entries in 'split' (must equal number of outputs) was ", split_sizes.size(), + " Sum of sizes in 'split' (must equal size of selected axis) was ", split_size_sum); + } + } + + return Status::OK(); +} + Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const { const Tensor* input_tensor = ctx->Input(0); ORT_ENFORCE(input_tensor); @@ -63,13 +120,13 @@ Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const { split_sizes.assign(split_sizes_.begin(), split_sizes_.end()); } - ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape, - num_outputs, - axis, - before_dims, - block_size_including_axis_dim, - block_size_inside_axis_dim, - split_sizes)); + ORT_RETURN_IF_ERROR(PrepareForComputeLocal(input_shape, + num_outputs, + axis, + before_dims, + block_size_including_axis_dim, + block_size_inside_axis_dim, + split_sizes)); auto input_data = input_tensor->DataRaw(); @@ -129,23 +186,23 @@ Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const { block_size_inside_axis_dim, split_sizes[0], num_outputs, input_data, output_ptr_array, static_cast(input_shape.Size()))); } else { - ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu(GetComputeStream(ctx))); ORT_RETURN_IF_ERROR(SplitSameSplitDimImpl(Stream(ctx), element_size, block_size_including_axis_dim, block_size_inside_axis_dim, split_sizes[0], num_outputs, input_data, output_ptr.GpuPtr(), static_cast(input_shape.Size()))); } } else { - ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu(GetComputeStream(ctx))); CudaAsyncBuffer split_sizes_gpu(this, split_sizes); - ORT_RETURN_IF_ERROR(split_sizes_gpu.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(split_sizes_gpu.CopyToGpu(GetComputeStream(ctx))); std::vector split_sizes_range(split_sizes); for (size_t i = 1; i < split_sizes_range.size(); ++i) { split_sizes_range[i] += split_sizes_range[i - 1]; } CudaAsyncBuffer split_sizes_range_gpu(this, split_sizes_range); - ORT_RETURN_IF_ERROR(split_sizes_range_gpu.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(split_sizes_range_gpu.CopyToGpu(GetComputeStream(ctx))); CudaAsyncBuffer axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping); - ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu(GetComputeStream(ctx))); ORT_RETURN_IF_ERROR(SplitImpl(Stream(ctx), element_size, block_size_including_axis_dim, block_size_inside_axis_dim, split_sizes_gpu.GpuPtr(), split_sizes_range_gpu.GpuPtr(), axis_dimension_input_output_mapping_gpu.GpuPtr(), num_outputs, input_data, diff --git a/onnxruntime/core/providers/cuda/tensor/split.h b/onnxruntime/core/providers/cuda/tensor/split.h index 00d80f65b79c0..6820d58d9c953 100644 --- a/onnxruntime/core/providers/cuda/tensor/split.h +++ b/onnxruntime/core/providers/cuda/tensor/split.h @@ -13,6 +13,15 @@ class SplitKernel : public CudaKernel, public SplitBase { SplitKernel(const OpKernelInfo& info, uint32_t opset) : CudaKernel(info), SplitBase(info, opset) {} Status ComputeInternal(OpKernelContext* context) const override; + + private: + Status PrepareForComputeLocal(const TensorShape& input_shape, + int num_outputs, + int64_t& axis, + int& before_dims, + int& after_dims_including_split_axis, + int& after_dims_excluding_split, + std::vector& split_sizes) const; }; // versions 2, 11 and 13 diff --git a/onnxruntime/core/providers/cuda/tensor/tile.cc b/onnxruntime/core/providers/cuda/tensor/tile.cc index a0dd78b5e60e1..5fd250a93d6f8 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile.cc +++ b/onnxruntime/core/providers/cuda/tensor/tile.cc @@ -9,6 +9,40 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace cuda { +namespace { + +bool IsTileMemcpyForPlugin(const TensorShape& input_shape, + const int64_t* repeats, + size_t rank, + /*out*/ bool& is_batched_memcpy, + /*out*/ size_t& num_of_elements_per_batch, + /*out*/ size_t& num_of_copies_per_batch, + /*out*/ size_t& num_of_batch_copies) { + for (int64_t i = static_cast(rank) - 1; i >= 0; --i) { + if (repeats[i] != 1) { + if (input_shape.SizeToDimension(onnxruntime::narrow(i)) == 1) { + num_of_copies_per_batch = 1; + for (int64_t j = 0; j <= i; ++j) { + num_of_copies_per_batch *= onnxruntime::narrow(repeats[onnxruntime::narrow(j)]); + } + is_batched_memcpy = false; + return true; + } else if (i == 1) { + num_of_elements_per_batch = static_cast(input_shape.SizeFromDimension(1)); + num_of_copies_per_batch = onnxruntime::narrow(repeats[onnxruntime::narrow(i)]); + num_of_batch_copies = onnxruntime::narrow(repeats[0]); + is_batched_memcpy = true; + return true; + } else { + break; + } + } + } + return false; +} + +} // namespace + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Tile, kOnnxDomain, diff --git a/onnxruntime/core/providers/cuda/tensor/transpose.cc b/onnxruntime/core/providers/cuda/tensor/transpose.cc index 51aa46df18bc8..82096a2f397a7 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose.cc +++ b/onnxruntime/core/providers/cuda/tensor/transpose.cc @@ -99,7 +99,7 @@ Status Transpose::DoTranspose(const Transpose& transpose_kernel, onnxruntime::Stream* ort_stream, const gsl::span& permutations, const Tensor& input, Tensor& output) { cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - const cublasHandle_t cublas_handle = transpose_kernel.GetCublasHandleOrDefault(ort_stream); + const cublasHandle_t cublas_handle = transpose_kernel.GetCublasHandle(ort_stream); return Transpose::DoTranspose(transpose_kernel.GetDeviceProp(), cuda_stream, cublas_handle, diff --git a/onnxruntime/core/providers/cuda/tensor/unsqueeze.cc b/onnxruntime/core/providers/cuda/tensor/unsqueeze.cc index 411f94f31c7ba..d4080b7ef49e5 100644 --- a/onnxruntime/core/providers/cuda/tensor/unsqueeze.cc +++ b/onnxruntime/core/providers/cuda/tensor/unsqueeze.cc @@ -6,6 +6,57 @@ namespace onnxruntime { namespace cuda { +namespace { + +Status PrepareComputeForPlugin(OpKernelContext* ctx, UnsqueezeBase::Prepare& p, const TensorShapeVector& axes_attr) { + const auto* input = ctx->Input(0); + ORT_ENFORCE(input != nullptr); + auto& input_tensor = *input; + + TensorShapeVector axes; + size_t num_inputs = static_cast(ctx->InputCount()); + if (num_inputs == 2) { + const Tensor* axes_tensor = ctx->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 || + axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a scalar or a 1-D tensor."); + auto data_span = axes_tensor->DataAsSpan(); + axes.assign(data_span.begin(), data_span.end()); + } else { + axes.assign(axes_attr.begin(), axes_attr.end()); + } + + TensorShapeVector output_dims(axes.size() + input_tensor.Shape().NumDimensions(), 0); + for (int64_t axis : axes) { + axis = HandleNegativeAxis(axis, onnxruntime::narrow(output_dims.size())); + if (axis < 0 || axis >= static_cast(output_dims.size())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has an out of range axis"); + } + + auto axis_index = onnxruntime::narrow(axis); + if (output_dims[axis_index] != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has a duplicate axis"); + } + output_dims[axis_index] = 1; + } + + auto begin = input_tensor.Shape().GetDims().begin(); + for (auto& axis_size : output_dims) { + if (axis_size == 0) { + axis_size = *begin++; + } + } + + TensorShape output_shape(output_dims); + p.output_tensor = ctx->Output(0, output_shape); + ORT_ENFORCE(p.output_tensor != nullptr); + p.input_tensor = &input_tensor; + return Status::OK(); +} + +} // namespace + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Unsqueeze, kOnnxDomain, @@ -85,7 +136,11 @@ ONNX_OPERATOR_KERNEL_EX( Status Unsqueeze::ComputeInternal(OpKernelContext* ctx) const { Prepare p; +#ifdef BUILD_CUDA_EP_AS_PLUGIN + ORT_RETURN_IF_ERROR(PrepareComputeForPlugin(ctx, p, axes_)); +#else ORT_RETURN_IF_ERROR(PrepareCompute(ctx, p)); +#endif const void* input = p.input_tensor->DataRaw(); void* output = p.output_tensor->MutableDataRaw(); diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index e2c08618264dd..974430d7ee5dc 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -41,16 +41,7 @@ REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9); template -Upsample::Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { - if (UpsampleBase::antialias_) { - // Copy the table on DEVICE - const uint8_t* lookup_table = GetLookupTableShared(); - auto alloc = info.GetAllocator(OrtMemTypeDefault); - shared_lookup_table_ondevice_ = IAllocator::MakeUniquePtr(std::move(alloc), kLookupTableSize); - CUDA_CALL_THROW(cudaMemcpyAsync(shared_lookup_table_ondevice_.get(), lookup_table, kLookupTableSize, - cudaMemcpyHostToDevice, nullptr)); - } -} +Upsample::Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) {} template Status Upsample::BaseCompute(OpKernelContext* context, @@ -104,8 +95,13 @@ Status Upsample::BaseCompute(OpKernelContext* context, } if (antialias_) { + const uint8_t* lookup_table = GetLookupTableShared(); + auto shared_lookup_table_ondevice = GetScratchBuffer(kLookupTableSize, GetComputeStream(context)); + CUDA_CALL_THROW(cudaMemcpyAsync(shared_lookup_table_ondevice.get(), lookup_table, kLookupTableSize, + cudaMemcpyHostToDevice, Stream(context))); + TempSpaceAllocateFunc allocate_temp_space = [&](size_t bytes_size) { - return GetScratchBuffer(bytes_size, context->GetComputeStream()); + return GetScratchBuffer(bytes_size, GetComputeStream(context)); }; std::optional extrapolation_value; @@ -170,7 +166,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, extrapolation_value, exclude_outside_, allocate_temp_space, - shared_lookup_table_ondevice_.get(), + shared_lookup_table_ondevice.get(), reinterpret_cast(X->Data()), reinterpret_cast(Y->MutableData()), output_count); @@ -213,7 +209,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, extrapolation_value, exclude_outside_, allocate_temp_space, - shared_lookup_table_ondevice_.get(), + shared_lookup_table_ondevice.get(), reinterpret_cast(X->Data()), reinterpret_cast(Y->MutableData()), output_count); @@ -259,7 +255,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, extrapolation_value, exclude_outside_, allocate_temp_space, - shared_lookup_table_ondevice_.get(), + shared_lookup_table_ondevice.get(), reinterpret_cast(X->Data()), reinterpret_cast(Y->MutableData()), output_count); @@ -274,7 +270,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, TArray scales_vals(scales); size_t temp_buffer_size = CalcResizeBufferSize(mode_, output_dims); - auto dims_mapping_buffer = GetScratchBuffer(temp_buffer_size, context->GetComputeStream()); + auto dims_mapping_buffer = GetScratchBuffer(temp_buffer_size, GetComputeStream(context)); void* dims_mapping = reinterpret_cast(dims_mapping_buffer.get()); ResizeImpl(Stream(context), mode_, rank, input_shape, output_shape, input_strides, output_div_pitches, scales_vals, roi_vals, @@ -341,7 +337,7 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { InlinedVector scales_array(input_dims.size()); // opset < 10 - if (OpKernel::Node().InputDefs().size() == 1) { + if (context->InputCount() == 1) { // Compute output shape from scales attributes and input dims scales_array = scales_; diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.h b/onnxruntime/core/providers/cuda/tensor/upsample.h index 50597e0fba1b9..baf7bc8b06915 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.h +++ b/onnxruntime/core/providers/cuda/tensor/upsample.h @@ -18,9 +18,6 @@ class Upsample : public UpsampleBase, public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; Status BaseCompute(OpKernelContext* context, gsl::span roi, gsl::span scales, gsl::span output_dims) const; - - private: - IAllocatorUniquePtr shared_lookup_table_ondevice_; }; } // namespace cuda diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc index e1c883f960dde..dbb0a9330d262 100644 --- a/onnxruntime/python/onnxruntime_pybind_module.cc +++ b/onnxruntime/python/onnxruntime_pybind_module.cc @@ -13,6 +13,7 @@ #include "core/session/provider_bridge_ort.h" #include "core/framework/provider_options.h" #include "core/platform/env.h" +#include "core/common/inlined_containers.h" namespace onnxruntime { namespace python { @@ -113,7 +114,27 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { throw pybind11::import_error(st.ErrorMessage()); // move it out of shared method since training build has a little different behavior. m.def( - "get_available_providers", []() -> const std::vector& { return GetAvailableExecutionProviderNames(); }, + "get_available_providers", []() -> std::vector { + auto available = GetAvailableExecutionProviderNames(); +#if !defined(ORT_MINIMAL_BUILD) + InlinedHashSet existing; + existing.reserve(available.size()); + for (const auto& ep_name : available) { + existing.insert(ep_name); + } + + for (const OrtEpDevice* ep_device : GetEnv().GetOrtEpDevices()) { + if (!ep_device) { + continue; + } + + if (existing.insert(ep_device->ep_name).second) { + available.push_back(ep_device->ep_name); + } + } +#endif + return available; + }, "Return list of available Execution Providers in this installed version of Onnxruntime. " "The order of elements represents the default priority order of Execution Providers " "from highest to lowest."); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 937a96a619822..2e5c24db951ae 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -585,6 +585,44 @@ static std::shared_ptr CreateExecutionProviderFactory const SessionOptions& session_options, const std::string& type, const ProviderOptionsMap& provider_options_map) { +#if !defined(ORT_MINIMAL_BUILD) + auto try_create_registered_plugin_factory = [&]() -> std::shared_ptr { + const auto& ep_devices = GetEnv().GetOrtEpDevices(); + if (ep_devices.empty()) { + return nullptr; + } + + const OrtEpDevice* selected_device = nullptr; + for (const OrtEpDevice* ep_device : ep_devices) { + if (!ep_device || ep_device->ep_name != type) { + continue; + } + + if (selected_device == nullptr) { + selected_device = ep_device; + break; + } + } + + if (selected_device == nullptr) { + return nullptr; + } + + InlinedVector selected_devices; + selected_devices.push_back(selected_device); + + std::unique_ptr ep_factory; + const auto status = onnxruntime::CreateIExecutionProviderFactoryForEpDevices(GetEnv(), selected_devices, ep_factory); + if (!status.IsOK()) { + LOGS_DEFAULT(WARNING) << "Failed to create dynamic EP factory for '" << type + << "' from registered EP devices: " << status; + return nullptr; + } + + return std::shared_ptr(std::move(ep_factory)); + }; +#endif + if (type == kCpuExecutionProvider) { return onnxruntime::CPUProviderFactoryCreator::Create( session_options.enable_cpu_mem_arena); @@ -1203,6 +1241,13 @@ static std::shared_ptr CreateExecutionProviderFactory << " to ensure all dependencies are met."; #endif } else { +#if !defined(ORT_MINIMAL_BUILD) + // Try EPs dynamically registered via register_execution_provider_library(). + if (auto ep_factory = try_create_registered_plugin_factory(); ep_factory) { + return ep_factory; + } +#endif + // check whether it is a dynamic load EP: const auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { @@ -1241,7 +1286,10 @@ std::unique_ptr CreateExecutionProviderInstance(const Sessio const ProviderOptionsMap& provider_options_map) { auto ep_factory = CreateExecutionProviderFactoryInstance(session_options, type, provider_options_map); if (ep_factory) { - return ep_factory->CreateProvider(); + const auto& default_logger = GetEnv().GetLoggingManager()->DefaultLogger(); + OrtSessionOptions ort_session_options; + ort_session_options.value = session_options; + return ep_factory->CreateProvider(ort_session_options, *default_logger.ToExternal()); } return nullptr; } diff --git a/onnxruntime/test/unittest_util/base_tester.cc b/onnxruntime/test/unittest_util/base_tester.cc index 18ead92ce3f18..8397ea31531c0 100644 --- a/onnxruntime/test/unittest_util/base_tester.cc +++ b/onnxruntime/test/unittest_util/base_tester.cc @@ -41,6 +41,12 @@ void DebugTrap() { } #endif +bool ShouldRouteCudaToDynamicPluginEp(const std::optional& dynamic_plugin_ep_name) { + // Route CUDA requests to the CUDA plugin EP when unit test main has initialized + // dynamic plugin EP infrastructure with the CUDA plugin registration. + return dynamic_plugin_ep_name.has_value() && *dynamic_plugin_ep_name == "CudaPluginExecutionProvider"; +} + } // namespace BaseTester::~BaseTester() { @@ -689,9 +695,10 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, #endif const auto dynamic_plugin_ep_name = dynamic_plugin_ep_infra::GetEpName(); + const bool route_cuda_to_dynamic_plugin_ep = ShouldRouteCudaToDynamicPluginEp(dynamic_plugin_ep_name); std::optional> provider_types_including_dynamic_plugin_ep{}; - if (dynamic_plugin_ep_name.has_value()) { + if (dynamic_plugin_ep_name.has_value() && !route_cuda_to_dynamic_plugin_ep) { ORT_ENFORCE(std::find(all_provider_types.begin(), all_provider_types.end(), *dynamic_plugin_ep_name) == all_provider_types.end(), "Dynamic plugin EP name conflicts with a known EP name: ", *dynamic_plugin_ep_name); @@ -716,7 +723,7 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, if (provider_type == onnxruntime::kCpuExecutionProvider) execution_provider = DefaultCpuExecutionProvider(); else if (provider_type == onnxruntime::kCudaExecutionProvider) - execution_provider = DefaultCudaExecutionProvider(); + execution_provider = route_cuda_to_dynamic_plugin_ep ? dynamic_plugin_ep_infra::MakeEp() : DefaultCudaExecutionProvider(); #ifdef ENABLE_CUDA_NHWC_OPS else if (provider_type == onnxruntime::kCudaNHWCExecutionProvider) execution_provider = DefaultCudaNHWCExecutionProvider(); @@ -814,11 +821,20 @@ void BaseTester::ExecuteModelForEps( bool allow_released_onnx_opset_only, size_t* number_of_pre_packed_weights_counter, size_t* number_of_shared_pre_packed_weights_counter) { + const auto dynamic_plugin_ep_name = dynamic_plugin_ep_infra::GetEpName(); + const bool route_cuda_to_dynamic_plugin_ep = ShouldRouteCudaToDynamicPluginEp(dynamic_plugin_ep_name); + for (auto& entry : execution_providers) { // Be noted, entry in execution providers passed in OpTester will be std::moved in the first BaseTester::Run(), // To make the error more obvious to debug (instead of a segment fault), we do check explicitly here. ASSERT_TRUE(entry) << "Execution provider entry invalid."; + if (route_cuda_to_dynamic_plugin_ep && entry->Type() == kCudaExecutionProvider) { + auto plugin_ep = dynamic_plugin_ep_infra::MakeEp(); + ASSERT_TRUE(plugin_ep) << "Failed to create CUDA plugin EP while routing from CUDAExecutionProvider."; + entry = std::move(plugin_ep); + } + if (entry->Type() == kDmlExecutionProvider) { sess_options.enable_mem_pattern = false; sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; From 267bb061e79d79aaaa4686f26291d89ab8cefde8 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 23 Mar 2026 13:25:25 -0700 Subject: [PATCH 04/48] remove cuda graph --- .../core/providers/cuda/plugin/cuda_ep.cc | 135 +--------------- .../core/providers/cuda/plugin/cuda_ep.h | 23 --- .../providers/cuda/plugin/cuda_ep_factory.cc | 16 +- .../providers/cuda/plugin/cuda_ep_factory.h | 7 - .../cuda/plugin/cuda_graph_plugin.cc | 144 ------------------ .../providers/cuda/plugin/cuda_graph_plugin.h | 70 --------- 6 files changed, 11 insertions(+), 384 deletions(-) delete mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc delete mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index 6215cc74f828a..0e6badaf7ef94 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -20,9 +20,7 @@ CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& lo factory_(factory), name_(factory.GetEpName()), config_(config), - logger_(logger), - cuda_graph_enabled_(config.enable_cuda_graph), - min_runs_before_capture_(config.min_num_runs_before_cuda_graph_capture) { + logger_(logger) { ort_version_supported = ORT_API_VERSION; // Set function pointers for kernel-registry-based EP @@ -31,8 +29,8 @@ CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& lo GetKernelRegistry = GetKernelRegistryImpl; GetPreferredDataLayout = GetPreferredDataLayoutImpl; ShouldConvertDataLayoutForOp = ShouldConvertDataLayoutForOpImpl; - OnRunStart = OnRunStartImpl; - OnRunEnd = OnRunEndImpl; + OnRunStart = nullptr; + OnRunEnd = nullptr; // Not a compile-based EP Compile = nullptr; @@ -178,132 +176,5 @@ OrtStatus* ORT_API_CALL CudaEp::ShouldConvertDataLayoutForOpImpl( return nullptr; } -// --------------------------------------------------------------------------- -// CUDA Graph helpers -// --------------------------------------------------------------------------- - -CudaGraphAnnotation_t CudaEp::GetAnnotationId(const ::OrtRunOptions* run_options) const { - const OrtApi& ort_api = factory_.GetOrtApi(); - // Use the same key as the bundled CUDA EP: "gpu_graph_id" - const char* val = ort_api.GetRunConfigEntry(run_options, "gpu_graph_id"); - if (val == nullptr) { - return kCudaGraphAnnotationDefault; - } - try { - return std::stoi(val); - } catch (...) { - return kCudaGraphAnnotationDefault; - } -} - -bool CudaEp::IsGraphCaptureAllowed(CudaGraphAnnotation_t annotation_id) const { - if (!cuda_graph_manager_.IsGraphCaptureAllowedOnRun(annotation_id)) { - return false; - } - auto it = graph_id_to_run_count_.find(annotation_id); - if (it == graph_id_to_run_count_.end()) { - return false; - } - return it->second >= min_runs_before_capture_; -} - -// --------------------------------------------------------------------------- -// OnRunStart — manage CUDA graph capture/replay state machine -// --------------------------------------------------------------------------- - -/*static*/ -OrtStatus* ORT_API_CALL CudaEp::OnRunStartImpl( - OrtEp* this_ptr, const ::OrtRunOptions* run_options) noexcept { - EXCEPTION_TO_STATUS_BEGIN - - auto* ep = static_cast(this_ptr); - - if (!ep->cuda_graph_enabled_.load(std::memory_order_relaxed)) { - return nullptr; // Graph capture not enabled — no-op - } - - // gpu_graph_id == -1 means skip capture/replay for this run - // (matches bundled CUDA EP behavior via kOrtRunOptionsConfigCudaGraphAnnotation) - CudaGraphAnnotation_t annotation_id = ep->GetAnnotationId(run_options); - if (annotation_id == kCudaGraphAnnotationSkip) { - return nullptr; - } - - // Lazily set the graph manager's stream from the factory's compute stream. - CudaSyncStream* compute_stream = ep->factory_.GetComputeStream(); - if (compute_stream == nullptr) { - // Stream not yet created — skip graph capture for this run. - // This can happen if OnRunStart is called before CreateSyncStreamForDevice. - return nullptr; - } - ep->cuda_graph_manager_.SetStream(compute_stream->GetCudaStream()); - - if (ep->cuda_graph_manager_.IsGraphCaptured(annotation_id)) { - // Already captured — replay happens in OnRunEnd for the plugin EP. - // ORT runtime will still dispatch kernels normally; the captured graph - // replays the actual GPU work. For the plugin EP without stream executor - // hooks, we replay at OnRunEnd after kernel dispatch completes. - return nullptr; - } - - if (!ep->cuda_graph_manager_.IsGraphCaptured(annotation_id) && - ep->IsGraphCaptureAllowed(annotation_id)) { - // Warm-up period complete — begin capture - ep->cuda_graph_manager_.CaptureBegin(annotation_id); - ep->is_capturing_ = true; - ep->capturing_annotation_id_ = annotation_id; - } - - return nullptr; - - EXCEPTION_TO_STATUS_END -} - -// --------------------------------------------------------------------------- -// OnRunEnd — end capture or handle replay -// --------------------------------------------------------------------------- - -/*static*/ -OrtStatus* ORT_API_CALL CudaEp::OnRunEndImpl( - OrtEp* this_ptr, const ::OrtRunOptions* run_options, bool sync_stream) noexcept { - EXCEPTION_TO_STATUS_BEGIN - - auto* ep = static_cast(this_ptr); - - if (!ep->cuda_graph_enabled_.load(std::memory_order_relaxed)) { - return nullptr; - } - - // gpu_graph_id == -1 means skip capture/replay for this run - CudaGraphAnnotation_t annotation_id = ep->GetAnnotationId(run_options); - if (annotation_id == kCudaGraphAnnotationSkip) { - return nullptr; - } - - if (!ep->cuda_graph_manager_.IsGraphCaptured(annotation_id)) { - if (ep->is_capturing_ && ep->capturing_annotation_id_ == annotation_id) { - // Was capturing — end capture and replay the first time - ep->cuda_graph_manager_.CaptureEnd(annotation_id); - ep->is_capturing_ = false; - - // CUDA work issued to a capturing stream doesn't actually run on the GPU, - // so replay the captured graph to actually execute the work. - OrtStatus* replay_status = ep->cuda_graph_manager_.Replay(annotation_id, sync_stream); - if (replay_status != nullptr) return replay_status; - } else { - // Still in warm-up period — increment run count - ep->graph_id_to_run_count_[annotation_id]++; - } - } - // Note: For subsequent runs after capture, the captured graph is not replayed - // here. The ORT framework dispatches kernels normally (it does not know about - // CUDA graph capture). Full graph-only replay (with kernel dispatch bypass) - // requires stream executor support which is not yet available in the plugin EP. - - return nullptr; - - EXCEPTION_TO_STATUS_END -} - } // namespace cuda_plugin } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.h b/onnxruntime/core/providers/cuda/plugin/cuda_ep.h index 9d15cd6abdb8a..c973507d90f91 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.h @@ -4,11 +4,8 @@ #pragma once #include "cuda_plugin_utils.h" -#include "cuda_graph_plugin.h" -#include #include -#include #include namespace onnxruntime { @@ -27,8 +24,6 @@ class CudaEp : public OrtEp { int device_id = 0; ///< CUDA device ordinal. int cudnn_conv_algo = 0; ///< cuDNN convolution algorithm selection. bool cudnn_conv1d_pad_to_nc1d = false; ///< Pad 1D convolutions to NC1D format. - bool enable_cuda_graph = false; ///< Enable CUDA graph capture/replay. - int min_num_runs_before_cuda_graph_capture = 1; ///< Warm-up runs before graph capture. }; CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& logger); @@ -52,32 +47,14 @@ class CudaEp : public OrtEp { static OrtStatus* ORT_API_CALL GetPreferredDataLayoutImpl( OrtEp* this_ptr, OrtEpDataLayout* preferred_data_layout) noexcept; - static OrtStatus* ORT_API_CALL OnRunStartImpl( - OrtEp* this_ptr, const ::OrtRunOptions* run_options) noexcept; - static OrtStatus* ORT_API_CALL ShouldConvertDataLayoutForOpImpl( OrtEp* this_ptr, const char* domain, const char* op_type, OrtEpDataLayout target_data_layout, int* should_convert) noexcept; - static OrtStatus* ORT_API_CALL OnRunEndImpl( - OrtEp* this_ptr, const ::OrtRunOptions* run_options, bool sync_stream) noexcept; - - // CUDA Graph helpers - CudaGraphAnnotation_t GetAnnotationId(const ::OrtRunOptions* run_options) const; - bool IsGraphCaptureAllowed(CudaGraphAnnotation_t annotation_id) const; - CudaEpFactory& factory_; std::string name_; Config config_; const OrtLogger& logger_; - - // CUDA Graph state - std::atomic cuda_graph_enabled_{false}; - int min_runs_before_capture_ = 1; - CUDAGraphManager cuda_graph_manager_; - std::unordered_map graph_id_to_run_count_; - bool is_capturing_ = false; - CudaGraphAnnotation_t capturing_annotation_id_ = kCudaGraphAnnotationDefault; }; } // namespace cuda_plugin diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index e6c6df2d6ccfe..3c8706ce2cef2 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -113,8 +113,14 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( auto hw_type = factory->ort_api_.HardwareDevice_Type(&device); if (hw_type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - // Check if this GPU is an NVIDIA GPU by trying to match vendor ID - // For now, accept all GPU devices and let CUDA runtime handle validation + // Filter by vendor ID to avoid claiming non-NVIDIA GPUs on mixed-vendor hosts. + // vendor_id == 0 means the hardware enumeration did not provide a vendor ID, + // in which case we fall through and let the CUDA runtime validate the device. + uint32_t hw_vendor_id = factory->ort_api_.HardwareDevice_VendorId(&device); + if (hw_vendor_id != 0 && hw_vendor_id != factory->vendor_id_) { + continue; // Skip non-NVIDIA GPUs + } + OrtKeyValuePairs* ep_metadata = nullptr; factory->ort_api_.CreateKeyValuePairs(&ep_metadata); @@ -231,8 +237,6 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( read_session_config_bool("ep.cuda.enable_skip_layer_norm_strict_mode", config.enable_skip_layer_norm_strict_mode); read_session_config_bool("ep.cuda.cudnn_conv1d_pad_to_nc1d", config.cudnn_conv1d_pad_to_nc1d); read_session_config_int("ep.cuda.cudnn_conv_algo", config.cudnn_conv_algo); - read_session_config_bool("ep.cuda.enable_cuda_graph", config.enable_cuda_graph); - read_session_config_int("ep.cuda.min_num_runs_before_cuda_graph_capture", config.min_num_runs_before_cuda_graph_capture); const OrtLogger& ep_logger = logger ? *logger : factory->default_logger_; auto actual_ep = std::make_unique(*factory, config, ep_logger); @@ -330,10 +334,6 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateSyncStreamForDeviceImpl( // Initialize CUDA handles (stream, cuBLAS, cuDNN) RETURN_IF_ERROR(cuda_stream->InitHandles()); - // Track the compute stream for CUDA graph integration. - // The factory does NOT own this stream — ORT manages its lifetime. - factory->compute_stream_ = cuda_stream.get(); - *stream = cuda_stream.release(); return nullptr; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h index 429cf50eaf854..96ec789d9abed 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h @@ -33,10 +33,6 @@ class CudaEpFactory : public OrtEpFactory { OrtStatus* GetKernelRegistryForEp(CudaEp& ep, const OrtKernelRegistry** out_kernel_registry); - /// Get the compute stream (set by CreateSyncStreamForDevice). - /// Returns nullptr if no stream has been created yet. - CudaSyncStream* GetComputeStream() const { return compute_stream_; } - private: // OrtEpFactory callback implementations static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; @@ -99,9 +95,6 @@ class CudaEpFactory : public OrtEpFactory { // Kernel registry (cached, shared across EP instances) OrtKernelRegistry* kernel_registry_ = nullptr; std::mutex registry_mutex_; - - // Compute stream (set by CreateSyncStreamForDevice, non-owning). - CudaSyncStream* compute_stream_ = nullptr; }; } // namespace cuda_plugin diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc deleted file mode 100644 index 3e5f28b4b0491..0000000000000 --- a/onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "cuda_graph_plugin.h" - -#include -#include -#include - -namespace onnxruntime { -namespace cuda_plugin { - -// --------------------------------------------------------------------------- -// CudaGraphSet -// --------------------------------------------------------------------------- - -CudaGraphSet::~CudaGraphSet() { - Clear(); -} - -void CudaGraphSet::Clear() { - for (auto& [id, graph_exec] : cuda_graphs_) { - (void)cudaGraphExecDestroy(graph_exec); - } - cuda_graphs_.clear(); -} - -bool CudaGraphSet::Contains(CudaGraphAnnotation_t id) const { - return cuda_graphs_.find(id) != cuda_graphs_.end(); -} - -void CudaGraphSet::Put(CudaGraphAnnotation_t id, cudaGraphExec_t graph_exec) { - if (Contains(id)) { - throw std::runtime_error( - "CudaGraphSet::Put: annotation id " + std::to_string(id) + - " already exists. Use a different annotation id."); - } - cuda_graphs_.emplace(id, graph_exec); -} - -cudaGraphExec_t CudaGraphSet::Get(CudaGraphAnnotation_t id) const { - auto it = cuda_graphs_.find(id); - if (it == cuda_graphs_.end()) { - throw std::runtime_error( - "CudaGraphSet::Get: no graph found for annotation id " + std::to_string(id)); - } - return it->second; -} - -// --------------------------------------------------------------------------- -// CUDAGraphManager -// --------------------------------------------------------------------------- - -CUDAGraphManager::CUDAGraphManager(cudaStream_t stream) : stream_(stream) {} - -CUDAGraphManager::~CUDAGraphManager() { - Reset(); -} - -void CUDAGraphManager::SetStream(cudaStream_t stream) { - stream_ = stream; -} - -void CUDAGraphManager::CaptureBegin(CudaGraphAnnotation_t annotation_id) { - if (!IsGraphCaptureAllowedOnRun(annotation_id)) { - throw std::runtime_error("CUDAGraphManager::CaptureBegin: capture not allowed for annotation " + - std::to_string(annotation_id)); - } - - if (cuda_graph_set_.Contains(annotation_id)) { - throw std::runtime_error( - "CUDAGraphManager::CaptureBegin: annotation id " + std::to_string(annotation_id) + - " already captured. Use a different annotation id."); - } - - auto err = cudaStreamSynchronize(stream_); - if (err != cudaSuccess) { - throw std::runtime_error(std::string("cudaStreamSynchronize failed: ") + cudaGetErrorString(err)); - } - - // cudaStreamCaptureModeGlobal: single-thread capture (future: ThreadLocal for multi-stream) - err = cudaStreamBeginCapture(stream_, cudaStreamCaptureModeGlobal); - if (err != cudaSuccess) { - throw std::runtime_error(std::string("cudaStreamBeginCapture failed: ") + cudaGetErrorString(err)); - } -} - -void CUDAGraphManager::CaptureEnd(CudaGraphAnnotation_t annotation_id) { - cudaGraph_t graph = nullptr; - auto err = cudaStreamEndCapture(stream_, &graph); - if (err != cudaSuccess) { - throw std::runtime_error(std::string("cudaStreamEndCapture failed: ") + cudaGetErrorString(err)); - } - if (graph == nullptr) { - throw std::runtime_error("CUDAGraphManager::CaptureEnd: captured graph is NULL"); - } - - cudaGraphExec_t graph_exec = nullptr; - err = cudaGraphInstantiate(&graph_exec, graph, nullptr, nullptr, 0); - (void)cudaGraphDestroy(graph); - - if (err != cudaSuccess) { - throw std::runtime_error(std::string("cudaGraphInstantiate failed: ") + cudaGetErrorString(err)); - } - - cuda_graph_set_.Put(annotation_id, graph_exec); -} - -OrtStatus* CUDAGraphManager::Replay(CudaGraphAnnotation_t annotation_id, bool sync) { - cudaGraphExec_t graph_exec = cuda_graph_set_.Get(annotation_id); - - auto err = cudaGraphLaunch(graph_exec, stream_); - if (err != cudaSuccess) { - return Ort::GetApi().CreateStatus( - ORT_EP_FAIL, - (std::string("cudaGraphLaunch failed: ") + cudaGetErrorString(err)).c_str()); - } - - if (sync) { - err = cudaStreamSynchronize(stream_); - if (err != cudaSuccess) { - return Ort::GetApi().CreateStatus( - ORT_EP_FAIL, - (std::string("cudaStreamSynchronize after graph replay failed: ") + cudaGetErrorString(err)).c_str()); - } - } - - return nullptr; -} - -bool CUDAGraphManager::IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t annotation_id) const { - return annotation_id != kCudaGraphAnnotationSkip; -} - -bool CUDAGraphManager::IsGraphCaptured(CudaGraphAnnotation_t annotation_id) const { - return cuda_graph_set_.Contains(annotation_id); -} - -void CUDAGraphManager::Reset() { - cuda_graph_set_.Clear(); -} - -} // namespace cuda_plugin -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h deleted file mode 100644 index 84c6361d9c5e1..0000000000000 --- a/onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// Plugin-compatible CUDA graph manager for capture/replay lifecycle. -// Adapted from core/providers/cuda/cuda_graph.h — removes dependencies -// on internal EP types (CUDAExecutionProvider, CudaStream). - -#pragma once - -#include "cuda_plugin_utils.h" - -#include -#include - -namespace onnxruntime { -namespace cuda_plugin { - -using CudaGraphAnnotation_t = int; - -constexpr CudaGraphAnnotation_t kCudaGraphAnnotationSkip = -1; -constexpr CudaGraphAnnotation_t kCudaGraphAnnotationDefault = 0; - -/// Stores instantiated CUDA graph executables keyed by annotation ID. -struct CudaGraphSet { - CudaGraphSet() = default; - ~CudaGraphSet(); - - void Clear(); - bool Contains(CudaGraphAnnotation_t id) const; - void Put(CudaGraphAnnotation_t id, cudaGraphExec_t graph_exec); - cudaGraphExec_t Get(CudaGraphAnnotation_t id) const; - - private: - std::unordered_map cuda_graphs_; -}; - -/// Manages CUDA graph capture/instantiation/replay for the plugin EP. -/// Each instance is associated with a single cudaStream_t. -struct CUDAGraphManager { - CUDAGraphManager() = default; - explicit CUDAGraphManager(cudaStream_t stream); - ~CUDAGraphManager(); - - void SetStream(cudaStream_t stream); - - /// Begin capturing CUDA work on the associated stream. - void CaptureBegin(CudaGraphAnnotation_t annotation_id); - - /// End capture, instantiate the graph, and store it. - void CaptureEnd(CudaGraphAnnotation_t annotation_id); - - /// Launch a previously captured graph. - OrtStatus* Replay(CudaGraphAnnotation_t annotation_id, bool sync = true); - - /// Destroy all captured graphs. - void Reset(); - - /// Whether capture is allowed for the given annotation (i.e., not the skip sentinel). - bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t annotation_id) const; - - /// Whether a graph has already been captured for the given annotation. - bool IsGraphCaptured(CudaGraphAnnotation_t annotation_id) const; - - private: - CudaGraphSet cuda_graph_set_; - cudaStream_t stream_ = nullptr; // Does not own the stream -}; - -} // namespace cuda_plugin -} // namespace onnxruntime From e61f27a4ec3158d1ad8d54f11106666ecf6e69b0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 23 Mar 2026 16:02:28 -0700 Subject: [PATCH 05/48] review feedback --- docs/cuda_plugin_ep/cuda_ops_for_plugin_ep.md | 269 ------------------ .../cuda/quantization/matmul_nbits.h | 2 +- .../cuda/plugin/cuda_controlflow_plugin.cc | 95 ++++--- .../cuda/plugin/cuda_controlflow_plugin.cu | 105 +++++-- .../cuda/plugin/cuda_controlflow_plugin.h | 8 +- .../cuda/plugin/cuda_data_transfer_plugin.cc | 40 +-- .../providers/cuda/reduction/reduction_ops.cc | 10 +- .../core/providers/cuda/tensor/tile.cc | 12 + .../core/providers/cuda/tensor/upsample.cc | 22 +- .../core/providers/cuda/tensor/upsample.h | 5 + 10 files changed, 193 insertions(+), 375 deletions(-) delete mode 100644 docs/cuda_plugin_ep/cuda_ops_for_plugin_ep.md diff --git a/docs/cuda_plugin_ep/cuda_ops_for_plugin_ep.md b/docs/cuda_plugin_ep/cuda_ops_for_plugin_ep.md deleted file mode 100644 index 0a6e627b1b5f7..0000000000000 --- a/docs/cuda_plugin_ep/cuda_ops_for_plugin_ep.md +++ /dev/null @@ -1,269 +0,0 @@ -# CUDA Kernel Changes for Plugin EP Compatibility - -## Overview - -The CUDA Plugin EP builds CUDA operator kernels into a separate shared library -(`onnxruntime_providers_cuda_plugin.so`) that communicates with the ORT core -through the ORT EP API. This architecture requires that kernel source files -**not** depend on framework-internal types that are unavailable across the -shared-library boundary. - -The plugin build uses two key mechanisms to achieve compatibility with minimal -(or zero) changes to existing kernel `.cc` files: - -1. **Force-included adapter headers** — The CMake build injects - `adapters.h` and `cuda_kernel_adapter.h` via `-include` compiler flags. - These headers redefine macros (`ONNX_OPERATOR_*_KERNEL_EX`), provide a - plugin-compatible `CudaKernel` base class, and supply shims for - `OpKernelContext`, `OpKernelInfo`, etc. - -2. **`BUILD_CUDA_EP_AS_PLUGIN` preprocessor guard** — For cases where the - adapter headers alone are insufficient, kernel headers can use - `#ifdef BUILD_CUDA_EP_AS_PLUGIN` to select an alternative code path - (e.g., a self-contained class instead of inheriting from a CPU base class). - -## Common Incompatibility Patterns - -| Pattern | Description | Typical Fix | -|---------|-------------|-------------| -| **`GetComputeStream()` returning `onnxruntime::Stream*`** | The adapter `OpKernelContext` exposes `GetComputeStream()` that returns the adapter `Stream*` with `GetHandle()`. Most kernels call `GetScratchBuffer(n, ctx->GetComputeStream())` which already works through the adapter. Kernels that `dynamic_cast` or call `CudaStream`-specific methods break. | Use `static_cast(ctx->GetComputeStream()->GetHandle())` instead of `CudaStream*` methods. The adapter `CudaKernel::GetCublasHandle(cudaStream_t)` and `GetCudnnHandle(cudaStream_t)` are available. | -| **Inheritance from CPU base class** | Kernels like `Resize : Upsample`, `SpaceToDepth : SpaceDepthBase`, `NonMaxSuppression : NonMaxSuppressionBase` inherit from CPU provider classes that are not linked into the plugin. | Add a `#ifdef BUILD_CUDA_EP_AS_PLUGIN` block in the header with a self-contained class that inlines the needed logic (see `constant_of_shape.h` for an example). | -| **`TensorSeq` (incomplete type)** | `TensorSeq` is not available in the plugin build. `identity_op.cc` and `sequence_op.cc` operate on sequence types. | These ops should remain excluded or need `TensorSeq` to be exposed through the EP API. | -| **`CudaTuningContext`** | Kernels that call `GetTuningContext()` and use `CudaTuningContext` methods directly. The adapter provides a stub `GetTuningContext()` but full tuning infra is unavailable. | Guard tuning-specific calls with `#ifndef BUILD_CUDA_EP_AS_PLUGIN` or use the adapter's stub which returns `nullptr` (callers should null-check). | -| **`PhiloxGenerator` / RNG state** | Dropout-family ops use `PhiloxGenerator` from the `CudaStream` object. This requires `CudaStream*` access. | Needs a `PhiloxGenerator` accessor in the adapter or exclusion. | -| **`QkvToContext` taking `Stream*`** | Attention ops pass `context->GetComputeStream()` (an `onnxruntime::Stream*`) to `QkvToContext`. This function dereferences `Stream*` internally. | Either change `QkvToContext` signature to accept `cudaStream_t` + handles, or provide a `PluginStreamShim` wrapper (already in the adapter). | -| **Pure CPU ops** | `Shape`, `Size` — these register CPU-side `OpKernel` classes whose `Compute()` is in the CPU provider library. | Permanently excluded; handled by `GetCpuPreferredNodes()`. | -| **`cuda_execution_provider.h` include** | Files that directly include the real `CUDAExecutionProvider` class definition conflict with the adapter's shim class. | Use the adapter's `CUDAExecutionProvider` shim (automatically provided by `cuda_kernel_adapter.h`). | -| **KernelInfoGetAttributeArray\_string** | RNN ops call `GetAttrs(...)` which maps to a C API function not yet available. | Wait for C API extension, or inline attribute parsing. | -| **Registration tables** | `cuda_nhwc_kernels.cc` and `cuda_contrib_kernels.cc` contain centralized `BuildKernelCreateInfo<>` tables that reference all kernel classes, including excluded ones. | Not needed — the plugin uses `PluginKernelCollector` for self-registration via macro overrides. | - -## How to Bring an Excluded Kernel to Plugin EP - -### Step 1: Identify the Dependency - -Check why the kernel is excluded by looking at: -- The comment in `cmake/onnxruntime_providers_cuda_plugin.cmake` -- The kernel `.cc`/`.h` files for the patterns listed above - -### Step 2: Apply the Minimal Fix - -The preferred approach (in order of preference): - -1. **No source change needed** — If the only issue was `GetComputeStream()` - usage with `GetScratchBuffer()`, the adapter already handles this. Just - remove the exclusion from the cmake file and test. - -2. **Use `BUILD_CUDA_EP_AS_PLUGIN` guard in the header** — For CPU base - class dependencies, add an alternative class definition: - ```cpp - #ifdef BUILD_CUDA_EP_AS_PLUGIN - class MyOp final : public CudaKernel { - // Self-contained implementation that inlines base class logic - }; - #else - class MyOp final : public CpuBaseClass, public CudaKernel { - // Original implementation - }; - #endif - ``` - -3. **Modify calling convention** — For functions that take - `onnxruntime::Stream*` or `CudaStream*`, change to accept - `cudaStream_t` + explicit handles: - ```cpp - // Before: - SomeHelper(context->GetComputeStream(), ...); - // After: - SomeHelper(static_cast(context->GetComputeStream()->GetHandle()), ...); - ``` - -4. **Add a shim in `cuda_kernel_adapter.h`** — For utility functions from - CPU providers (e.g., `ValidateInputs`, `PrepareCompute`), inline the - logic in the adapter header so it's available in the plugin build. - -5. **Inline CPU helper to header** — Move the helper implementation - from the CPU `.cc` file to the `.h` header, wrapped in - `#ifdef SHARED_PROVIDER` (declaration only) / `#else` (inline body). - The `SHARED_PROVIDER` build retains the existing `ProviderHostCPU` - bridge path. See `padbase.h`, `slice.h`, `scatter_nd.h` for examples. - -6. **Templatize on info/context type** — For base class constructors - that call `GetAttr()`, templatize on `KernelInfoType` with - `info.template GetAttr(...)`. For methods that take - `OpKernelContext&`, templatize on `KernelContextType`. - See `roialign.h`, `unsqueeze.h`, `attention_base.h` for examples. - -7. **Move CUDA type helpers to shared header** — For utility functions - that only depend on CUDA types (not framework types), move from - `.cc` to a header so the plugin build can consume them directly. - See `cuda_common_type_helpers.h`. - -### Step 3: Remove the CMake Exclusion - -In `cmake/onnxruntime_providers_cuda_plugin.cmake`, comment out the exclusion -line with a note about what was done: -```cmake -# myop.cc: . -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/myop\\.cc$") # REMOVED in Stage N -``` - -### Step 4: Build and Test - -```bash -# Build with plugin EP enabled -./build.sh --config Release --use_cuda --build_cuda_ep_as_plugin -# Run parity tests -python tools/ci_build/cuda_plugin_parity_report.py -``` - ---- - -## Excluded Operators Table - -### Infrastructure Files (Not Operator Kernels) - -These are excluded because they define the real EP infrastructure, which is -replaced by the plugin's own implementations in `plugin/`. - -| File | Reason | Resolution | -|------|--------|------------| -| `cuda_execution_provider.cc` | Defines the real `CUDAExecutionProvider` class; conflicts with adapter shim. | Permanently excluded; replaced by `plugin/cuda_ep.cc`. | -| `cuda_provider_factory.cc` | Creates the real CUDA EP via `ProviderFactory`; not used in plugin architecture. | Permanently excluded; replaced by `plugin/cuda_ep_factory.cc`. | -| `cuda_provider_interface.cc` | Shared-library provider interface for the old (non-plugin) shared-library model. | Permanently excluded; not applicable to plugin EP. | -| `cuda_stream_handle.cc` | Defines `CudaStream` class; replaced by plugin stream adapter. | Permanently excluded; replaced by `plugin/cuda_stream_plugin.cc`. | -| `cuda_execution_provider_info.cc` | EP configuration parsing tied to the real EP. | Permanently excluded; replaced by `plugin/cuda_ep.cc` config. | -| `cuda_graph.cc` | CUDA graph capture tied to real EP stream management. | Permanently excluded; replaced by `plugin/cuda_graph_plugin.cc`. | -| `cuda_mempool_arena.cc` | Memory arena tied to real EP allocator infrastructure. | Permanently excluded; replaced by `plugin/cuda_allocator_plugin.cc`. | -| `cuda_common.cc` | `HalfGemmOptions` definitions conflict with adapter's inline shim. | Permanently excluded; shims provided in `cuda_kernel_adapter.h`. | -| `cuda_nhwc_kernels.cc` | Centralized kernel registration table; references all NHWC kernel classes. | Permanently excluded; `PluginKernelCollector` auto-registers. | -| `cuda_contrib_kernels.cc` | Centralized kernel registration table; references all contrib kernel classes. | Permanently excluded; `PluginKernelCollector` auto-registers. | - -### Standard ONNX Operator Kernels — Currently Excluded - -| File | Exclusion Reason | Change Needed to Include | -|------|-----------------|--------------------------| -| `math/einsum.cc` | Inherits from `onnxruntime::Einsum` (CPU provider); calls `Einsum::Compute()` which chains to `DeviceCompute()` through the CPU base class vtable. Also depends on `einsum_utils/` which calls `ReduceCompute`. | Add `#ifdef BUILD_CUDA_EP_AS_PLUGIN` path that directly implements `ComputeInternal()` without the CPU base class. Substantial effort — einsum is complex. | -| `math/einsum_utils/*` | `einsum_auxiliary_ops.cc` calls `ReductionOps::ReduceCompute` which is a framework-only function. | Must inline or rewrite reduction logic for plugin build. Coupled with `einsum.cc`. | -| `controlflow/*` (If, Loop, Scan) | Inherits from CPU base classes (`If`, `Loop`, `Scan` from `core/providers/cpu/controlflow/`). These ops call into the ORT session to execute subgraphs. | Plugin has custom wrappers in `plugin/cuda_controlflow_plugin.cc` that delegate to `OrtEpApi`. Permanently excluded from standard source; plugin equivalents exist. | -| `tunable/*` | Depends on `CudaTuningContext` and the real `CUDAExecutionProvider` for tuning infrastructure. | Needs full tuning API exposure through plugin interface. Low priority — tuning is optional. | -| `rnn/*` (RNN, GRU, LSTM) | Kernel constructors call `GetAttrs("activations", ...)` which maps to `KernelInfoGetAttributeArray_string` — a C API function that does not yet exist. Also uses `CudnnRnnBase` which manages cuDNN RNN descriptors. | Extend the ORT C API with `KernelInfoGetAttributeArray_string`. After that, the dual-build signatures (already in place) should work. | -| `tensor/identity_op.cc` | Uses `TensorSeq` (incomplete type in plugin build) for sequence pass-through in `IdentityOp`. | Expose `TensorSeq` through the EP API adapter, or split the sequence codepath into a separate file with `#ifdef`. | -| `tensor/sequence_op.cc` | All ops (`SequenceAt`, `SequenceConstruct`, `SequenceInsert`, etc.) heavily use `TensorSeq`. | Same as `identity_op.cc` — requires `TensorSeq` support in the adapter. | -| `tensor/size.cc` | Pure CPU op — registers `onnxruntime::Size` whose `Compute()` is in the CPU provider. | **Permanently excluded.** Handled by `GetCpuPreferredNodes()`. | -| `tensor/shape_op.cc` | Pure CPU op — inherits from `onnxruntime::OpKernel` (framework class, not adapter `OpKernel`). Output is on CPU. | **Permanently excluded.** Handled by `GetCpuPreferredNodes()`. | -| `tensor/space_depth_ops.cc` | Inherits from `SpaceDepthBase` (CPU provider, `core/providers/cpu/tensor/space_depth_ops.h`). | `SpaceDepthBase` constructor templatized on `KernelInfoType` (#27628). Remaining: inline `SpaceDepthCompute` validation logic or add adapter-compatible path. Reduced effort. | -| `tensor/upsample.cc` | Inherits from `UpsampleBase` (CPU provider). `UpsampleBase` uses `InputDefs()` and complex attribute/input parsing in its constructor. | `UpsampleBase::AdjustOutputSizeAsPolicy` moved to header (#27628). Remaining blockers: `InputDefs()` and `OpKernelInfo::GetAllocator()` not available in adapter. Moderate effort. | -| `tensor/resize.cc` | Inherits from `Upsample` which inherits from `UpsampleBase`. | Blocked on `upsample.cc` — must fix `Upsample` first, then `Resize` follows. | -| `generator/constant_of_shape.cc` | Inherits from `ConstantOfShapeBase` (CPU provider) which uses `TensorProto`/`UnpackTensor`. | **Already has `#ifdef BUILD_CUDA_EP_AS_PLUGIN` path** in the header with a self-contained class. Currently excluded because the `.cc` file's `#else` path still compiles `ConstantOfShapeBase` version. Need to verify the `#ifdef` path compiles and remove the exclusion. | -| `object_detection/*` (NonMaxSuppression, RoiAlign) | `NonMaxSuppression` inherits from `NonMaxSuppressionBase`; `RoiAlign` inherits from `RoiAlignBase`. Both CPU base classes. `NonMaxSuppression` also uses CPU helper `PrepareCompute`. | `NonMaxSuppressionBase` refactored to `NonMaxSuppressionBaseImpl` template (#27617). `RoiAlignBase` constructor templatized, `CheckROIAlignValidInput` inlined (#27628). Remaining: integration verification and residual `GetComputeStream()` issues. | -| `llm/*` | Attention kernels that call `QkvToContext` with `onnxruntime::Stream*`. Deep dependency on attention implementation internals. | Change `QkvToContext` to accept `cudaStream_t` + explicit handles, or use `PluginStreamShim`. Large surface area. | - -### Contrib Operator Kernels — Currently Excluded - -| File | Exclusion Reason | Change Needed to Include | -|------|-----------------|--------------------------| -| **aten_ops/\*** | PyTorch ATen operator bindings; requires `libtorch`. Not relevant for plugin EP. | **Permanently excluded.** | -| **collective/\*** | NCCL/MPI collective ops; requires distributed runtime. | **Permanently excluded** (or separate plugin). | -| **contrib llm/\*** | Same as standard `llm/` — deep `Stream*` and attention infra dependencies. | Same fix as standard `llm/`. | -| **transformers/\*** (beam_search, greedy_search, sampling) | Directly includes `cuda_execution_provider.h`. Uses session-level APIs to run subgraphs (encoder/decoder). Heavy framework dependency. | Would need significant refactoring to route subgraph execution through `OrtEpApi`. Very high effort. | -| **bert/attention.cc** | Calls `GetScratchBuffer` with `context->GetComputeStream()` (works via adapter). Main issue: calls `QkvToContext` passing `context->GetComputeStream()` (`Stream*`), and uses `IAllocator::MakeUniquePtr` with stream. | `AttentionBase::CheckInputs`/`CheckMask`/`GetPresent` moved to header (#27628). Remaining blocker: `QkvToContext` takes `Stream*`. Moderate-high effort. | -| **bert/decoder_attention.cc** | Same pattern as `attention.cc` — `QkvToContext` with `Stream*`. | Same fix as `attention.cc`. | -| **bert/decoder_masked_self_attention.cc** | Uses `GetComputeStream()` for scratch buffers and stream handle extraction. | Replace `GetComputeStream()` → adapter-compatible calls. Moderate effort. | -| **bert/embed_layer_norm.cc** | `embed_layer_norm_helper::CheckInputs` templatized and moved to header (#27617). CPU base class dependency resolved. | Verify compilation with exclusion removed — helper refactoring complete. **Very low effort.** | -| **bert/fast_gelu.cc** | Was excluded due to `bias_gelu_helper` CPU base class dependency. `bias_gelu_helper::CheckInputs` now templatized and inlined (#27617). | Verify compilation with exclusion removed — helper refactoring complete. **Very low effort.** | -| **bert/group_query_attention.cc** | Heavy use of `GetComputeStream()` (scratch buffers, stream handle extraction, `CudaStream*` cast). Complex attention pipeline with flash attention, XQA loader. | Same approach as `attention.cc`. High effort due to many code paths. | -| **bert/longformer_attention.cc** | Uses `GetScratchBuffer` with `GetComputeStream()`, workspace allocation. `LongformerAttentionBase::CheckInputs` moved to header (#27628). | Remaining blocker: `GetComputeStream()` / `Stream*` usage. Moderate effort. | -| **bert/multihead_attention.cc** | Same pattern as `attention.cc` — `QkvToContext` with `Stream*`. | Same fix as `attention.cc`. | -| **bert/packed_attention.cc** | Same attention pipeline dependency. | Same fix as `attention.cc`. | -| **bert/packed_multihead_attention.cc** | Same attention pipeline dependency. | Same fix as `attention.cc`. | -| **bert/paged_attention.cc** | Uses `GetComputeStream()` for scratch buffers and paged KV-cache management. | Replace stream access pattern. Moderate effort. | -| **bert/relative_attn_bias.cc** | Uses `GetScratchBuffer` with `GetComputeStream()`. | Simple `GetComputeStream()` pattern — may work with adapter. **Low effort to try.** | -| **bert/remove_padding.cc** | Uses `GetScratchBuffer` with `GetComputeStream()`. | Simple `GetComputeStream()` pattern — may work with adapter. **Low effort to try.** | -| **diffusion/group_norm.cc** | Uses `CudaTuningContext*` and `Stream*` in the `DispatchGroupNorm` helper. | Guard tuning path with `#ifndef BUILD_CUDA_EP_AS_PLUGIN`, change stream parameter. Moderate effort. | -| **fused_conv.cc** | Uses `GetComputeStream()` for cuDNN workspace allocation. | Replace stream access with adapter-compatible calls. Moderate effort. | -| **inverse.cc** | Uses `GetScratchBuffer` with `GetComputeStream()`. cuBLAS batched operations. | Simple pattern — likely works with adapter. **Low effort to try.** | -| **math/bias_dropout.cc** | Uses `PhiloxGenerator` from `CudaStream` for RNG state. Also `GetComputeStream()`. | Needs `PhiloxGenerator` accessor in adapter. Blocked on RNG infrastructure. | -| **math/fft_ops.cc** | Uses `onnxruntime::Stream*` directly. cuFFT plan management. | Change stream access to adapter pattern. Moderate effort. | -| **math/gemm_float8.cc/.cu** | `ComputeInternal` is in `.cu` file which uses `GetComputeStream()`. `.cu` files don't receive the force-include adapter header. | Move `GetComputeStream()` usage to `.cc` file, or pass stream as parameter to `.cu` function. Moderate effort. | -| **moe/moe.cc** | Uses `GetComputeStream()`. MoE routing + expert computation. | Replace `context->GetComputeStream()` with adapter-compatible calls. Moderate effort. | -| **sparse/sparse_attention.cc** | Uses `onnxruntime::Stream*`. Sparse attention kernel dispatch. | Same stream pattern fix. Moderate effort. | -| **tensor/shrunken_gather.cc** | Training op — includes `provider_api.h` in header. `ENABLE_TRAINING_OPS` guard. | **Permanently excluded** (training op, not needed for inference plugin). | -| **tensor/crop.cc** | `CropBase` constructor templatized on `KernelInfoType` (#27628). No `GetComputeStream()` usage. | Verify compilation with exclusion removed — constructor refactoring complete. **Very low effort.** | -| **tensor/dynamic_time_warping.cc** | Uses `GetScratchBuffer` with `GetComputeStream()`. | Simple pattern — likely works with adapter. **Low effort to try.** | -| **tensor/dynamicslice.cc** | Uses `onnxruntime::Stream*` via `GetComputeStream()`. | Simple pattern — likely works with adapter. **Low effort to try.** | -| **quantization/attention_quantization.cc** | Uses `GetScratchBuffer` with `GetComputeStream()`, calls `QkvToContext`. | Same fix as `attention.cc`. Moderate-high effort. | -| **quantization/matmul_bnb4.cc** | Uses `GetScratchBuffer` and `GetComputeStream()->GetHandle()`. | Adapter should handle this pattern. **Low effort to try.** | -| **quantization/matmul_nbits.cc** | Uses `GetScratchBuffer` with `GetComputeStream()` and `GetHandle()`. | Adapter should handle this pattern. **Low effort to try.** | -| **quantization/moe_quantization.cc** | Uses `GetComputeStream()`. Quantized MoE pipeline. | Same as `moe.cc`. Moderate effort. | -| **quantization/qordered_ops/\*** | Ordered quantization ops with framework dependencies. | Needs investigation. Low priority. | - -### Operators Successfully Brought to Plugin EP (Reference) - -These were previously excluded and are now included thanks to adapter -compatibility. Listed here as examples of the fix patterns applied. - -| File | Fix Applied | Stage | -|------|-------------|-------| -| `tensor/reshape.cc` | `CopyTensor` replaced with explicit `cudaMemcpyAsync` on kernel stream (#27719). | 5A | -| `tensor/concat.cc` | `InputArgCount`/`GetComputeStream` usage works through adapter `OpKernelContext`. | 5A | -| `tensor/split.cc` | `GetComputeStream` usage works via `CudaKernel::GetComputeStream`. | 5A | -| `tensor/gather.cc` | Switched to `GatherBase::PrepareForComputeImpl` compatible with adapter context. | 5B | -| `tensor/gather_nd.cc` | `PrepareCompute` signature changed to `void*`/`cudaStream_t`. | 5 | -| `tensor/unsqueeze.cc` | Plugin-local `PrepareCompute` path added for adapter context. | 5B | -| `tensor/tile.cc` | Plugin-local `IsTileMemcpy` helper added. | 5B | -| `math/cumsum.cc` | Axis parsing helper inlined for plugin build. | 5B | -| `tensor/scatter_nd.cc` | `ValidateShapes` inlined for plugin; `GetComputeStream` fixed. | 5 | -| `tensor/pad.cc` | Plugin-local wrappers for `PadBase` static helpers. | 5C.2 | -| `tensor/slice.cc` | Plugin-local wrappers for `SliceBase::PrepareForCompute`/`FlattenOutputDims`. | 5C.3 | -| `math/variadic_elementwise_ops.cc` | Adapter `InputCount`/`RequiredInput`/`RequiredOutput` supported. | 5C | -| `math/matmul.cc` | `GetComputeStream` fixed; `GetTuningContext` guarded. | 5 | -| `math/matmul_integer.cc` | `GetComputeStream` fixed; `GemmInt8` signature updated. | 5 | -| `math/integer_gemm.cc` | `dynamic_cast` replaced with stream-based `GetCublasHandle()` overload (#27719). | 5 | -| `contrib/math/fused_matmul.cc` | Included after `matmul.cc` was fixed. | 5 | - -## Priority Recommendations - -### High Priority (Common ops, likely low effort) - -These excluded ops use simple `GetComputeStream()` patterns that the adapter -already supports. They should be tried first. Ops marked with (✓) have had -their CPU helper dependencies fully refactored and are ready for build -verification: - -- `contrib/bert/embed_layer_norm.cc` (✓ helper refactored #27617) -- `contrib/bert/fast_gelu.cc` (✓ helper refactored #27617) -- `contrib/bert/relative_attn_bias.cc` -- `contrib/bert/remove_padding.cc` -- `contrib/tensor/crop.cc` (✓ constructor templatized #27628) -- `contrib/tensor/dynamic_time_warping.cc` -- `contrib/tensor/dynamicslice.cc` -- `contrib/inverse.cc` -- `contrib/quantization/matmul_bnb4.cc` -- `contrib/quantization/matmul_nbits.cc` -- `generator/constant_of_shape.cc` (already has `#ifdef` path) - -### Medium Priority (Moderate refactoring needed) - -- `tensor/space_depth_ops.cc` — constructor templatized; remaining validation to inline -- `contrib/diffusion/group_norm.cc` — guard tuning context -- `contrib/moe/moe.cc` — fix stream access -- `contrib/fused_conv.cc` — fix stream access -- `contrib/math/fft_ops.cc` — fix stream access -- `contrib/math/gemm_float8.cc/.cu` — move stream access to `.cc` - -### Low Priority (Significant effort or niche) - -- `tensor/upsample.cc` + `tensor/resize.cc` — `AdjustOutputSizeAsPolicy` moved to header; `InputDefs()`/`GetAllocator()` still needed -- `rnn/*` — blocked on C API string array extension -- `llm/*` + `bert/attention*.cc` family — deep attention pipeline changes -- `math/einsum.cc` — complex CPU base class -- `object_detection/*` — base classes partially refactored; integration verification needed -- `transformers/*` — subgraph execution, very high effort - -### Permanently Excluded - -- `tensor/size.cc`, `tensor/shape_op.cc` — pure CPU ops -- `aten_ops/*` — PyTorch dependency -- `collective/*` — distributed runtime -- `tensor/shrunken_gather.cc` — training only -- Infrastructure files — replaced by `plugin/` equivalents diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h index 87a675d282fdd..64969ae499bf7 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -52,7 +52,7 @@ class MatMulNBits final : public CudaKernel { #ifdef BUILD_CUDA_EP_AS_PLUGIN // Plugin adapter Node does not have InputDefs(). Defer existence checks to ComputeInternal // where we can check if the actual input tensor is null or not. - (void)kInputIndexScale; // used only in non-plugin path + ORT_UNUSED_PARAMETER(kInputIndexScale); // used only in non-plugin path has_zero_points_ = info.GetInputCount() > kInputIndexZeroPoints; has_g_idx_ = info.GetInputCount() > kInputIndexGroupIndex; has_bias_ = info.GetInputCount() > kInputIndexBias; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc index bc2ba2b8c6f8a..f03e69645df9c 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc @@ -12,6 +12,57 @@ namespace onnxruntime { namespace cuda { namespace plugin { +namespace { + +Status GetTensorElementStorageSize(ONNXTensorElementDataType elem_type, size_t& element_size) { + switch (elem_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: + element_size = 1; + return Status::OK(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + element_size = 2; + return Status::OK(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + element_size = 4; + return Status::OK(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: + element_size = 8; + return Status::OK(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: + element_size = 16; + return Status::OK(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT2: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Scan Transpose: packed sub-byte tensor types are unsupported"); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Scan Transpose: string tensors are unsupported"); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Scan Transpose: unsupported element type ", static_cast(elem_type)); + } +} + +} // namespace + // =================================================================== // If kernel // =================================================================== @@ -154,46 +205,22 @@ OrtStatus* ORT_API_CALL PluginScanHelper::TransposeImpl( // Determine element size from the data type ONNXTensorElementDataType elem_type = input_info.GetElementType(); size_t element_size = 0; - switch (elem_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - element_size = sizeof(float); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - element_size = sizeof(double); - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: - element_size = 2; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - element_size = 1; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - element_size = 2; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - element_size = 4; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - element_size = 8; - break; - default: - return Ort::Status("Scan Transpose: unsupported element type", ORT_FAIL).release(); + auto status = GetTensorElementStorageSize(elem_type, element_size); + if (!status.IsOK()) { + return Ort::Status(status.ErrorMessage().c_str(), ORT_EP_FAIL).release(); } const void* input_data = input.GetTensorRawData(); void* output_data = output.GetTensorMutableData(); // Launch the GPU transpose kernel - LaunchTransposeKernel(input_data, output_data, - input_shape.data(), permutation, - num_dims, element_size, total_elements, - cuda_stream); + status = LaunchTransposeKernel(input_data, output_data, + input_shape.data(), permutation, + num_dims, element_size, total_elements, + cuda_stream); + if (!status.IsOK()) { + return Ort::Status(status.ErrorMessage().c_str(), ORT_EP_FAIL).release(); + } return nullptr; } catch (const std::exception& ex) { diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu index 43df89468c42d..98d35baab2361 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu @@ -9,11 +9,45 @@ #include #include #include +#include + +#include "core/common/status.h" namespace onnxruntime { namespace cuda { namespace plugin { +namespace { + +Status CudaStatus(cudaError_t cuda_status, const char* operation) { + if (cuda_status == cudaSuccess) { + return Status::OK(); + } + + return common::Status(common::ONNXRUNTIME, common::FAIL, + std::string("Scan Transpose: ") + operation + " failed: " + cudaGetErrorString(cuda_status)); +} + +struct DeviceArraySet { + int64_t* input_strides = nullptr; + int64_t* output_strides = nullptr; + int* perm = nullptr; + + ~DeviceArraySet() { + if (perm != nullptr) { + cudaFree(perm); + } + if (output_strides != nullptr) { + cudaFree(output_strides); + } + if (input_strides != nullptr) { + cudaFree(input_strides); + } + } +}; + +} // namespace + // Maximum number of dimensions supported by the transpose kernel. // Most real-world tensors have <= 8 dimensions. constexpr int kMaxTransposeDims = 8; @@ -52,11 +86,19 @@ __global__ void TransposeNDKernel(const char* __restrict__ input, memcpy(dst, src, element_size); } -void LaunchTransposeKernel(const void* input, void* output, - const int64_t* input_shape, const size_t* permutation, - size_t num_dims, size_t element_size, size_t total_elements, - cudaStream_t stream) { - if (total_elements == 0 || num_dims == 0) return; +Status LaunchTransposeKernel(const void* input, void* output, + const int64_t* input_shape, const size_t* permutation, + size_t num_dims, size_t element_size, size_t total_elements, + cudaStream_t stream) { + if (total_elements == 0 || num_dims == 0) { + return Status::OK(); + } + + if (num_dims > static_cast(kMaxTransposeDims)) { + return common::Status(common::ONNXRUNTIME, common::FAIL, + "Scan Transpose: rank " + std::to_string(num_dims) + + " exceeds the supported maximum rank of " + std::to_string(kMaxTransposeDims)); + } // Compute input strides (row-major) int64_t input_strides[kMaxTransposeDims]; @@ -79,17 +121,35 @@ void LaunchTransposeKernel(const void* input, void* output, } // Copy arrays to device - int64_t* d_input_strides = nullptr; - int64_t* d_output_strides = nullptr; - int* d_perm = nullptr; - - cudaMalloc(&d_input_strides, num_dims * sizeof(int64_t)); - cudaMalloc(&d_output_strides, num_dims * sizeof(int64_t)); - cudaMalloc(&d_perm, num_dims * sizeof(int)); + DeviceArraySet device_arrays; + auto status = CudaStatus(cudaMalloc(&device_arrays.input_strides, num_dims * sizeof(int64_t)), "cudaMalloc(input_strides)"); + if (!status.IsOK()) { + return status; + } + status = CudaStatus(cudaMalloc(&device_arrays.output_strides, num_dims * sizeof(int64_t)), "cudaMalloc(output_strides)"); + if (!status.IsOK()) { + return status; + } + status = CudaStatus(cudaMalloc(&device_arrays.perm, num_dims * sizeof(int)), "cudaMalloc(perm)"); + if (!status.IsOK()) { + return status; + } - cudaMemcpyAsync(d_input_strides, input_strides, num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(d_output_strides, output_strides, num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(d_perm, perm_int, num_dims * sizeof(int), cudaMemcpyHostToDevice, stream); + status = CudaStatus(cudaMemcpyAsync(device_arrays.input_strides, input_strides, num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, stream), + "cudaMemcpyAsync(input_strides)"); + if (!status.IsOK()) { + return status; + } + status = CudaStatus(cudaMemcpyAsync(device_arrays.output_strides, output_strides, num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, stream), + "cudaMemcpyAsync(output_strides)"); + if (!status.IsOK()) { + return status; + } + status = CudaStatus(cudaMemcpyAsync(device_arrays.perm, perm_int, num_dims * sizeof(int), cudaMemcpyHostToDevice, stream), + "cudaMemcpyAsync(perm)"); + if (!status.IsOK()) { + return status; + } constexpr int kBlockSize = 256; int num_blocks = static_cast((total_elements + kBlockSize - 1) / kBlockSize); @@ -97,17 +157,18 @@ void LaunchTransposeKernel(const void* input, void* output, TransposeNDKernel<<>>( static_cast(input), static_cast(output), - d_input_strides, - d_output_strides, - d_perm, + device_arrays.input_strides, + device_arrays.output_strides, + device_arrays.perm, static_cast(num_dims), element_size, total_elements); - // Free device arrays asynchronously - cudaFreeAsync(d_input_strides, stream); - cudaFreeAsync(d_output_strides, stream); - cudaFreeAsync(d_perm, stream); + status = CudaStatus(cudaGetLastError(), "TransposeNDKernel launch"); + if (!status.IsOK()) { + return status; + } + return Status::OK(); } } // namespace plugin diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h index efee4bb215342..f5e203322a0c3 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h @@ -87,10 +87,10 @@ class PluginScanKernel : public OpKernel { }; // GPU transpose helper (defined in cuda_controlflow_plugin.cu) -void LaunchTransposeKernel(const void* input, void* output, - const int64_t* input_shape, const size_t* permutation, - size_t num_dims, size_t element_size, size_t total_elements, - cudaStream_t stream); +Status LaunchTransposeKernel(const void* input, void* output, + const int64_t* input_shape, const size_t* permutation, + size_t num_dims, size_t element_size, size_t total_elements, + cudaStream_t stream); } // namespace plugin } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc index 5226ff98b08f5..696712b2ea693 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc @@ -56,43 +56,11 @@ CudaDataTransfer::CudaDataTransfer(const OrtApi& ort_api, const OrtEpApi& ep_api Ort::ConstValue src{src_tensors[i]}; Ort::UnownedValue dst{dst_tensors[i]}; - auto src_type_shape = src.GetTensorTypeAndShapeInfo(); - size_t count_elems = src_type_shape.GetElementCount(); - - // Get element size from data type - ONNXTensorElementDataType elem_type = src_type_shape.GetElementType(); - size_t elem_size = 0; - // Compute byte size of the tensor elements. - // ORT's C API doesn't expose an element-size helper directly, so we - // map the ONNX element type to its byte width manually. - // Cases are grouped by element size for clarity. - switch (elem_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - elem_size = 1; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - elem_size = 2; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - elem_size = 4; - break; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - elem_size = 8; - break; - default: - return dt->ort_api_.CreateStatus(ORT_EP_FAIL, "Unsupported tensor element type for copy"); + size_t bytes = 0; + auto* status = dt->ort_api_.GetTensorSizeInBytes(src_tensors[i], &bytes); + if (status != nullptr) { + return status; } - - size_t bytes = count_elems * elem_size; if (bytes == 0) continue; const void* src_data = src.GetTensorRawData(); diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 5fdf64275c274..9a828c701d07f 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -458,14 +458,14 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const CudaKernel* ke input_tensor, output_tensor, &workspace_bytes)); auto workspace_cuda = workspace_bytes == 0 ? nullptr - : AllocateScratchBuffer(gpu_allocator, kernel, workspace_bytes, compute_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, workspace_bytes, compute_stream); size_t indices_bytes = 0; CUDNN_RETURN_IF_ERROR(cudnnGetReductionIndicesSize(cudnn_handle, reduce_desc, input_tensor, output_tensor, &indices_bytes)); auto indices_cuda = indices_bytes == 0 ? nullptr - : AllocateScratchBuffer(gpu_allocator, kernel, indices_bytes, compute_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, indices_bytes, compute_stream); if (ReduceTensorIndices == CUDNN_REDUCE_TENSOR_NO_INDICES) { IAllocatorUniquePtr input_data_buffer(nullptr, [](T*) {}); @@ -500,7 +500,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const CudaKernel* ke input_tensor, output_tensor, &indices_bytes_max)); auto indices_cuda_max = indices_bytes_max == 0 ? nullptr - : AllocateScratchBuffer(gpu_allocator, kernel, indices_bytes_max, compute_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, indices_bytes_max, compute_stream); auto* p_output = reinterpret_cast(output.template MutableData()); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( cudnn_handle, reduce_max_desc, indices_cuda_max.get(), indices_bytes_max, @@ -788,8 +788,8 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnRe cudnnGetReductionIndicesSize(GetCudnnHandle(ctx), reduce_desc, input_tensor, output_tensor, &indices_bytes)); \ CUDNN_RETURN_IF_ERROR( \ cudnnGetReductionWorkspaceSize(GetCudnnHandle(ctx), reduce_desc, input_tensor, output_tensor, &workspace_bytes)); \ - IAllocatorUniquePtr indices_cuda = GetScratchBuffer(indices_bytes, GetComputeStream(ctx)); \ - IAllocatorUniquePtr workspace_cuda = GetScratchBuffer(workspace_bytes, GetComputeStream(ctx)); \ + IAllocatorUniquePtr indices_cuda = GetScratchBuffer(indices_bytes, GetComputeStream(ctx)); \ + IAllocatorUniquePtr workspace_cuda = GetScratchBuffer(workspace_bytes, GetComputeStream(ctx)); \ \ const auto one = Consts::One; \ const auto zero = Consts::Zero; \ diff --git a/onnxruntime/core/providers/cuda/tensor/tile.cc b/onnxruntime/core/providers/cuda/tensor/tile.cc index 5fd250a93d6f8..dc33fcb286acf 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile.cc +++ b/onnxruntime/core/providers/cuda/tensor/tile.cc @@ -11,6 +11,7 @@ namespace cuda { namespace { +#ifdef BUILD_CUDA_EP_AS_PLUGIN bool IsTileMemcpyForPlugin(const TensorShape& input_shape, const int64_t* repeats, size_t rank, @@ -40,6 +41,7 @@ bool IsTileMemcpyForPlugin(const TensorShape& input_shape, } return false; } +#endif } // namespace @@ -143,6 +145,15 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const { size_t num_of_elements_per_batch = 1; size_t num_of_copies_per_batch = 1; size_t num_of_batch_copies = 1; +#ifdef BUILD_CUDA_EP_AS_PLUGIN + if (IsTileMemcpyForPlugin(input_shape, + repeats, + input_rank, + is_batched_memcpy, + num_of_elements_per_batch, + num_of_copies_per_batch, + num_of_batch_copies)) { +#else if (TileOp::IsTileMemcpy(input_shape, repeats, input_rank, @@ -150,6 +161,7 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const { num_of_elements_per_batch, num_of_copies_per_batch, num_of_batch_copies)) { +#endif if (!is_batched_memcpy) { switch (element_size) { CASE_TILE_MEMCPY(float); diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index 974430d7ee5dc..ee0e2a66f48d3 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -41,7 +41,17 @@ REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9); template -Upsample::Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) {} +Upsample::Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { +#ifndef BUILD_CUDA_EP_AS_PLUGIN + if (antialias_) { + const uint8_t* lookup_table = GetLookupTableShared(); + auto alloc = info.GetAllocator(OrtMemTypeDefault); + shared_lookup_table_ondevice_ = IAllocator::MakeUniquePtr(std::move(alloc), kLookupTableSize); + CUDA_CALL_THROW(cudaMemcpyAsync(shared_lookup_table_ondevice_.get(), lookup_table, kLookupTableSize, + cudaMemcpyHostToDevice, nullptr)); + } +#endif +} template Status Upsample::BaseCompute(OpKernelContext* context, @@ -95,10 +105,14 @@ Status Upsample::BaseCompute(OpKernelContext* context, } if (antialias_) { +#ifdef BUILD_CUDA_EP_AS_PLUGIN const uint8_t* lookup_table = GetLookupTableShared(); auto shared_lookup_table_ondevice = GetScratchBuffer(kLookupTableSize, GetComputeStream(context)); CUDA_CALL_THROW(cudaMemcpyAsync(shared_lookup_table_ondevice.get(), lookup_table, kLookupTableSize, cudaMemcpyHostToDevice, Stream(context))); +#else + const auto* shared_lookup_table_ondevice = shared_lookup_table_ondevice_.get(); +#endif TempSpaceAllocateFunc allocate_temp_space = [&](size_t bytes_size) { return GetScratchBuffer(bytes_size, GetComputeStream(context)); @@ -166,7 +180,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, extrapolation_value, exclude_outside_, allocate_temp_space, - shared_lookup_table_ondevice.get(), + shared_lookup_table_ondevice, reinterpret_cast(X->Data()), reinterpret_cast(Y->MutableData()), output_count); @@ -209,7 +223,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, extrapolation_value, exclude_outside_, allocate_temp_space, - shared_lookup_table_ondevice.get(), + shared_lookup_table_ondevice, reinterpret_cast(X->Data()), reinterpret_cast(Y->MutableData()), output_count); @@ -255,7 +269,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, extrapolation_value, exclude_outside_, allocate_temp_space, - shared_lookup_table_ondevice.get(), + shared_lookup_table_ondevice, reinterpret_cast(X->Data()), reinterpret_cast(Y->MutableData()), output_count); diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.h b/onnxruntime/core/providers/cuda/tensor/upsample.h index baf7bc8b06915..152862da0fdbd 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.h +++ b/onnxruntime/core/providers/cuda/tensor/upsample.h @@ -18,6 +18,11 @@ class Upsample : public UpsampleBase, public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; Status BaseCompute(OpKernelContext* context, gsl::span roi, gsl::span scales, gsl::span output_dims) const; + +#ifndef BUILD_CUDA_EP_AS_PLUGIN + private: + IAllocatorUniquePtr shared_lookup_table_ondevice_; +#endif }; } // namespace cuda From 44cf95515cddb278daf085d1e74797954c5137e0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 23 Mar 2026 16:18:49 -0700 Subject: [PATCH 06/48] refactoring --- .../cuda/plugin/cuda_controlflow_plugin.cu | 90 +++++-------------- 1 file changed, 21 insertions(+), 69 deletions(-) diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu index 98d35baab2361..77eed62068413 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu @@ -28,37 +28,23 @@ Status CudaStatus(cudaError_t cuda_status, const char* operation) { std::string("Scan Transpose: ") + operation + " failed: " + cudaGetErrorString(cuda_status)); } -struct DeviceArraySet { - int64_t* input_strides = nullptr; - int64_t* output_strides = nullptr; - int* perm = nullptr; - - ~DeviceArraySet() { - if (perm != nullptr) { - cudaFree(perm); - } - if (output_strides != nullptr) { - cudaFree(output_strides); - } - if (input_strides != nullptr) { - cudaFree(input_strides); - } - } -}; - -} // namespace - // Maximum number of dimensions supported by the transpose kernel. // Most real-world tensors have <= 8 dimensions. constexpr int kMaxTransposeDims = 8; +struct TransposeArgs { + int64_t input_strides[kMaxTransposeDims]; + int64_t output_strides[kMaxTransposeDims]; + int perm[kMaxTransposeDims]; +}; + +} // namespace + // Kernel: each thread handles one element, computing its output position // from the input position via the permutation. __global__ void TransposeNDKernel(const char* __restrict__ input, char* __restrict__ output, - const int64_t* __restrict__ input_strides, - const int64_t* __restrict__ output_strides, - const int* __restrict__ perm, + TransposeArgs args, int num_dims, size_t element_size, size_t total_elements) { @@ -69,14 +55,14 @@ __global__ void TransposeNDKernel(const char* __restrict__ input, int64_t coords[kMaxTransposeDims]; size_t remaining = idx; for (int d = 0; d < num_dims; d++) { - coords[d] = static_cast(remaining / static_cast(input_strides[d])); - remaining %= static_cast(input_strides[d]); + coords[d] = static_cast(remaining / static_cast(args.input_strides[d])); + remaining %= static_cast(args.input_strides[d]); } // Compute output linear index via permutation size_t out_idx = 0; for (int d = 0; d < num_dims; d++) { - out_idx += static_cast(coords[perm[d]]) * static_cast(output_strides[d]); + out_idx += static_cast(coords[args.perm[d]]) * static_cast(args.output_strides[d]); } // Copy element bytes @@ -100,55 +86,23 @@ Status LaunchTransposeKernel(const void* input, void* output, " exceeds the supported maximum rank of " + std::to_string(kMaxTransposeDims)); } + TransposeArgs args; + // Compute input strides (row-major) - int64_t input_strides[kMaxTransposeDims]; - input_strides[num_dims - 1] = 1; + args.input_strides[num_dims - 1] = 1; for (int d = static_cast(num_dims) - 2; d >= 0; d--) { - input_strides[d] = input_strides[d + 1] * input_shape[d + 1]; + args.input_strides[d] = args.input_strides[d + 1] * input_shape[d + 1]; } // Compute output shape and strides from permutation int64_t output_shape[kMaxTransposeDims]; - int64_t output_strides[kMaxTransposeDims]; - int perm_int[kMaxTransposeDims]; for (size_t d = 0; d < num_dims; d++) { output_shape[d] = input_shape[permutation[d]]; - perm_int[d] = static_cast(permutation[d]); + args.perm[d] = static_cast(permutation[d]); } - output_strides[num_dims - 1] = 1; + args.output_strides[num_dims - 1] = 1; for (int d = static_cast(num_dims) - 2; d >= 0; d--) { - output_strides[d] = output_strides[d + 1] * output_shape[d + 1]; - } - - // Copy arrays to device - DeviceArraySet device_arrays; - auto status = CudaStatus(cudaMalloc(&device_arrays.input_strides, num_dims * sizeof(int64_t)), "cudaMalloc(input_strides)"); - if (!status.IsOK()) { - return status; - } - status = CudaStatus(cudaMalloc(&device_arrays.output_strides, num_dims * sizeof(int64_t)), "cudaMalloc(output_strides)"); - if (!status.IsOK()) { - return status; - } - status = CudaStatus(cudaMalloc(&device_arrays.perm, num_dims * sizeof(int)), "cudaMalloc(perm)"); - if (!status.IsOK()) { - return status; - } - - status = CudaStatus(cudaMemcpyAsync(device_arrays.input_strides, input_strides, num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, stream), - "cudaMemcpyAsync(input_strides)"); - if (!status.IsOK()) { - return status; - } - status = CudaStatus(cudaMemcpyAsync(device_arrays.output_strides, output_strides, num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, stream), - "cudaMemcpyAsync(output_strides)"); - if (!status.IsOK()) { - return status; - } - status = CudaStatus(cudaMemcpyAsync(device_arrays.perm, perm_int, num_dims * sizeof(int), cudaMemcpyHostToDevice, stream), - "cudaMemcpyAsync(perm)"); - if (!status.IsOK()) { - return status; + args.output_strides[d] = args.output_strides[d + 1] * output_shape[d + 1]; } constexpr int kBlockSize = 256; @@ -157,14 +111,12 @@ Status LaunchTransposeKernel(const void* input, void* output, TransposeNDKernel<<>>( static_cast(input), static_cast(output), - device_arrays.input_strides, - device_arrays.output_strides, - device_arrays.perm, + args, static_cast(num_dims), element_size, total_elements); - status = CudaStatus(cudaGetLastError(), "TransposeNDKernel launch"); + auto status = CudaStatus(cudaGetLastError(), "TransposeNDKernel launch"); if (!status.IsOK()) { return status; } From eb329a41096f9ca8f1a64dc1f72b2d277ef591b8 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 23 Mar 2026 16:43:48 -0700 Subject: [PATCH 07/48] Copilot feedback --- .../providers/cuda/plugin/cuda_ep_factory.cc | 44 +++++++++---------- .../python/onnxruntime_pybind_module.cc | 2 +- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index 3c8706ce2cef2..58ab8fff413b1 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -121,15 +121,17 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( continue; // Skip non-NVIDIA GPUs } + int32_t current_device_id = factory->ort_api_.HardwareDevice_DeviceId(&device); + OrtKeyValuePairs* ep_metadata = nullptr; factory->ort_api_.CreateKeyValuePairs(&ep_metadata); // Get CUDA device properties for metadata int cuda_device_count = 0; cudaError_t err = cudaGetDeviceCount(&cuda_device_count); - if (err == cudaSuccess && cuda_device_count > 0) { + if (err == cudaSuccess && cuda_device_count > 0 && current_device_id < cuda_device_count) { cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, factory->device_id_); + cudaGetDeviceProperties(&prop, current_device_id); factory->ort_api_.AddKeyValuePair(ep_metadata, "cuda_device_name", prop.name); factory->ort_api_.AddKeyValuePair( ep_metadata, "cuda_compute_capability", @@ -145,9 +147,17 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( return status; } + Ort::MemoryInfo device_memory_info{"Cuda", + OrtMemoryInfoDeviceType_GPU, + factory->vendor_id_, + static_cast(current_device_id), + OrtDeviceMemoryType_DEFAULT, + /*alignment is default*/ 0, + OrtAllocatorType::OrtDeviceAllocator}; + // Register allocator info for GPU device memory RETURN_IF_ERROR(factory->ep_api_.EpDevice_AddAllocatorInfo( - ep_device, factory->default_memory_info_)); + ep_device, device_memory_info)); // Register allocator info for CPU pinned memory (host accessible) RETURN_IF_ERROR(factory->ep_api_.EpDevice_AddAllocatorInfo( @@ -184,6 +194,7 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( // The read helpers intentionally swallow errors: if a config entry is // absent or malformed the default value in Config is kept. CudaEp::Config config{}; + config.device_id = factory->ort_api_.HardwareDevice_DeviceId(devices[0]); auto read_session_config_bool = [&](const char* key, bool& value) { size_t size = 0; @@ -261,31 +272,18 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateAllocatorImpl( auto& factory = *static_cast(this_ptr); *allocator = nullptr; - const OrtMemoryInfo* default_memory_info_ptr = factory.default_memory_info_.operator OrtMemoryInfo*(); - const OrtMemoryInfo* pinned_memory_info_ptr = factory.pinned_memory_info_.operator OrtMemoryInfo*(); - - auto is_equal_memory_info = [&](const OrtMemoryInfo* expected, bool& out_equal) -> OrtStatus* { - int is_equal = 0; - auto* status = factory.ort_api_.CompareMemoryInfo(memory_info, expected, &is_equal); - if (status != nullptr) { - return status; - } - out_equal = (is_equal != 0); - return nullptr; - }; - - bool is_default = false; - bool is_pinned = false; - RETURN_IF_ERROR(is_equal_memory_info(default_memory_info_ptr, is_default)); - RETURN_IF_ERROR(is_equal_memory_info(pinned_memory_info_ptr, is_pinned)); + const char* name = ""; + factory.ort_api_.MemoryInfoGetName(memory_info, &name); + int req_device_id = 0; + factory.ort_api_.MemoryInfoGetId(memory_info, &req_device_id); - if (is_default) { - auto cuda_allocator = std::make_unique(memory_info, factory.device_id_); + if (strcmp(name, "Cuda") == 0) { + auto cuda_allocator = std::make_unique(memory_info, req_device_id); *allocator = cuda_allocator.release(); return nullptr; } - if (is_pinned) { + if (strcmp(name, "CudaPinned") == 0) { auto pinned_allocator = std::make_unique(memory_info); *allocator = pinned_allocator.release(); return nullptr; diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc index dbb0a9330d262..4eaa057a68cff 100644 --- a/onnxruntime/python/onnxruntime_pybind_module.cc +++ b/onnxruntime/python/onnxruntime_pybind_module.cc @@ -117,7 +117,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { "get_available_providers", []() -> std::vector { auto available = GetAvailableExecutionProviderNames(); #if !defined(ORT_MINIMAL_BUILD) - InlinedHashSet existing; + InlinedHashSet existing; existing.reserve(available.size()); for (const auto& ep_name : available) { existing.insert(ep_name); From 59fda409d57cbd5f08a5abf5bdd14296fd67df7d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 23 Mar 2026 17:21:29 -0700 Subject: [PATCH 08/48] fix test --- onnxruntime/contrib_ops/cuda/inverse.cc | 6 +- .../core/providers/cuda/nn/conv_transpose.cc | 13 +--- .../core/providers/cuda/nn/conv_transpose.h | 2 + .../core/providers/cuda/nn/conv_transpose_8.h | 6 +- .../providers/cuda/plugin/cuda_ep_factory.cc | 1 + .../cuda/plugin/cuda_kernel_adapter.h | 61 ++++++++++++++----- .../cuda/plugin/cuda_stream_plugin.cc | 46 +++++++------- 7 files changed, 84 insertions(+), 51 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/inverse.cc b/onnxruntime/contrib_ops/cuda/inverse.cc index e3ece229142aa..68f06b51c0067 100644 --- a/onnxruntime/contrib_ops/cuda/inverse.cc +++ b/onnxruntime/contrib_ops/cuda/inverse.cc @@ -65,7 +65,8 @@ Status CheckForSingularity(cudaStream_t stream, const IAllocatorUniquePtr& template struct Inverse::ComputeImpl { - Status operator()(void* ort_stream, Inverse::CublasHandle cublas_h, const Inverse* inst, const Tensor& input, Tensor& output, + Status operator()(void* ort_stream, cudaStream_t stream, Inverse::CublasHandle cublas_h, const Inverse* inst, + const Tensor& input, Tensor& output, const IAllocatorUniquePtr& info, const IAllocatorUniquePtr& pivots, size_t num_batches, size_t rows) const { using namespace onnxruntime::cuda; @@ -75,7 +76,6 @@ struct Inverse::ComputeImpl { auto info_cpu = std::make_unique(num_batches); const auto dim = static_cast(rows); const auto n_batches = static_cast(num_batches); - cudaStream_t stream = ort_stream ? static_cast(ort_stream) : nullptr; // Make a copy of the input which will serve as a workspace as well. if constexpr (std::is_same::value || std::is_same::value) { @@ -156,7 +156,7 @@ Status Inverse::ComputeInternal(OpKernelContext* ctx) const { utils::MLTypeCallDispatcher t_disp(input->GetElementType()); return t_disp.InvokeRet( - GetComputeStream(ctx), GetCublasHandle(ctx), this, *input, *output, info, pivots, num_batches, rows); + GetComputeStream(ctx), Stream(ctx), GetCublasHandle(ctx), this, *input, *output, info, pivots, num_batches, rows); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 0ad19815f0600..16d219ee4ef1c 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -225,12 +225,9 @@ Status ConvTranspose::CreateCudnnFeExecutionPlan(const onnxruntime::T template Status ConvTranspose::UpdateState(OpKernelContext* context, bool dynamic_padding) const { constexpr bool channels_last = Layout == LAYOUT_NHWC; - -#ifdef BUILD_CUDA_EP_AS_PLUGIN size_t num_inputs = static_cast(Info().GetInputCount()); -#else - size_t num_inputs = OpKernel::Node().InputDefs().size(); -#endif + // Standard ONNX ConvTranspose has inputs X, W, optional B. + // ConvTransposeWithDynamicPads inserts Pads at input 2, so bias becomes input 3. bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; // set X @@ -487,11 +484,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool CUDA_RETURN_IF_ERROR(cudaMemset(s_.y_data, 0, s_.Y->SizeInBytes())); } } -#ifdef BUILD_CUDA_EP_AS_PLUGIN - auto ws = GetWorkSpace(context->GetGPUComputeStream()); -#else - auto ws = GetWorkSpace(context->GetComputeStream()); -#endif + auto ws = GetWorkSpace(GetComputeStream(context)); CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_graph->execute(cudnn_handle, s_.variant_pack, diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index 62deb0475289a..072ef5c3221ef 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -24,6 +24,8 @@ class ConvTranspose : public CudaKernel { Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; + // `dynamic_padding` is used by the contrib op ConvTransposeWithDynamicPads, + // which adds a Pads input before the optional bias input. Status DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const; private: diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h index d296c8540dd5f..cf0a2723111b8 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h @@ -48,8 +48,10 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy TensorShapeVector w_dims = w_shape.AsShapeVector(); auto w_data = reinterpret_cast(W->Data()); - const Tensor* B_tensor = context->Input(dynamic_padding ? 3 : 2); - bool has_bias = B_tensor != nullptr; + const size_t num_inputs = static_cast(Info().GetInputCount()); + // Standard ONNX ConvTranspose has inputs X, W, optional B. + // ConvTransposeWithDynamicPads inserts Pads at input 2, so bias becomes input 3. + bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; CudaT* y_data = nullptr; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index 58ab8fff413b1..f7083b178a71e 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -189,6 +189,7 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( ORT_INVALID_ARGUMENT, "CUDA EP factory currently supports only one device at a time."); } + ORT_RETURN_IF_NOT(devices != nullptr && devices[0] != nullptr, "CUDA EP factory requires a valid device"); // Parse configuration from session options. // The read helpers intentionally swallow errors: if a config entry is diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index a0bf3e4b1862f..f509901c54cde 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -306,6 +306,7 @@ struct PluginNoOpLogStream { #include #include #include +#include #include #include @@ -343,6 +344,34 @@ inline size_t BytesForCount(size_t count_or_bytes, size_t element_size) { if (count_or_bytes > (std::numeric_limits::max() / element_size)) return 0; return count_or_bytes * element_size; } + +template +IConstantBuffer* GetConstOnesBufferForDevice(int device_id) { + static std::mutex mutex; + static std::unordered_map>> buffers; + std::lock_guard lock(mutex); + auto& buffer = buffers[device_id]; + if (!buffer) { + buffer = CreateConstantOnes(); + } + return buffer.get(); +} + +inline const cudaDeviceProp& GetDevicePropForDevice(int device_id) { + static std::mutex mutex; + static std::unordered_map> props; + std::lock_guard lock(mutex); + auto it = props.find(device_id); + if (it == props.end()) { + auto prop = std::make_unique(); + if (cudaGetDeviceProperties(prop.get(), device_id) != cudaSuccess) { + std::memset(prop.get(), 0, sizeof(*prop)); + prop->major = -1; + } + it = props.emplace(device_id, std::move(prop)).first; + } + return *it->second; +} } // namespace detail } // namespace cuda @@ -370,16 +399,8 @@ class CUDAExecutionProvider : public onnxruntime::IExecutionProvider { return false; } const cudaDeviceProp& GetDeviceProp() const { - static cudaDeviceProp prop; - static std::once_flag flag; - std::call_once(flag, []() { - int device_id = cuda::detail::GetCudaKernelAdapterRuntimeConfig().device_id.load(std::memory_order_relaxed); - if (cudaGetDeviceProperties(&prop, device_id) != cudaSuccess) { - std::memset(&prop, 0, sizeof(prop)); - prop.major = -1; - } - }); - return prop; + int device_id = cuda::detail::GetCudaKernelAdapterRuntimeConfig().device_id.load(std::memory_order_relaxed); + return cuda::detail::GetDevicePropForDevice(device_id); } }; @@ -664,9 +685,7 @@ class CudaKernel : public OpKernel { // Delegates to IConstantBuffer from cuda_utils.h (compiled in cuda_utils.cu). template const T* GetConstOnes(size_t count, cudaStream_t stream) const { - static std::unique_ptr> buf; - static std::once_flag flag; - std::call_once(flag, []() { buf = CreateConstantOnes(); }); + auto* buf = detail::GetConstOnesBufferForDevice(device_id_); return buf->GetBuffer(stream, count); } @@ -677,11 +696,23 @@ class CudaKernel : public OpKernel { if (cnt == 0) return IAllocatorUniquePtr(nullptr, [](T*) {}); size_t sz = detail::BytesForCount(cnt, detail::SizeOf::value); void* p = nullptr; - if (cudaMalloc(&p, sz) != cudaSuccess) return IAllocatorUniquePtr(nullptr, [](T*) {}); + cudaError_t alloc_result = cudaSuccess; + if (s) { + alloc_result = cudaMallocAsync(&p, sz, static_cast(s)); + if (alloc_result == cudaErrorNotSupported || alloc_result == cudaErrorInvalidValue) { + alloc_result = cudaMalloc(&p, sz); + } + } else { + alloc_result = cudaMalloc(&p, sz); + } + if (alloc_result != cudaSuccess) return IAllocatorUniquePtr(nullptr, [](T*) {}); return IAllocatorUniquePtr(static_cast(p), [s](T* ptr) { if (ptr) { if (s) { - cudaFreeAsync(ptr, static_cast(s)); + cudaError_t free_result = cudaFreeAsync(ptr, static_cast(s)); + if (free_result == cudaErrorNotSupported || free_result == cudaErrorInvalidValue) { + cudaFree(ptr); + } } else { cudaFree(ptr); } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc index c1e753440cca4..06d16cb92ca8f 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc @@ -12,17 +12,17 @@ namespace { // Global stream-to-CudaSyncStream mapping. // Required because migrated CUDA kernels receive only a raw cudaStream_t -// but need access to associated cuBLAS/cuDNN handles. The map is -// lazily initialized and never freed (process-lifetime singleton). -static std::unordered_map* g_stream_map = nullptr; -static std::mutex* g_stream_map_mutex = nullptr; -static std::once_flag g_stream_map_init_flag; - -void InitStreamMap() { - std::call_once(g_stream_map_init_flag, []() { - g_stream_map = new std::unordered_map(); - g_stream_map_mutex = new std::mutex(); - }); +// but need access to associated cuBLAS/cuDNN handles. +using StreamMap = std::unordered_map; + +StreamMap& GetStreamMap() { + static StreamMap stream_map; + return stream_map; +} + +std::mutex& GetStreamMapMutex() { + static std::mutex stream_map_mutex; + return stream_map_mutex; } } // namespace @@ -117,25 +117,29 @@ void CudaSyncStream::CleanupDeferredCPUBuffers() { } /*static*/ CudaSyncStream* CudaSyncStream::FromCudaStream(cudaStream_t stream) { - InitStreamMap(); - std::lock_guard lock(*g_stream_map_mutex); - auto it = g_stream_map->find(stream); - if (it != g_stream_map->end()) { + if (stream == nullptr) { + return nullptr; + } + + auto& stream_map = GetStreamMap(); + std::lock_guard lock(GetStreamMapMutex()); + auto it = stream_map.find(stream); + if (it != stream_map.end()) { return it->second; } return nullptr; } /*static*/ void CudaSyncStream::RegisterStream(cudaStream_t stream, CudaSyncStream* sync_stream) { - InitStreamMap(); - std::lock_guard lock(*g_stream_map_mutex); - (*g_stream_map)[stream] = sync_stream; + auto& stream_map = GetStreamMap(); + std::lock_guard lock(GetStreamMapMutex()); + stream_map[stream] = sync_stream; } /*static*/ void CudaSyncStream::UnregisterStream(cudaStream_t stream) { - if (!g_stream_map) return; - std::lock_guard lock(*g_stream_map_mutex); - g_stream_map->erase(stream); + auto& stream_map = GetStreamMap(); + std::lock_guard lock(GetStreamMapMutex()); + stream_map.erase(stream); } // --------------------------------------------------------------------------- From 87edf0eb606dd111b9208542df690098f44dd83b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 24 Mar 2026 00:56:43 -0700 Subject: [PATCH 09/48] add more ops --- cmake/onnxruntime_providers_cuda_plugin.cmake | 111 ++++++------------ .../contrib_ops/cuda/bert/attention.cc | 29 +++-- .../cuda/bert/decoder_attention.cc | 21 ++-- .../cuda/bert/group_query_attention.cc | 4 +- .../cuda/bert/longformer_attention.cc | 12 +- .../cuda/bert/multihead_attention.cc | 23 ++-- .../contrib_ops/cuda/bert/paged_attention.cc | 4 +- .../contrib_ops/cuda/diffusion/group_norm.cc | 4 +- onnxruntime/contrib_ops/cuda/moe/moe.cc | 22 ++-- .../quantization/attention_quantization.cc | 15 ++- .../cuda/quantization/moe_quantization.cc | 28 +++-- .../cuda/sparse/sparse_attention.cc | 4 +- .../object_detection/non_max_suppression.cc | 10 +- .../providers/cuda/plugin/cuda_ep_factory.cc | 18 ++- .../core/providers/cuda/tensor/upsample.cc | 5 +- 15 files changed, 166 insertions(+), 144 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index 9834700ec220b..e840aa150bd1b 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -109,76 +109,33 @@ list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/llm/.*") # Exclude constant_of_shape — inherits from ConstantOfShapeBase (CPU provider) # which is not linked into the plugin. -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/generator/constant_of_shape\\.cc$") - -# matmul_integer.cc: GetComputeStream fixed, GemmInt8 signature updated. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/math/matmul_integer\.cc$") # REMOVED in Stage 5 - -# matmul.cc: GetComputeStream fixed, GetTuningContext guarded. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/math/matmul\.cc$") # REMOVED in Stage 5 - -# variadic_elementwise_ops.cc: adapter InputCount/RequiredInput/RequiredOutput supported. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/math/variadic_elementwise_ops\\.cc$") # REMOVED in Stage 5C - -# slice.cc: plugin-local wrappers added for SliceBase::PrepareForCompute/FlattenOutputDims. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/slice\\.cc$") # REMOVED in Stage 5C.3 +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/generator/constant_of_shape\\.cc$") # Exclude space_depth_ops — inherits from SpaceDepthBase (CPU provider). -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/space_depth_ops\\.cc$") - -# concat.cc: InputArgCount/GetComputeStream usage fixed for adapter. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/concat\\.cc$") # REMOVED in Stage 5A - -# gather.cc: switched to GatherBase::PrepareForComputeImpl for adapter context. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/gather\\.cc$") # REMOVED in Stage 5B - -# gather_nd.cc: PrepareCompute signature changed to void*/cudaStream_t, GetComputeStream fixed. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/gather_nd\\.cc$") # REMOVED in Stage 5 - -# pad.cc: plugin-local wrappers added for PadBase static helpers. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/pad\\.cc$") # REMOVED in Stage 5C.2 - -# reshape.cc: GetComputeStream/CopyTensor framework dependency fixed for adapter. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/reshape\\.cc$") # REMOVED in Stage 5A - -# split.cc: GetComputeStream usage fixed for adapter via CudaKernel::GetComputeStream. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/split\\.cc$") # REMOVED in Stage 5A +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/space_depth_ops\\.cc$") # Exclude object_detection/ — NonMaxSuppression and RoiAlign inherit from CPU # base classes (NonMaxSuppressionBase, RoiAlignBase) not linked into the plugin. -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/object_detection/.*") -list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/object_detection/.*") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/object_detection/.*") +# list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/object_detection/.*") # Exclude upsample.cc — UpsampleBase uses InputDefs() and # OpKernelInfo::GetAllocator() not available in adapter. -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/upsample\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/upsample\\.cc$") # Exclude resize.cc — Resize inherits from Upsample (excluded above). -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/resize\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/resize\\.cc$") # Exclude einsum — einsum_auxiliary_ops.cc calls ReductionOps::ReduceCompute # which is framework-only (guarded by #ifndef BUILD_CUDA_EP_AS_PLUGIN). -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/math/einsum_utils/.*") -list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/math/einsum_utils/.*") - -# unsqueeze.cc: plugin-local PrepareCompute path added for adapter context. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/unsqueeze\\.cc$") # REMOVED in Stage 5B +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/math/einsum_utils/.*") +# list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/math/einsum_utils/.*") # Permanently excluded — pure CPU ops, handled by GetCpuPreferredNodes. # shape_op.cc inherits from onnxruntime::OpKernel (framework) # which cannot convert to ep::adapter::OpKernel in the plugin build. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/shape_op\\.cc$") -# cumsum.cc: axis parsing helper inlined for plugin build. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/math/cumsum\\.cc$") # REMOVED in Stage 5B - -# tile.cc: plugin-local IsTileMemcpy helper added. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/tile\\.cc$") # REMOVED in Stage 5B - -# --- Contrib op exclusions --- -# Exclude contrib ops that have dependencies not available in the plugin build. -# Note: aten_ops/ and collective/ exclusions are applied above (near the glob). - # Exclude contrib llm/ — uses onnxruntime::Stream* in QkvToContext. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/llm/.*") list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/llm/.*") @@ -187,37 +144,39 @@ list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/llm/.*") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/shrunken_gather\\.cc$") # Exclude contrib bert ops that use GetComputeStream() or framework OpKernelContext. -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/attention\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/decoder_attention\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/decoder_masked_self_attention\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/embed_layer_norm\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/fast_gelu\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/group_query_attention\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/longformer_attention\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/multihead_attention\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/packed_attention\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/packed_multihead_attention\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/paged_attention\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/relative_attn_bias\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/remove_padding\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/attention\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/decoder_attention\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/decoder_masked_self_attention\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/embed_layer_norm\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/fast_gelu\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/group_query_attention\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/longformer_attention\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/multihead_attention\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/packed_attention\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/packed_multihead_attention\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/paged_attention\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/relative_attn_bias\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/remove_padding\\.cc$") # Exclude contrib ops using GetComputeStream() or framework type deps. +# group_norm.cc still requires the real CudaTuningContext/Stream types. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/diffusion/group_norm\\.cc$") +# fused_conv.cc still depends on provider-only cuDNN config APIs beyond stream access. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/fused_conv\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/inverse\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/bias_dropout\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/inverse\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/bias_dropout\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/fft_ops\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/moe/moe\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/sparse/sparse_attention\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/moe/moe\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/sparse/sparse_attention\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/crop\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/dynamic_time_warping\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/dynamic_time_warping\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/dynamicslice\\.cc$") # Exclude contrib quantization ops with GetComputeStream() deps. -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/attention_quantization\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/matmul_bnb4\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/matmul_nbits\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/moe_quantization\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/attention_quantization\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/matmul_bnb4\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/matmul_nbits\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/moe_quantization\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/qordered_ops/.*") # Exclude contrib transformers/ (beam search, greedy search, sampling). @@ -225,8 +184,8 @@ list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/transforme list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/transformers/.*") # Exclude gemm_float8.cc/.cu — ComputeInternal is in .cu which uses GetComputeStream(). -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/gemm_float8\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/gemm_float8\\.cu$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/gemm_float8\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/gemm_float8\\.cu$") # fused_matmul.cc: matmul.cc is now included, so fused_matmul can be too. # list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/fused_matmul\\.cc$") # REMOVED in Stage 5 @@ -246,7 +205,7 @@ set_target_properties(onnxruntime_providers_cuda_plugin PROPERTIES # adapter OpKernelInfo::GetAttr<> (GCC falsely warns about variables that are # initialised inside GetAttr’s output parameter path). target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE - $<$:-Wno-maybe-uninitialized> + $<$,$>:-Wno-maybe-uninitialized> ) target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE # Flash-attention, XQA, MoE, and other pure CUDA kernel .cu files must NOT diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index cff5d5d320423..e46ead949ce4f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -58,6 +58,13 @@ Attention::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionB template Status Attention::ComputeInternal(OpKernelContext* context) const { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + onnxruntime::PluginStreamShim plugin_stream_shim(GetComputeStream(context)); + auto* ort_stream = static_cast(&plugin_stream_shim); +#else + auto* ort_stream = context->GetComputeStream(); +#endif + const Tensor* input = context->Input(0); const Tensor* weights = context->Input(1); const Tensor* bias = context->Input(2); @@ -139,14 +146,14 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; } - auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); - auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); - auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, GetComputeStream(context)); + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, GetComputeStream(context)); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, GetComputeStream(context)); #else constexpr bool use_flash_attention = false; - auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); - auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr - auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto softmax_lse_buffer = GetScratchBuffer(0, GetComputeStream(context)); + auto softmax_lse_accum_buffer = GetScratchBuffer(0, GetComputeStream(context)); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, GetComputeStream(context)); // nullptr #endif if (!use_flash_attention) { @@ -243,7 +250,11 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { int m = batch_size * sequence_length; int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size); int k = parameters.input_hidden_size; +#ifdef BUILD_CUDA_EP_AS_PLUGIN + IAllocatorUniquePtr gemm_buffer = GetScratchBuffer(static_cast(m * n) * sizeof(T), GetComputeStream(context)); +#else IAllocatorUniquePtr gemm_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(m * n) * sizeof(T), false, context->GetComputeStream()); +#endif CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); @@ -275,7 +286,11 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_memory_efficient_attention, use_cudnn_flash_attention, false); +#ifdef BUILD_CUDA_EP_AS_PLUGIN + IAllocatorUniquePtr work_space = GetScratchBuffer(workSpaceSize, GetComputeStream(context)); +#else IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); +#endif data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); if (nullptr != bias) { @@ -313,7 +328,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } cudnnHandle_t cudnn = GetCudnnHandle(context); - return QkvToContext(device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data); + return QkvToContext(device_prop, cublas, cudnn, ort_stream, parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index 5e5f909415fff..6478a3cca78a5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -175,6 +175,13 @@ DecoderAttention::DecoderAttention(const OpKernelInfo& info) : CudaKernel(inf template Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + onnxruntime::PluginStreamShim plugin_stream_shim(GetComputeStream(context)); + auto* ort_stream = static_cast(&plugin_stream_shim); +#else + auto* ort_stream = context->GetComputeStream(); +#endif + const Tensor* query(context->Input(0)); const Tensor* key(context->Input(1)); const Tensor* q_weights(context->Input(2)); @@ -262,7 +269,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { // calculate q gemm_query_buffer_p = GetScratchBuffer(static_cast(batch_size) * sequence_length * hidden_size, - context->GetComputeStream()); + GetComputeStream(context)); m = sequence_length * batch_size; n = hidden_size; k = hidden_size; @@ -288,7 +295,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { if (!has_layer_state_ || !use_past_) { if (!static_kv_) { gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * sequence_length * hidden_size, - context->GetComputeStream()); + GetComputeStream(context)); m = sequence_length * batch_size; n = 2 * hidden_size; k = hidden_size; @@ -308,7 +315,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { // gemm_kv_buffer in col-base: (2*h2, T_S*B) } else { gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * key_sequence_length * hidden_size, - context->GetComputeStream()); + GetComputeStream(context)); m = key_sequence_length * batch_size; n = 2 * hidden_size; k = hidden_size; @@ -334,7 +341,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { int cache_sequence_length = static_cast(cache_shape[2]); if (!static_kv_) { gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * sequence_length * hidden_size, - context->GetComputeStream()); + GetComputeStream(context)); m = sequence_length * batch_size; kv_sequence_length = cache_sequence_length + sequence_length; // broadcast bias for key and value: (2*h2, T_S*B) @@ -357,11 +364,11 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { size_t bytes = element_size * batch_size * (static_cast(sequence_length) + static_cast(2) * kv_sequence_length) * hidden_size; - auto qkv_buffer_p = GetScratchBuffer(bytes, context->GetComputeStream()); + auto qkv_buffer_p = GetScratchBuffer(bytes, GetComputeStream(context)); bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * (static_cast(2) * head_size + static_cast(kv_sequence_length)); - auto workspace_p = GetScratchBuffer(bytes, context->GetComputeStream()); + auto workspace_p = GetScratchBuffer(bytes, GetComputeStream(context)); Tensor* output(context->Output(0, query_shape)); TensorShape new_cache_shape({batch_size, num_heads_, kv_sequence_length, head_size}); @@ -371,7 +378,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { return LaunchDecoderAttentionKernel( device_prop, UseTF32(), - context->GetComputeStream(), + ort_stream, cublas, element_size, batch_size, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index ae2fd1c43ceb7..289ed0fc55f41 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -144,8 +144,8 @@ template Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { // Stream access: void* for GetScratchBuffer, Stream* for QkvToContext. #ifdef BUILD_CUDA_EP_AS_PLUGIN - onnxruntime::PluginStreamShim __stream_shim(GetComputeStream(context)); - auto* ort_stream = static_cast(&__stream_shim); + onnxruntime::PluginStreamShim plugin_stream_shim(GetComputeStream(context)); + auto* ort_stream = static_cast(&plugin_stream_shim); #else auto* ort_stream = context->GetComputeStream(); #endif diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc index 9c5d0e9834f6f..3501c3baff0b6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc @@ -78,11 +78,11 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { // TODO(tianleiwu): only calculate global index once per model instead of once per LongformerAttention node. // Build Global Index - auto global_index_buffer = GetScratchBuffer(static_cast(batch_size) * sequence_length, context->GetComputeStream()); - auto batch_global_num_buffer = GetScratchBuffer(batch_size, context->GetComputeStream()); + auto global_index_buffer = GetScratchBuffer(static_cast(batch_size) * sequence_length, GetComputeStream(context)); + auto batch_global_num_buffer = GetScratchBuffer(batch_size, GetComputeStream(context)); size_t global_scratch_bytes = GetGlobalScratchSize(sequence_length); - auto global_scratch_buffer = GetScratchBuffer(global_scratch_bytes, context->GetComputeStream()); + auto global_scratch_buffer = GetScratchBuffer(global_scratch_bytes, GetComputeStream(context)); auto& device_prop = GetDeviceProp(); ORT_RETURN_IF_ERROR(BuildGlobalIndex( @@ -116,7 +116,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { size_t qkv_size = static_cast(batch_size) * sequence_length * 3 * hidden_size * element_size; // Buffer for GEMM outputs of q, k, v, global_q, global_k and global_v // TODO(tianleiwu): compact global_q only need batch_size * window * hidden_size * element_size buffer size. - auto gemm_buffer = GetScratchBuffer(qkv_size + qkv_size, context->GetComputeStream()); + auto gemm_buffer = GetScratchBuffer(qkv_size + qkv_size, GetComputeStream(context)); bool use_merged_qkv_weights = (weights->Shape().NumDimensions() == 2); @@ -257,7 +257,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { max_num_global, window_, disable_compact_memory); - auto workspace_buffer = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + auto workspace_buffer = GetScratchBuffer(workSpaceSize, GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchLongformerAttentionKernel( device_prop, cublas, @@ -285,7 +285,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { use_half4_)); // Defer release of pinned memory since cudaStreamSynchronize is not used here and kernel need access the buffer. - this->AddDeferredReleaseCPUPtr(pinned_buffer.release(), context->GetComputeStream()); + this->AddDeferredReleaseCPUPtr(pinned_buffer.release(), GetComputeStream(context)); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 452a00deeb21c..e06437c0f07e2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -88,6 +88,13 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) template Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + onnxruntime::PluginStreamShim plugin_stream_shim(GetComputeStream(context)); + auto* ort_stream = static_cast(&plugin_stream_shim); +#else + auto* ort_stream = context->GetComputeStream(); +#endif + const Tensor* query = context->Input(0); const Tensor* key = context->Input(1); const Tensor* value = context->Input(2); @@ -290,7 +297,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) cons kernel_type = AttentionKernelType::AttentionKernel_LeanAttention; } - auto lean_sync_flag_buffer = GetScratchBuffer(sync_flag_bytes, context->GetComputeStream()); + auto lean_sync_flag_buffer = GetScratchBuffer(sync_flag_bytes, GetComputeStream(context)); data.lean_sync_flag = reinterpret_cast(lean_sync_flag_buffer.get()); #else constexpr bool use_lean_attention = false; @@ -336,9 +343,9 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) cons #endif #if USE_LEAN_ATTENTION || USE_FLASH_ATTENTION - auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); - auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); - auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, GetComputeStream(context)); + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, GetComputeStream(context)); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, GetComputeStream(context)); if (use_flash_attention || use_lean_attention) { data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); @@ -485,7 +492,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) cons data.use_memory_efficient_attention = use_memory_efficient_attention; data.use_decoder_masked_multihead_attention = use_decoder_masked_multihead_attention; data.kernel_type = kernel_type; - data.allocator = Info().GetAllocator(OrtMemType::OrtMemTypeDefault); + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&data.allocator)); // Cache of cumulated sequence length that could help when sequence length does not change (for example, image model). // The cache will be initialized only once, and become readonly after that. @@ -515,7 +522,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) cons use_memory_efficient_attention, use_cudnn_sdpa, no_qkv_workspace); - auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + auto work_space = GetScratchBuffer(workspace_bytes, GetComputeStream(context)); data.has_qkv_workspace = !no_qkv_workspace; data.workspace = reinterpret_cast(work_space.get()); @@ -528,7 +535,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) cons std::vector seqlens_k(parameters.batch_size, parameters.total_sequence_length - 1); size_t seqlens_k_bytes = 0; seqlens_k_bytes = sizeof(int) * parameters.batch_size; - auto seqlens_k_buffer = GetScratchBuffer(seqlens_k_bytes, context->GetComputeStream()); + auto seqlens_k_buffer = GetScratchBuffer(seqlens_k_bytes, GetComputeStream(context)); if (seqlens_k_buffer != nullptr) { data.seqlens_k_total = reinterpret_cast(seqlens_k_buffer.get()); CUDA_RETURN_IF_ERROR(cudaMemcpy(data.seqlens_k_total, seqlens_k.data(), seqlens_k_bytes, cudaMemcpyHostToDevice)); @@ -557,7 +564,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) cons cudnnHandle_t cudnn = GetCudnnHandle(context); DUMP_STRING("Run QkvToContext from MHA CUDA"); return QkvToContext( - device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data); + device_prop, cublas, cudnn, ort_stream, parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc index 5f68282726c2a..7a3acc2867745 100644 --- a/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc @@ -56,8 +56,8 @@ template Status PagedAttention::ComputeInternal(OpKernelContext* context) const { // Stream access: void* for GetScratchBuffer, Stream* for QkvToContext. #ifdef BUILD_CUDA_EP_AS_PLUGIN - onnxruntime::PluginStreamShim __stream_shim(GetComputeStream(context)); - auto* ort_stream = static_cast(&__stream_shim); + onnxruntime::PluginStreamShim plugin_stream_shim(GetComputeStream(context)); + auto* ort_stream = static_cast(&plugin_stream_shim); #else auto* ort_stream = context->GetComputeStream(); #endif diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index dea5391c7629b..baa284f2fb859 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -208,11 +208,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { } auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups_), - context->GetComputeStream()); + GetComputeStream(context)); utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); return dispatcher.InvokeRet(GetTuningContext(), - context->GetComputeStream(), output, add_out, input, skip, bias, + GetComputeStream(context), output, add_out, input, skip, bias, gamma, beta, workspace.get(), epsilon_, batch_size, diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 707eb24a386a9..e088e5241cc93 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -49,8 +49,6 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { 0)); // no block-wise quantization for regular MoE using CudaT = typename OrtToCudaType::type; - auto stream = context->GetComputeStream(); - auto& device_prop = GetDeviceProp(); const int sm = device_prop.major * 10 + device_prop.minor; @@ -68,18 +66,28 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int); +#ifdef BUILD_CUDA_EP_AS_PLUGIN + IAllocatorUniquePtr work_space = this->template GetScratchBuffer(ws_size, this->GetComputeStream(context)); + IAllocatorUniquePtr fc2_output = this->template GetScratchBuffer(fc2_output_size, this->GetComputeStream(context)); + IAllocatorUniquePtr expert_scales = this->template GetScratchBuffer(expert_scales_size, this->GetComputeStream(context)); + IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = + this->template GetScratchBuffer(expanded_source_row_to_expanded_dest_row_size, this->GetComputeStream(context)); + IAllocatorUniquePtr expert_for_source_row = + this->template GetScratchBuffer(expert_for_source_row_size, this->GetComputeStream(context)); +#else AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); // TODO: allocate one buffer and reuse it. - IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, stream); - IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); + IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, context->GetComputeStream()); + IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, context->GetComputeStream()); IAllocatorUniquePtr expert_scales = - IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, stream); + IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, context->GetComputeStream()); IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = - IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream); + IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, context->GetComputeStream()); IAllocatorUniquePtr expert_for_source_row = - IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); + IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, context->GetComputeStream()); +#endif const CudaT* fc_scales_ptr = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 20ac68d602684..551cd3dcdb9e9 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -95,6 +95,13 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { // Output 0 - output : (batch_size, sequence_length, hidden_size) // Output 1 - present : (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size) +#ifdef BUILD_CUDA_EP_AS_PLUGIN + onnxruntime::PluginStreamShim plugin_stream_shim(GetComputeStream(context)); + auto* ort_stream = static_cast(&plugin_stream_shim); +#else + auto* ort_stream = context->GetComputeStream(); +#endif + const Tensor* input = context->Input(0); const Tensor* weights = context->Input(1); const Tensor* bias = context->Input(2); @@ -138,8 +145,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { int n = 3 * hidden_size; int k = parameters.input_hidden_size; size_t num_elements = SafeInt(m) * n; - auto gemm_buffer = GetScratchBuffer(num_elements * element_size, context->GetComputeStream()); - auto gemm_buffer_quantized = GetScratchBuffer(num_elements, context->GetComputeStream()); + auto gemm_buffer = GetScratchBuffer(num_elements * element_size, GetComputeStream(context)); + auto gemm_buffer_quantized = GetScratchBuffer(num_elements, GetComputeStream(context)); typedef typename ToCudaType::MappedType CudaT; @@ -198,7 +205,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { use_cudnn_flash_attention, true); - auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + auto work_space = GetScratchBuffer(workSpaceSize, GetComputeStream(context)); typedef typename ToCudaType::MappedType CudaT; AttentionData data; @@ -221,7 +228,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { } cudnnHandle_t cudnn = GetCudnnHandle(context); - return QkvToContext(GetDeviceProp(), cublas, cudnn, context->GetComputeStream(), parameters, data); + return QkvToContext(GetDeviceProp(), cublas, cudnn, ort_stream, parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index 31a793bb86f17..70f4690e60b92 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -57,13 +57,8 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, const Tensor* fc2_scales, const Tensor* fc3_scales_optional, const cudaDeviceProp& device_prop) const { - auto stream = context->GetComputeStream(); - const int sm = device_prop.major * 10 + device_prop.minor; - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - using CudaT = typename OrtToCudaType::type; ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, @@ -81,14 +76,27 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int); - IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, stream); - IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); +#ifdef BUILD_CUDA_EP_AS_PLUGIN + IAllocatorUniquePtr work_space = this->template GetScratchBuffer(ws_size, this->GetComputeStream(context)); + IAllocatorUniquePtr fc2_output = this->template GetScratchBuffer(fc2_output_size, this->GetComputeStream(context)); + IAllocatorUniquePtr expert_scales = this->template GetScratchBuffer(expert_scales_size, this->GetComputeStream(context)); + IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = + this->template GetScratchBuffer(expanded_source_row_to_expanded_dest_row_size, this->GetComputeStream(context)); + IAllocatorUniquePtr expert_for_source_row = + this->template GetScratchBuffer(expert_for_source_row_size, this->GetComputeStream(context)); +#else + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, context->GetComputeStream()); + IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, context->GetComputeStream()); IAllocatorUniquePtr expert_scales = - IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, stream); + IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, context->GetComputeStream()); IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = - IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream); + IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, context->GetComputeStream()); IAllocatorUniquePtr expert_for_source_row = - IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); + IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, context->GetComputeStream()); +#endif moe_runner.run_moe_fc( reinterpret_cast(input->template Data()), diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc index bb4a9fabaca7e..63ed7ce189a7b 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc @@ -62,8 +62,8 @@ template Status SparseAttention::ComputeInternal(OpKernelContext* context) const { // Stream access: void* for GetScratchBuffer, Stream* for QkvToContext. #ifdef BUILD_CUDA_EP_AS_PLUGIN - onnxruntime::PluginStreamShim __stream_shim(GetComputeStream(context)); - auto* ort_stream = static_cast(&__stream_shim); + onnxruntime::PluginStreamShim plugin_stream_shim(GetComputeStream(context)); + auto* ort_stream = static_cast(&plugin_stream_shim); #else auto* ort_stream = context->GetComputeStream(); #endif diff --git a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression.cc b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression.cc index b55efc1180f10..1fbe0573301c4 100644 --- a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression.cc +++ b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression.cc @@ -62,7 +62,7 @@ Status NonMaxSuppression::ComputeInternal(OpKernelContext* ctx) const { IAllocatorUniquePtr d_selected_indices{}; IAllocatorUniquePtr h_number_selected_ptr{AllocateBufferOnCPUPinned(sizeof(int))}; auto* h_number_selected = static_cast(h_number_selected_ptr.get()); - auto* stream = ctx->GetComputeStream(); + auto* stream = GetComputeStream(ctx); ORT_RETURN_IF_ERROR(NonMaxSuppressionImpl( Stream(ctx), [this, stream](size_t bytes) { return GetScratchBuffer(bytes, stream); }, @@ -114,10 +114,10 @@ Status NonMaxSuppression::ComputeInternal(OpKernelContext* ctx) const { } } - ORT_RETURN_IF_ERROR(concat_sizes_gpu.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(concat_sizes_range_gpu.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(concat_sizes_gpu.CopyToGpu(GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu(GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(concat_sizes_range_gpu.CopyToGpu(GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu(GetComputeStream(ctx))); ORT_RETURN_IF_ERROR(ConcatImpl(Stream(ctx), sizeof(int64_t), diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index f7083b178a71e..605c45a133809 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -173,7 +173,7 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( /*static*/ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( OrtEpFactory* this_ptr, - const OrtHardwareDevice* const* /*devices*/, + const OrtHardwareDevice* const* devices, const OrtKeyValuePairs* const* /*ep_metadata*/, size_t num_devices, const OrtSessionOptions* session_options, @@ -189,7 +189,11 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( ORT_INVALID_ARGUMENT, "CUDA EP factory currently supports only one device at a time."); } - ORT_RETURN_IF_NOT(devices != nullptr && devices[0] != nullptr, "CUDA EP factory requires a valid device"); + if (devices == nullptr || devices[0] == nullptr) { + return factory->ort_api_.CreateStatus( + ORT_INVALID_ARGUMENT, + "CUDA EP factory requires a valid device."); + } // Parse configuration from session options. // The read helpers intentionally swallow errors: if a config entry is @@ -274,9 +278,15 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateAllocatorImpl( *allocator = nullptr; const char* name = ""; - factory.ort_api_.MemoryInfoGetName(memory_info, &name); + OrtStatus* status = factory.ort_api_.MemoryInfoGetName(memory_info, &name); + if (status != nullptr) { + return status; + } int req_device_id = 0; - factory.ort_api_.MemoryInfoGetId(memory_info, &req_device_id); + status = factory.ort_api_.MemoryInfoGetId(memory_info, &req_device_id); + if (status != nullptr) { + return status; + } if (strcmp(name, "Cuda") == 0) { auto cuda_allocator = std::make_unique(memory_info, req_device_id); diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index ee0e2a66f48d3..7ba2eed09353d 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -107,9 +107,10 @@ Status Upsample::BaseCompute(OpKernelContext* context, if (antialias_) { #ifdef BUILD_CUDA_EP_AS_PLUGIN const uint8_t* lookup_table = GetLookupTableShared(); - auto shared_lookup_table_ondevice = GetScratchBuffer(kLookupTableSize, GetComputeStream(context)); - CUDA_CALL_THROW(cudaMemcpyAsync(shared_lookup_table_ondevice.get(), lookup_table, kLookupTableSize, + auto shared_lookup_table_ondevice_buffer = GetScratchBuffer(kLookupTableSize, GetComputeStream(context)); + CUDA_CALL_THROW(cudaMemcpyAsync(shared_lookup_table_ondevice_buffer.get(), lookup_table, kLookupTableSize, cudaMemcpyHostToDevice, Stream(context))); + const auto* shared_lookup_table_ondevice = shared_lookup_table_ondevice_buffer.get(); #else const auto* shared_lookup_table_ondevice = shared_lookup_table_ondevice_.get(); #endif From 5dbba29db08eebed4a95e9c0540ce3e93a4a62f6 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 24 Mar 2026 01:03:45 -0700 Subject: [PATCH 10/48] add fused conv --- cmake/onnxruntime_providers_cuda_plugin.cmake | 3 +-- onnxruntime/contrib_ops/cuda/fused_conv.cc | 10 +++++----- onnxruntime/core/providers/cuda/plugin/cuda_ep.cc | 2 +- onnxruntime/core/providers/cuda/plugin/cuda_ep.h | 1 + .../core/providers/cuda/plugin/cuda_ep_factory.cc | 2 ++ .../core/providers/cuda/plugin/cuda_kernel_adapter.h | 8 +++++++- 6 files changed, 17 insertions(+), 9 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index e840aa150bd1b..6d43519ec0a83 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -161,8 +161,7 @@ list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/shr # Exclude contrib ops using GetComputeStream() or framework type deps. # group_norm.cc still requires the real CudaTuningContext/Stream types. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/diffusion/group_norm\\.cc$") -# fused_conv.cc still depends on provider-only cuDNN config APIs beyond stream access. -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/fused_conv\\.cc$") +# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/fused_conv\\.cc$") # list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/inverse\\.cc$") # list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/bias_dropout\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/fft_ops\\.cc$") diff --git a/onnxruntime/contrib_ops/cuda/fused_conv.cc b/onnxruntime/contrib_ops/cuda/fused_conv.cc index 0554cc34933f1..34732798a656a 100644 --- a/onnxruntime/contrib_ops/cuda/fused_conv.cc +++ b/onnxruntime/contrib_ops/cuda/fused_conv.cc @@ -181,7 +181,7 @@ class FusedConv : public onnxruntime::cuda::CudaKernel { // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, - context->GetComputeStream()); + GetComputeStream(context)); s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); } else { // No post slicing needed. Fill the output tensor's buffer directly. @@ -338,7 +338,7 @@ class FusedConv : public onnxruntime::cuda::CudaKernel { } if (s_.post_slicing_required) { s_.memory_for_cudnn_conv_results = GetScratchBuffer( - TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); + TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, GetComputeStream(context)); s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); } else { s_.y_data = reinterpret_cast(s_.Y->MutableData()); @@ -358,7 +358,7 @@ class FusedConv : public onnxruntime::cuda::CudaKernel { bool has_b = nullptr != s_.b_data; const auto alpha = onnxruntime::cuda::Consts::One; const auto beta = onnxruntime::cuda::Consts::Zero; - IAllocatorUniquePtr workspace = GetWorkSpace(context->GetComputeStream()); + IAllocatorUniquePtr workspace = GetWorkSpace(GetComputeStream(context)); auto cudnn_status = cudnnConvolutionBiasActivationForward(cudnnHandle, &alpha, s_.x_tensor, @@ -422,7 +422,7 @@ class FusedConv : public onnxruntime::cuda::CudaKernel { return Status::OK(); } - inline IAllocatorUniquePtr GetWorkSpace(onnxruntime::Stream* stream) const { + inline IAllocatorUniquePtr GetWorkSpace(void* stream) const { return GetScratchBuffer(s_.workspace_bytes, stream); } @@ -455,4 +455,4 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( } // namespace cuda } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index 0e6badaf7ef94..8a9a5a4548e9b 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -44,7 +44,7 @@ CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& lo // Seed adapter-level runtime options for migrated kernels. onnxruntime::cuda::SetCudaKernelAdapterRuntimeConfig( config_.use_tf32, config_.device_id, config_.enable_skip_layer_norm_strict_mode, - config_.cudnn_conv_algo, config_.cudnn_conv1d_pad_to_nc1d); + config_.cudnn_conv_algo, config_.cudnn_conv_use_max_workspace, config_.cudnn_conv1d_pad_to_nc1d); } CudaEp::~CudaEp() = default; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.h b/onnxruntime/core/providers/cuda/plugin/cuda_ep.h index c973507d90f91..0e3fe81561af7 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.h @@ -23,6 +23,7 @@ class CudaEp : public OrtEp { bool enable_skip_layer_norm_strict_mode = false; ///< Strict mode for SkipLayerNorm kernel. int device_id = 0; ///< CUDA device ordinal. int cudnn_conv_algo = 0; ///< cuDNN convolution algorithm selection. + bool cudnn_conv_use_max_workspace = true; ///< Use maximum workspace for cuDNN conv algo search. bool cudnn_conv1d_pad_to_nc1d = false; ///< Pad 1D convolutions to NC1D format. }; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index 605c45a133809..33294851f50bf 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -245,12 +245,14 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( read_session_config_bool("prefer_nhwc", config.prefer_nhwc); read_session_config_bool("use_tf32", config.use_tf32); read_session_config_bool("enable_skip_layer_norm_strict_mode", config.enable_skip_layer_norm_strict_mode); + read_session_config_bool("cudnn_conv_use_max_workspace", config.cudnn_conv_use_max_workspace); read_session_config_bool("cudnn_conv1d_pad_to_nc1d", config.cudnn_conv1d_pad_to_nc1d); read_session_config_int("cudnn_conv_algo", config.cudnn_conv_algo); read_session_config_bool("ep.cuda.prefer_nhwc_layout", config.prefer_nhwc); read_session_config_bool("ep.cuda.use_tf32", config.use_tf32); read_session_config_bool("ep.cuda.enable_skip_layer_norm_strict_mode", config.enable_skip_layer_norm_strict_mode); + read_session_config_bool("ep.cuda.cudnn_conv_use_max_workspace", config.cudnn_conv_use_max_workspace); read_session_config_bool("ep.cuda.cudnn_conv1d_pad_to_nc1d", config.cudnn_conv1d_pad_to_nc1d); read_session_config_int("ep.cuda.cudnn_conv_algo", config.cudnn_conv_algo); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index f509901c54cde..00effe11e6249 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -325,6 +325,7 @@ struct CudaKernelAdapterRuntimeConfig { std::atomic skip_layer_norm_strict_mode{false}; std::atomic device_id{0}; std::atomic cudnn_conv_algo{0}; + std::atomic cudnn_conv_use_max_workspace{true}; std::atomic cudnn_conv1d_pad_to_nc1d{false}; }; inline CudaKernelAdapterRuntimeConfig& GetCudaKernelAdapterRuntimeConfig() { @@ -389,6 +390,9 @@ class CUDAExecutionProvider : public onnxruntime::IExecutionProvider { int GetCudnnConvAlgo() const { return cuda::detail::GetCudaKernelAdapterRuntimeConfig().cudnn_conv_algo.load(std::memory_order_relaxed); } + bool GetCudnnConvUseMaxWorkspace() const { + return cuda::detail::GetCudaKernelAdapterRuntimeConfig().cudnn_conv_use_max_workspace.load(std::memory_order_relaxed); + } bool GetCudnnConv1dPadToNc1d() const { return cuda::detail::GetCudaKernelAdapterRuntimeConfig().cudnn_conv1d_pad_to_nc1d.load(std::memory_order_relaxed); } @@ -407,12 +411,14 @@ class CUDAExecutionProvider : public onnxruntime::IExecutionProvider { namespace cuda { inline void SetCudaKernelAdapterRuntimeConfig(bool use_tf32, int device_id, bool skip_layer_norm_strict_mode = false, - int cudnn_conv_algo = 0, bool cudnn_conv1d_pad_to_nc1d = false) { + int cudnn_conv_algo = 0, bool cudnn_conv_use_max_workspace = true, + bool cudnn_conv1d_pad_to_nc1d = false) { auto& config = detail::GetCudaKernelAdapterRuntimeConfig(); config.use_tf32.store(use_tf32, std::memory_order_relaxed); config.skip_layer_norm_strict_mode.store(skip_layer_norm_strict_mode, std::memory_order_relaxed); config.device_id.store(device_id, std::memory_order_relaxed); config.cudnn_conv_algo.store(cudnn_conv_algo, std::memory_order_relaxed); + config.cudnn_conv_use_max_workspace.store(cudnn_conv_use_max_workspace, std::memory_order_relaxed); config.cudnn_conv1d_pad_to_nc1d.store(cudnn_conv1d_pad_to_nc1d, std::memory_order_relaxed); } From ede149326752a411fd41229da4bcd2be42b43d38 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 24 Mar 2026 12:00:45 -0700 Subject: [PATCH 11/48] Add group norm and qordered ops. --- cmake/onnxruntime_providers_cuda_plugin.cmake | 76 +------------------ .../contrib_ops/cuda/diffusion/group_norm.cc | 46 +++++++++++ .../qordered_ops/qordered_attention.cc | 4 +- .../qordered_longformer_attention.cc | 14 ++-- .../qordered_ops/qordered_matmul.cc | 2 +- .../quantization/qordered_ops/qordered_qdq.cc | 4 +- onnxruntime/core/providers/cuda/cuda_kernel.h | 4 + .../cuda/plugin/cuda_kernel_adapter.h | 12 +++ 8 files changed, 75 insertions(+), 87 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index 6d43519ec0a83..60c2faf4d2139 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -79,17 +79,6 @@ list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_common\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_nhwc_kernels\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_contrib_kernels\\.cc$") -# integer_gemm.cc: dynamic_cast replaced with GetCublasHandle(cudaStream_t). -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/integer_gemm\\.cc$") # REMOVED in Stage 5 - -# RNN ops: dual-build-compatible signatures are in place (void* alloc_stream, -# cudaStream_t, cudnnHandle_t), but the ORT C API lacks KernelInfoGetAttributeArray_string -# which rnn.h uses via GetAttrs("activations", ...). -# Re-excluded until the C API is extended to support string array attributes. -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/rnn/.*") - -list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/rnn/.*") - # Exclude files that use TensorSeq (incomplete type in plugin build). list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/identity_op\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/sequence_op\\.cc$") @@ -99,38 +88,11 @@ list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/sequence_op\\.cc$") # in the CPU provider and is not linked into the plugin. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/size\\.cc$") -# scatter_nd.cc: ValidateShapes inlined for plugin, GetComputeStream fixed. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/scatter_nd\\.cc$") # REMOVED in Stage 5 - # Exclude llm/ — attention.cc calls QkvToContext which dereferences # onnxruntime::Stream* (not available in plugin build's adapter OpKernelContext). list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/llm/.*") list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/llm/.*") -# Exclude constant_of_shape — inherits from ConstantOfShapeBase (CPU provider) -# which is not linked into the plugin. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/generator/constant_of_shape\\.cc$") - -# Exclude space_depth_ops — inherits from SpaceDepthBase (CPU provider). -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/space_depth_ops\\.cc$") - -# Exclude object_detection/ — NonMaxSuppression and RoiAlign inherit from CPU -# base classes (NonMaxSuppressionBase, RoiAlignBase) not linked into the plugin. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/object_detection/.*") -# list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/object_detection/.*") - -# Exclude upsample.cc — UpsampleBase uses InputDefs() and -# OpKernelInfo::GetAllocator() not available in adapter. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/upsample\\.cc$") - -# Exclude resize.cc — Resize inherits from Upsample (excluded above). -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/resize\\.cc$") - -# Exclude einsum — einsum_auxiliary_ops.cc calls ReductionOps::ReduceCompute -# which is framework-only (guarded by #ifndef BUILD_CUDA_EP_AS_PLUGIN). -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/math/einsum_utils/.*") -# list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/math/einsum_utils/.*") - # Permanently excluded — pure CPU ops, handled by GetCpuPreferredNodes. # shape_op.cc inherits from onnxruntime::OpKernel (framework) # which cannot convert to ep::adapter::OpKernel in the plugin build. @@ -143,52 +105,16 @@ list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/llm/.*") # Exclude contrib training ops (shrunken_gather depends on provider_api.h in header). list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/shrunken_gather\\.cc$") -# Exclude contrib bert ops that use GetComputeStream() or framework OpKernelContext. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/attention\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/decoder_attention\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/decoder_masked_self_attention\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/embed_layer_norm\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/fast_gelu\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/group_query_attention\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/longformer_attention\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/multihead_attention\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/packed_attention\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/packed_multihead_attention\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/paged_attention\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/relative_attn_bias\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/bert/remove_padding\\.cc$") - # Exclude contrib ops using GetComputeStream() or framework type deps. -# group_norm.cc still requires the real CudaTuningContext/Stream types. -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/diffusion/group_norm\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/fused_conv\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/inverse\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/bias_dropout\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/fft_ops\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/moe/moe\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/sparse/sparse_attention\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/crop\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/dynamic_time_warping\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/dynamicslice\\.cc$") -# Exclude contrib quantization ops with GetComputeStream() deps. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/attention_quantization\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/matmul_bnb4\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/matmul_nbits\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/moe_quantization\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/quantization/qordered_ops/.*") -# Exclude contrib transformers/ (beam search, greedy search, sampling). +# Exclude contrib transformers/ (beam search, greedy search, sampling). Those need subgraph inference. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/transformers/.*") list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/transformers/.*") -# Exclude gemm_float8.cc/.cu — ComputeInternal is in .cu which uses GetComputeStream(). -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/gemm_float8\\.cc$") -# list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/gemm_float8\\.cu$") - -# fused_matmul.cc: matmul.cc is now included, so fused_matmul can be too. -# list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/fused_matmul\\.cc$") # REMOVED in Stage 5 - # Create shared library target using the ORT helper function for plugins onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_plugin ${CUDA_PLUGIN_EP_CC_SRCS} diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index baa284f2fb859..f142bec4f3edc 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -65,6 +65,52 @@ struct DispatchGroupNorm { broadcast_skip, channels_per_block); } + +#ifdef BUILD_CUDA_EP_AS_PLUGIN + Status operator()(CudaKernel::PluginTuningContextStub* tuning_ctx, + void* ort_stream, + Tensor* output, + Tensor* add_out, + const Tensor* input, + const Tensor* skip, + const Tensor* bias, + const Tensor* gamma, + const Tensor* beta, + void* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_swish_activation, + bool broadcast_skip, + int channels_per_block) { + ORT_UNUSED_PARAMETER(tuning_ctx); + typedef typename ToCudaType::MappedType CudaT; + onnxruntime::PluginStreamShim plugin_stream_shim(ort_stream); + return LaunchGroupNormKernel( + nullptr, + &plugin_stream_shim, + reinterpret_cast(output->MutableData()), + add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), + reinterpret_cast(input->Data()), + skip == nullptr ? nullptr : reinterpret_cast(skip->Data()), + bias == nullptr ? nullptr : reinterpret_cast(bias->Data()), + gamma->Data(), + beta->Data(), + workspace, + epsilon, + batch_size, + num_channels, + height, + width, + num_groups, + use_swish_activation, + broadcast_skip, + channels_per_block); + } +#endif }; } // namespace diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 3e93a527877c5..5ba8833d0d8b7 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -228,7 +228,7 @@ Status QOrderedAttention::ComputeInternal(OpKernelContext* context) const { output_shape[2] = static_cast(hidden_size); Tensor* output = context->Output(0, output_shape); - cublasLtHandle_t cublasLt = CublasLtHandle(); + cublasLtHandle_t cublasLt = this->GetCublasLtHandle(context); // Use GEMM for fully connection. int m = batch_size * sequence_length; int n = 3 * hidden_size; @@ -236,7 +236,7 @@ Status QOrderedAttention::ComputeInternal(OpKernelContext* context) const { int64_t size_of_attention_scores = ((int64_t)batch_size) * num_heads_ * sequence_length * sequence_length; // transposed qkv_layer, union(stacked, attention probs + attention scores) - auto gemm_buffer_quantized = GetScratchBuffer((int64_t)m * n + std::max((int64_t)m * n, 2 * size_of_attention_scores), context->GetComputeStream()); + auto gemm_buffer_quantized = GetScratchBuffer((int64_t)m * n + std::max((int64_t)m * n, 2 * size_of_attention_scores), GetComputeStream(context)); int8_t* stacked_qkv_layers = gemm_buffer_quantized.get() + ((int64_t)m * n); int8_t* tranposed_qkv_layers = gemm_buffer_quantized.get(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_longformer_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_longformer_attention.cc index 4e0140f34e869..65b09a98f139c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_longformer_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_longformer_attention.cc @@ -100,7 +100,7 @@ QOrderedLongformerAttention::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, shape); cublasHandle_t cublas = GetCublasHandle(context); - cublasLtHandle_t cublasLt = CublasLtHandle(); + cublasLtHandle_t cublasLt = this->GetCublasLtHandle(context); cudaStream_t stream = Stream(context); CUBLAS_RETURN_IF_ERROR(cublasSetStream(cublas, stream)); @@ -109,11 +109,11 @@ QOrderedLongformerAttention::ComputeInternal(OpKernelContext* context) const { // TODO: only calculate once per model. // Build Global Index - auto global_index_buffer = GetScratchBuffer(static_cast(batch_size) * static_cast(sequence_length), context->GetComputeStream()); - auto batch_global_num_buffer = GetScratchBuffer(batch_size, context->GetComputeStream()); + auto global_index_buffer = GetScratchBuffer(static_cast(batch_size) * static_cast(sequence_length), GetComputeStream(context)); + auto batch_global_num_buffer = GetScratchBuffer(batch_size, GetComputeStream(context)); size_t global_scratch_bytes = GetGlobalScratchSize(sequence_length); - auto global_scratch_buffer = GetScratchBuffer(global_scratch_bytes, context->GetComputeStream()); + auto global_scratch_buffer = GetScratchBuffer(global_scratch_bytes, GetComputeStream(context)); auto& device_prop = GetDeviceProp(); ORT_RETURN_IF_ERROR(BuildGlobalIndex(device_prop, @@ -152,7 +152,7 @@ QOrderedLongformerAttention::ComputeInternal(OpKernelContext* context) const { // Buffer for GEMM outputs of q, k, v, global_q, global_k and global_v // TODO(tianleiwu): compact global_q only need batch_size * window * hidden_size * element_size buffer size. size_t qkv_3 = qkv_size + qkv_size + 2 * qkv_count * sizeof(int8_t); - auto gemm_buffer = GetScratchBuffer(qkv_3, context->GetComputeStream()); + auto gemm_buffer = GetScratchBuffer(qkv_3, GetComputeStream(context)); const float* scale_input = context->Input(1)->Data(); const float* scale_weight = context->Input(3)->Data(); @@ -220,7 +220,7 @@ QOrderedLongformerAttention::ComputeInternal(OpKernelContext* context) const { window_, disable_compact_memory); - auto workspace_buffer = GetScratchBuffer(workSpaceSize + output_elements * element_size, context->GetComputeStream()); + auto workspace_buffer = GetScratchBuffer(workSpaceSize + output_elements * element_size, GetComputeStream(context)); MLFloat16* out_fp16 = (MLFloat16*)(((int8_t*)workspace_buffer.get()) + workSpaceSize); ORT_RETURN_IF_ERROR(LaunchLongformerAttentionKernel(device_prop, cublas, @@ -252,7 +252,7 @@ QOrderedLongformerAttention::ComputeInternal(OpKernelContext* context) const { *scale_output, batch_size, sequence_length, hidden_size)); // Defer release of pinned memory since cudaStreamSynchronize is not used here and kernel need access the buffer. - this->AddDeferredReleaseCPUPtr(pinned_buffer.release(), context->GetComputeStream()); + this->AddDeferredReleaseCPUPtr(pinned_buffer.release(), GetComputeStream(context)); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc index a64f628f245e6..520c205179dc7 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc @@ -146,7 +146,7 @@ Status QOrderedMatMul::ComputeInternal(OpKernelContext* context) const { } Tensor* tensor_Y = context->Output(0, shapeY); - cublasLtHandle_t cublasLt = CublasLtHandle(); + cublasLtHandle_t cublasLt = this->GetCublasLtHandle(context); cudaStream_t stream = Stream(context); auto& device_prop = GetDeviceProp(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq.cc index 347697e588b8b..d7ac36d16e8e2 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq.cc @@ -138,7 +138,7 @@ Status QuantizeWithOrder::ComputeInternal(OpKernelContext* context) const { input_tensor, (cublasLtOrder_t)order_input_, (cublasLtOrder_t)order_output_, rows, cols, batch, n)); const float* scale = context->Input(1)->Data(); Tensor* output_tensor = context->Output(0, input_tensor.Shape()); - cublasLtHandle_t cublasLt = CublasLtHandle(); + cublasLtHandle_t cublasLt = this->GetCublasLtHandle(context); cudaStream_t stream = Stream(context); const auto& device_prop = GetDeviceProp(); @@ -154,7 +154,7 @@ Status QuantizeWithOrder::ComputeInternal(OpKernelContext* context) const { *scale, gsl::narrow(batch), gsl::narrow(rows), gsl::narrow(cols))); } } else { - auto q8_buffer = GetScratchBuffer(order_input_ == order_output_ ? 0LL : n, context->GetComputeStream()); + auto q8_buffer = GetScratchBuffer(order_input_ == order_output_ ? 0LL : n, GetComputeStream(context)); int8_t* dst = (order_input_ == order_output_ ? output_tensor->MutableData() : q8_buffer.get()); if (input_tensor.IsDataType()) { ORT_RETURN_IF_ERROR(QOrderQuantize_Strict(stream, device_prop, (const __half*)input_tensor.Data(), dst, *scale, n)); diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index 2aaff1192073b..4435b64f81b06 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -144,6 +144,10 @@ class CudaKernel : public OpKernel { return stream ? GetCublasHandle(stream) : DefaultCublasHandle(); } + inline cublasLtHandle_t GetCublasLtHandle(OpKernelContext* /*ctx*/) const { + return provider_->PerThreadCublasLtHandle(); + } + bool UseTF32() const { return provider_->UseTF32(); } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 00effe11e6249..0e09a98c36224 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -669,6 +669,18 @@ class CudaKernel : public OpKernel { } cublasHandle_t GetCublasHandle(OpKernelContext* ctx) const { return GetCublasHandle(Stream(ctx)); } + static cublasLtHandle_t GetCublasLtHandle(cudaStream_t s) { + auto* sync = cuda_plugin::CudaSyncStream::FromCudaStream(s); + return sync ? sync->GetCublasLtHandle() : nullptr; + } + static inline cublasLtHandle_t GetCublasLtHandle(onnxruntime::CudaStream* stream) { + return stream ? GetCublasLtHandle(static_cast(reinterpret_cast(stream)->GetHandle())) : nullptr; + } + static inline cublasLtHandle_t GetCublasLtHandle(onnxruntime::Stream* stream) { + return stream ? GetCublasLtHandle(static_cast(stream->GetHandle())) : nullptr; + } + cublasLtHandle_t GetCublasLtHandle(OpKernelContext* ctx) const { return GetCublasLtHandle(Stream(ctx)); } + const cudaDeviceProp& GetDeviceProp() const { return device_prop_; } bool UseTF32() const { return use_tf32_; } bool IsArchAvailable(int arch) const { return device_prop_.major >= arch; } From f4b1881f21e9bdb81151e3dc32b120c027dbd487 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 24 Mar 2026 12:17:40 -0700 Subject: [PATCH 12/48] add ort stream adapter --- .../contrib_ops/cuda/bert/attention.cc | 9 ++---- .../cuda/bert/decoder_attention.cc | 9 ++---- .../cuda/bert/group_query_attention.cc | 10 ++----- .../cuda/bert/multihead_attention.cc | 9 ++---- .../contrib_ops/cuda/bert/paged_attention.cc | 10 ++----- .../contrib_ops/cuda/diffusion/group_norm.cc | 4 +-- .../quantization/attention_quantization.cc | 9 ++---- .../cuda/sparse/sparse_attention.cc | 10 ++----- onnxruntime/core/providers/cuda/cuda_kernel.h | 16 ++++++++++ .../cuda/plugin/cuda_kernel_adapter.h | 29 +++++++++++++++++++ 10 files changed, 61 insertions(+), 54 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index e46ead949ce4f..69f319fdd977d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -58,12 +58,7 @@ Attention::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionB template Status Attention::ComputeInternal(OpKernelContext* context) const { -#ifdef BUILD_CUDA_EP_AS_PLUGIN - onnxruntime::PluginStreamShim plugin_stream_shim(GetComputeStream(context)); - auto* ort_stream = static_cast(&plugin_stream_shim); -#else - auto* ort_stream = context->GetComputeStream(); -#endif + auto ort_stream = GetOrtStream(context); const Tensor* input = context->Input(0); const Tensor* weights = context->Input(1); @@ -328,7 +323,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } cudnnHandle_t cudnn = GetCudnnHandle(context); - return QkvToContext(device_prop, cublas, cudnn, ort_stream, parameters, data); + return QkvToContext(device_prop, cublas, cudnn, ort_stream.get(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index 6478a3cca78a5..e13f13fc8b245 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -175,12 +175,7 @@ DecoderAttention::DecoderAttention(const OpKernelInfo& info) : CudaKernel(inf template Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { -#ifdef BUILD_CUDA_EP_AS_PLUGIN - onnxruntime::PluginStreamShim plugin_stream_shim(GetComputeStream(context)); - auto* ort_stream = static_cast(&plugin_stream_shim); -#else - auto* ort_stream = context->GetComputeStream(); -#endif + auto ort_stream = GetOrtStream(context); const Tensor* query(context->Input(0)); const Tensor* key(context->Input(1)); @@ -378,7 +373,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { return LaunchDecoderAttentionKernel( device_prop, UseTF32(), - ort_stream, + ort_stream.get(), cublas, element_size, batch_size, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 289ed0fc55f41..3b6b5f9079ebe 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -142,13 +142,7 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) // 11. head_sink (Tensor) - Attention sink for GPT-OSS template Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { - // Stream access: void* for GetScratchBuffer, Stream* for QkvToContext. -#ifdef BUILD_CUDA_EP_AS_PLUGIN - onnxruntime::PluginStreamShim plugin_stream_shim(GetComputeStream(context)); - auto* ort_stream = static_cast(&plugin_stream_shim); -#else - auto* ort_stream = context->GetComputeStream(); -#endif + auto ort_stream = GetOrtStream(context); const Tensor* query = context->Input(0); const Tensor* key = context->Input(1); @@ -564,7 +558,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons cublasHandle_t cublas = GetCublasHandle(context); ORT_RETURN_IF_ERROR((QkvToContext( - device_prop, cublas, ort_stream, parameters, data))); + device_prop, cublas, ort_stream.get(), parameters, data))); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index e06437c0f07e2..a2af4831a3a00 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -88,12 +88,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) template Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { -#ifdef BUILD_CUDA_EP_AS_PLUGIN - onnxruntime::PluginStreamShim plugin_stream_shim(GetComputeStream(context)); - auto* ort_stream = static_cast(&plugin_stream_shim); -#else - auto* ort_stream = context->GetComputeStream(); -#endif + auto ort_stream = GetOrtStream(context); const Tensor* query = context->Input(0); const Tensor* key = context->Input(1); @@ -564,7 +559,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) cons cudnnHandle_t cudnn = GetCudnnHandle(context); DUMP_STRING("Run QkvToContext from MHA CUDA"); return QkvToContext( - device_prop, cublas, cudnn, ort_stream, parameters, data); + device_prop, cublas, cudnn, ort_stream.get(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc index 7a3acc2867745..5df2c8b438771 100644 --- a/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc @@ -54,13 +54,7 @@ PagedAttention::PagedAttention(const OpKernelInfo& info) template Status PagedAttention::ComputeInternal(OpKernelContext* context) const { - // Stream access: void* for GetScratchBuffer, Stream* for QkvToContext. -#ifdef BUILD_CUDA_EP_AS_PLUGIN - onnxruntime::PluginStreamShim plugin_stream_shim(GetComputeStream(context)); - auto* ort_stream = static_cast(&plugin_stream_shim); -#else - auto* ort_stream = context->GetComputeStream(); -#endif + auto ort_stream = GetOrtStream(context); const Tensor* query = context->Input(0); const Tensor* key = context->Input(1); @@ -218,7 +212,7 @@ Status PagedAttention::ComputeInternal(OpKernelContext* context) const { cublasHandle_t cublas = GetCublasHandle(context); return QkvToContext( - device_prop, cublas, ort_stream, parameters, data); + device_prop, cublas, ort_stream.get(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index f142bec4f3edc..72a7a5f164ead 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -88,10 +88,10 @@ struct DispatchGroupNorm { int channels_per_block) { ORT_UNUSED_PARAMETER(tuning_ctx); typedef typename ToCudaType::MappedType CudaT; - onnxruntime::PluginStreamShim plugin_stream_shim(ort_stream); + onnxruntime::OrtStreamAdapter ort_stream_adapter(ort_stream); return LaunchGroupNormKernel( nullptr, - &plugin_stream_shim, + ort_stream_adapter.get(), reinterpret_cast(output->MutableData()), add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), reinterpret_cast(input->Data()), diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 551cd3dcdb9e9..3dcc03e9597e3 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -95,12 +95,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { // Output 0 - output : (batch_size, sequence_length, hidden_size) // Output 1 - present : (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size) -#ifdef BUILD_CUDA_EP_AS_PLUGIN - onnxruntime::PluginStreamShim plugin_stream_shim(GetComputeStream(context)); - auto* ort_stream = static_cast(&plugin_stream_shim); -#else - auto* ort_stream = context->GetComputeStream(); -#endif + auto ort_stream = GetOrtStream(context); const Tensor* input = context->Input(0); const Tensor* weights = context->Input(1); @@ -228,7 +223,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { } cudnnHandle_t cudnn = GetCudnnHandle(context); - return QkvToContext(GetDeviceProp(), cublas, cudnn, ort_stream, parameters, data); + return QkvToContext(GetDeviceProp(), cublas, cudnn, ort_stream.get(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc index 63ed7ce189a7b..656fde2f46ab8 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc @@ -60,13 +60,7 @@ SparseAttention::SparseAttention(const OpKernelInfo& info) template Status SparseAttention::ComputeInternal(OpKernelContext* context) const { - // Stream access: void* for GetScratchBuffer, Stream* for QkvToContext. -#ifdef BUILD_CUDA_EP_AS_PLUGIN - onnxruntime::PluginStreamShim plugin_stream_shim(GetComputeStream(context)); - auto* ort_stream = static_cast(&plugin_stream_shim); -#else - auto* ort_stream = context->GetComputeStream(); -#endif + auto ort_stream = GetOrtStream(context); auto& device_prop = GetDeviceProp(); if constexpr (std::is_same::value) { @@ -324,7 +318,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { data.active_q_blocks = active_q_blocks; } - return QkvToContext(device_prop, ort_stream, parameters, data); + return QkvToContext(device_prop, ort_stream.get(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index 4435b64f81b06..f3627b2f97229 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -14,6 +14,18 @@ namespace onnxruntime { namespace cuda { +class OrtStreamAdapter { + public: + explicit OrtStreamAdapter(onnxruntime::Stream* stream) : stream_(stream) {} + explicit OrtStreamAdapter(void* stream) : stream_(static_cast(stream)) {} + + onnxruntime::Stream* get() const { return stream_; } + operator onnxruntime::Stream*() const { return stream_; } + + private: + onnxruntime::Stream* stream_; +}; + #ifndef CUDA_STREAM_FROM_CTX // Helper for kernels that need a cudaStream_t from OpKernelContext in both framework and plugin builds. #define CUDA_STREAM_FROM_CTX(ctx) static_cast(GetComputeStream(ctx)) @@ -103,6 +115,10 @@ class CudaKernel : public OpKernel { return ctx ? ctx->GetComputeStream() : nullptr; } + inline OrtStreamAdapter GetOrtStream(OpKernelContext* ctx) const { + return OrtStreamAdapter(GetComputeStream(ctx)); + } + inline cudaStream_t Stream(OpKernelContext* ctx) const { auto* stream = ctx->GetComputeStream(); return stream ? static_cast(stream->GetHandle()) : nullptr; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 0e09a98c36224..57032831d89b1 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -59,6 +59,31 @@ struct PluginStreamShim : public onnxruntime::Stream { OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0)) {} }; + +class OrtStreamAdapter { + public: + explicit OrtStreamAdapter(void* cuda_stream_handle) + : plugin_stream_shim_(cuda_stream_handle), stream_(&plugin_stream_shim_) {} + + onnxruntime::Stream* get() const { return stream_; } + operator onnxruntime::Stream*() const { return stream_; } + + private: + PluginStreamShim plugin_stream_shim_; + onnxruntime::Stream* stream_; +}; +#else +class OrtStreamAdapter { + public: + explicit OrtStreamAdapter(void* cuda_stream_handle) + : stream_(static_cast(cuda_stream_handle)) {} + + onnxruntime::Stream* get() const { return stream_; } + operator onnxruntime::Stream*() const { return stream_; } + + private: + onnxruntime::Stream* stream_; +}; #endif } // namespace onnxruntime @@ -645,6 +670,10 @@ class CudaKernel : public OpKernel { return ctx->GetGPUComputeStream(); } + inline onnxruntime::OrtStreamAdapter GetOrtStream(OpKernelContext* ctx) const { + return onnxruntime::OrtStreamAdapter(GetComputeStream(ctx)); + } + static cudnnHandle_t GetCudnnHandle(cudaStream_t s) { auto* sync = cuda_plugin::CudaSyncStream::FromCudaStream(s); return sync ? sync->GetCudnnHandle() : nullptr; From 2869301bbe5947ad74a1886ca934b6229e27e8b0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 24 Mar 2026 12:27:21 -0700 Subject: [PATCH 13/48] remove duplicated link cudnn; use adapter in llm attention --- cmake/onnxruntime_providers_cuda_plugin.cmake | 9 ++-- .../core/providers/cuda/llm/attention.cc | 53 ++++++++++--------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index 60c2faf4d2139..9abd325cdfd22 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -88,8 +88,9 @@ list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/sequence_op\\.cc$") # in the CPU provider and is not linked into the plugin. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/size\\.cc$") -# Exclude llm/ — attention.cc calls QkvToContext which dereferences -# onnxruntime::Stream* (not available in plugin build's adapter OpKernelContext). +# Exclude llm/ for now. Stream handling in core/providers/cuda/llm/attention.cc +# is now adapter-safe, but the kernel still uses framework-only Node::OutputDefs() +# introspection in its constructor, and ep::adapter::Node does not expose that yet. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/llm/.*") list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/llm/.*") @@ -98,7 +99,8 @@ list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/llm/.*") # which cannot convert to ep::adapter::OpKernel in the plugin build. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/shape_op\\.cc$") -# Exclude contrib llm/ — uses onnxruntime::Stream* in QkvToContext. +# Exclude contrib llm/ for now. The core CUDA llm kernels are adapter-safe, but +# contrib llm kernels still need their own plugin pass. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/llm/.*") list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/llm/.*") @@ -197,7 +199,6 @@ target_link_libraries(onnxruntime_providers_cuda_plugin PRIVATE CUDA::cublasLt CUDNN::cudnn_all cudnn_frontend - ${CUDA_PLUGIN_CUDNN_LIBRARY} Boost::mp11 safeint_interface onnxruntime_framework diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 0ed6210fb4d29..b0776f77a39e7 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -138,7 +138,7 @@ Status Attention::ConvertAttnMaskToBias( using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; int64_t num_elements = attn_mask->Shape().Size(); converted_mask_buffer = GetScratchBuffer( - num_elements * sizeof(NativeCudaT), context->GetComputeStream()); + num_elements * sizeof(NativeCudaT), GetComputeStream(context)); float mask_filter_value = static_cast(std::numeric_limits::lowest()); ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( attn_mask->Data(), @@ -219,7 +219,7 @@ Status Attention::RunFlashAttention( const attention_helper::AttentionParameters& parameters) const { #if USE_FLASH_ATTENTION auto& device_prop = GetDeviceProp(); - auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + auto cuda_stream = Stream(context); const bool is_bf16 = std::is_same::value; const bool is_bsnh = parameters.transpose_output; // 3D inputs → BSNH @@ -233,9 +233,9 @@ Status Attention::RunFlashAttention( parameters.total_sequence_length, parameters.q_num_heads, parameters.head_size, device_prop.multiProcessorCount); - auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); - auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); - auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, GetComputeStream(context)); + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, GetComputeStream(context)); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, GetComputeStream(context)); if (softmax_lse_accum_bytes > 0) { CUDA_RETURN_IF_ERROR(cudaMemsetAsync(softmax_lse_accum_buffer.get(), 0, @@ -252,7 +252,7 @@ Status Attention::RunFlashAttention( if (!is_bsnh) { size_t q_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length * parameters.q_num_heads * parameters.head_size; - q_bsnh_buffer = GetScratchBuffer(q_bytes, context->GetComputeStream()); + q_bsnh_buffer = GetScratchBuffer(q_bytes, GetComputeStream(context)); ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( parameters.batch_size, parameters.q_sequence_length, parameters.q_num_heads, parameters.head_size, @@ -267,7 +267,7 @@ Status Attention::RunFlashAttention( if (!is_bsnh) { size_t out_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length * parameters.q_num_heads * parameters.v_head_size; - out_bsnh_buffer = GetScratchBuffer(out_bytes, context->GetComputeStream()); + out_bsnh_buffer = GetScratchBuffer(out_bytes, GetComputeStream(context)); out_data = out_bsnh_buffer.get(); } @@ -280,7 +280,7 @@ Status Attention::RunFlashAttention( "(past_sequence_length must be 0, got ", parameters.past_sequence_length, ")."); - auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); + auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToFlashSeqlensK( nonpad_kv_seqlen->Data(), seqlens_k_buffer.get(), @@ -334,7 +334,7 @@ Status Attention::RunFlashAttention( // Step 1: Compute per-batch past sequence lengths for the concat kernel. // The concat kernel needs past_seq_lens to know where past data ends and new begins. - auto past_seqlens_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); + auto past_seqlens_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); if (attn_mask != nullptr && attn_mask->IsDataType()) { size_t mask_dims = attn_mask->Shape().NumDimensions(); auto dims = attn_mask->Shape().GetDims(); @@ -364,8 +364,8 @@ Status Attention::RunFlashAttention( parameters.kv_num_heads * parameters.head_size; size_t v_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length * parameters.kv_num_heads * parameters.v_head_size; - k_bsnh_buffer = GetScratchBuffer(k_bytes, context->GetComputeStream()); - v_bsnh_buffer = GetScratchBuffer(v_bytes, context->GetComputeStream()); + k_bsnh_buffer = GetScratchBuffer(k_bytes, GetComputeStream(context)); + v_bsnh_buffer = GetScratchBuffer(v_bytes, GetComputeStream(context)); ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( parameters.batch_size, parameters.kv_sequence_length, parameters.kv_num_heads, parameters.head_size, @@ -413,7 +413,7 @@ Status Attention::RunFlashAttention( // Step 4: Compute total seqlens for mha_fwd_kvcache. // With k_new=nullptr, the kernel treats seqlens_k as the total valid token count // (not pre-append count), so we need past + new. - auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); + auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); if (attn_mask != nullptr && attn_mask->IsDataType()) { size_t mask_dims = attn_mask->Shape().NumDimensions(); auto dims = attn_mask->Shape().GetDims(); @@ -567,7 +567,7 @@ Status Attention::RunMemoryEfficientAttention( ORT_UNUSED_PARAMETER(past_key); ORT_UNUSED_PARAMETER(past_value); auto& device_prop = GetDeviceProp(); - auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + auto cuda_stream = Stream(context); const bool is_bsnh = parameters.transpose_output; const int sm = device_prop.major * 10 + device_prop.minor; @@ -581,7 +581,7 @@ Status Attention::RunMemoryEfficientAttention( if (!is_bsnh) { size_t q_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length * parameters.q_num_heads * parameters.head_size; - q_bsnh_buffer = GetScratchBuffer(q_bytes, context->GetComputeStream()); + q_bsnh_buffer = GetScratchBuffer(q_bytes, GetComputeStream(context)); ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( parameters.batch_size, parameters.q_sequence_length, parameters.q_num_heads, parameters.head_size, @@ -596,7 +596,7 @@ Status Attention::RunMemoryEfficientAttention( if (!is_bsnh) { size_t out_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length * parameters.q_num_heads * parameters.v_head_size; - out_bsnh_buffer = GetScratchBuffer(out_bytes, context->GetComputeStream()); + out_bsnh_buffer = GetScratchBuffer(out_bytes, GetComputeStream(context)); out_data = out_bsnh_buffer.get(); } @@ -622,8 +622,8 @@ Status Attention::RunMemoryEfficientAttention( static_cast(parameters.total_sequence_length) * static_cast(parameters.q_num_heads) * static_cast(parameters.head_size); - k_expand_buffer = GetScratchBuffer(expanded_kv_elements * sizeof(T), context->GetComputeStream()); - v_expand_buffer = GetScratchBuffer(expanded_kv_elements * sizeof(T), context->GetComputeStream()); + k_expand_buffer = GetScratchBuffer(expanded_kv_elements * sizeof(T), GetComputeStream(context)); + v_expand_buffer = GetScratchBuffer(expanded_kv_elements * sizeof(T), GetComputeStream(context)); onnxruntime::contrib::GroupQueryAttentionParameters ungroup_params = {}; ungroup_params.batch_size = parameters.batch_size; @@ -661,7 +661,7 @@ Status Attention::RunMemoryEfficientAttention( if (nonpad_kv_seqlen != nullptr) { // Convert nonpad_kv_seqlen to seqlens_k for custom right padding. // MEA expects actual token count (not count-1), so use FlashSeqlensK variant. - auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); + auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToFlashSeqlensK( nonpad_kv_seqlen->Data(), seqlens_k_buffer.get(), @@ -710,7 +710,7 @@ Status Attention::RunMemoryEfficientAttention( parameters.v_head_size, sizeof(T) == sizeof(float))) { size_t workspace_bytes = sizeof(float) * parameters.batch_size * parameters.q_sequence_length * parameters.q_num_heads * parameters.v_head_size; - workspace_buffer = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + workspace_buffer = GetScratchBuffer(workspace_bytes, GetComputeStream(context)); p.workspace = workspace_buffer.get(); } else { p.workspace = nullptr; @@ -758,7 +758,7 @@ Status Attention::RunMemoryEfficientAttention( parameters.v_head_size, sizeof(T) == sizeof(float))) { size_t workspace_bytes = sizeof(float) * parameters.batch_size * parameters.q_sequence_length * parameters.q_num_heads * parameters.v_head_size; - workspace_buffer = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + workspace_buffer = GetScratchBuffer(workspace_bytes, GetComputeStream(context)); p.workspace = workspace_buffer.get(); } else { p.workspace = nullptr; @@ -845,7 +845,8 @@ Status Attention::RunUnfusedAttention( const attention_helper::AttentionParameters& parameters) const { using CudaT = typename ToCudaType::MappedType; auto& device_prop = GetDeviceProp(); - auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + auto cuda_stream = Stream(context); + auto ort_stream = GetOrtStream(context); // Bridge to contrib::AttentionParameters for the MHA unfused path onnxruntime::contrib::AttentionParameters contribop_parameters; @@ -927,7 +928,7 @@ Status Attention::RunUnfusedAttention( int64_t bias_elements = static_cast(parameters.batch_size) * parameters.q_sequence_length * parameters.total_sequence_length; - converted_mask_buffer = GetScratchBuffer(bias_elements * sizeof(NativeCudaT), context->GetComputeStream()); + converted_mask_buffer = GetScratchBuffer(bias_elements * sizeof(NativeCudaT), GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToAttentionBias( nonpad_kv_seqlen->Data(), reinterpret_cast(converted_mask_buffer.get()), @@ -957,7 +958,7 @@ Status Attention::RunUnfusedAttention( if (attn_mask->IsDataType()) { // Convert bool mask to additive bias in a temp buffer, then add in-place. - mask_bias_buffer = GetScratchBuffer(mask_elements * sizeof(NativeCudaT), context->GetComputeStream()); + mask_bias_buffer = GetScratchBuffer(mask_elements * sizeof(NativeCudaT), GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( attn_mask->Data(), reinterpret_cast(mask_bias_buffer.get()), @@ -991,7 +992,7 @@ Status Attention::RunUnfusedAttention( if (attn_mask->IsDataType()) { using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; int64_t num_elements = attn_mask->Shape().Size(); - converted_mask_buffer = GetScratchBuffer(num_elements * sizeof(NativeCudaT), context->GetComputeStream()); + converted_mask_buffer = GetScratchBuffer(num_elements * sizeof(NativeCudaT), GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( attn_mask->Data(), reinterpret_cast(converted_mask_buffer.get()), @@ -1025,7 +1026,7 @@ Status Attention::RunUnfusedAttention( contribop_parameters.total_sequence_length, nullptr, false, false, false, false, false, no_qkv_workspace); - auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + auto work_space = GetScratchBuffer(workspace_bytes, GetComputeStream(context)); data.has_qkv_workspace = !no_qkv_workspace; data.workspace = reinterpret_cast(work_space.get()); @@ -1035,7 +1036,7 @@ Status Attention::RunUnfusedAttention( cudnnHandle_t cudnn = GetCudnnHandle(context); return onnxruntime::contrib::cuda::QkvToContext( - device_prop, cublas, cudnn, context->GetComputeStream(), contribop_parameters, data); + device_prop, cublas, cudnn, ort_stream.get(), contribop_parameters, data); } // ============================================================================ From 0bb6422076de472789d08f92846ce090291d52ec Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 24 Mar 2026 13:57:08 -0700 Subject: [PATCH 14/48] onnx attention op --- cmake/onnxruntime_providers_cuda_plugin.cmake | 6 ------ include/onnxruntime/ep/adapter/allocator.h | 2 ++ include/onnxruntime/ep/adapter/node.h | 10 ++++++++++ .../providers/cpu/nn/deform_conv_attributes.h | 7 ++++--- onnxruntime/core/providers/cuda/llm/attention.cc | 16 +++++++++++++++- 5 files changed, 31 insertions(+), 10 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index 9abd325cdfd22..6814374b2d3ed 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -88,12 +88,6 @@ list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/sequence_op\\.cc$") # in the CPU provider and is not linked into the plugin. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/size\\.cc$") -# Exclude llm/ for now. Stream handling in core/providers/cuda/llm/attention.cc -# is now adapter-safe, but the kernel still uses framework-only Node::OutputDefs() -# introspection in its constructor, and ep::adapter::Node does not expose that yet. -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/llm/.*") -list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/llm/.*") - # Permanently excluded — pure CPU ops, handled by GetCpuPreferredNodes. # shape_op.cc inherits from onnxruntime::OpKernel (framework) # which cannot convert to ep::adapter::OpKernel in the plugin build. diff --git a/include/onnxruntime/ep/adapter/allocator.h b/include/onnxruntime/ep/adapter/allocator.h index 4f107ae72c0e9..c1d4bcaf77017 100644 --- a/include/onnxruntime/ep/adapter/allocator.h +++ b/include/onnxruntime/ep/adapter/allocator.h @@ -7,6 +7,8 @@ #error "This header should not be included directly. Include ep/adapters.h instead." #endif +#include + #include "core/framework/allocator.h" namespace onnxruntime { diff --git a/include/onnxruntime/ep/adapter/node.h b/include/onnxruntime/ep/adapter/node.h index cdeb9209389a3..17513d3a14dfa 100644 --- a/include/onnxruntime/ep/adapter/node.h +++ b/include/onnxruntime/ep/adapter/node.h @@ -36,6 +36,16 @@ struct Node { return kernel_info_.GetOperatorSinceVersion(); } + /** Gets the number of outputs. */ + size_t OutputCount() const noexcept { + return kernel_info_.GetOutputCount(); + } + + /** Gets whether an output exists or is an omitted optional output. */ + bool OutputExists(size_t index) const { + return index < OutputCount() && !kernel_info_.GetOutputName(index).empty(); + } + private: const Ort::ConstKernelInfo kernel_info_; }; diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h index 8bc891bb4f377..a5c0606a55e7d 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -15,15 +15,16 @@ namespace onnxruntime { // See https://onnx.ai/onnx/operators/onnx__DeformConv.html // Used by both CPU and CUDA implementations (CUDA includes from here). struct DeformConvAttributes { - explicit DeformConvAttributes(const OpKernelInfo& info) { + template + explicit DeformConvAttributes(const KernelInfoType& info) { // Optional attributes. // If not present, they will be empty/default, and handled in Compute/ComputeInternal. (void)info.GetAttrs("kernel_shape", kernel_shape); (void)info.GetAttrs("strides", strides); (void)info.GetAttrs("pads", pads); (void)info.GetAttrs("dilations", dilations); - group = info.GetAttrOrDefault("group", 1); - offset_group = info.GetAttrOrDefault("offset_group", 1); + group = info.template GetAttrOrDefault("group", 1); + offset_group = info.template GetAttrOrDefault("offset_group", 1); } TensorShapeVector kernel_shape; diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index b0776f77a39e7..27b81691c70af 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -19,6 +19,19 @@ using namespace onnxruntime::cuda; namespace onnxruntime { namespace cuda { +namespace llm_attention_detail { + +template +bool HasOutput(const NodeType& node, size_t output_index) { + if constexpr (requires(const NodeType& candidate) { candidate.OutputCount(); candidate.OutputExists(output_index); }) { + return node.OutputCount() > output_index && node.OutputExists(output_index); + } else { + return node.OutputDefs().size() > output_index && node.OutputDefs()[output_index]->Exists(); + } +} + +} // namespace llm_attention_detail + #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ Attention, \ @@ -60,7 +73,8 @@ Attention::Attention(const OpKernelInfo& info) : CudaKernel(info) { kv_num_heads_ = static_cast(info.GetAttrOrDefault("kv_num_heads", 0)); q_num_heads_ = static_cast(info.GetAttrOrDefault("q_num_heads", 0)); int mode = static_cast(info.GetAttrOrDefault("qk_matmul_output_mode", 0)); - qk_matmul_output_mode_ = info.node().OutputDefs().size() >= 4 && info.node().OutputDefs()[3]->Exists() + const auto& node = info.node(); + qk_matmul_output_mode_ = llm_attention_detail::HasOutput(node, 3) ? static_cast(mode) : attention_helper::QKMatMulOutputMode::kNone; ORT_ENFORCE(qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kNone || From ced3fdbe5b775afd11973aa781c1a4e8bbaf5d6b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 24 Mar 2026 16:30:20 -0700 Subject: [PATCH 15/48] redesign CudaKernelAdapterRuntimeConfig map --- .../contrib_ops/cuda/bert/skip_layer_norm.cc | 2 +- .../cuda/plugin/cuda_allocator_plugin.cc | 4 +- .../core/providers/cuda/plugin/cuda_ep.cc | 7 +- .../providers/cuda/plugin/cuda_ep_factory.cc | 23 ++-- .../cuda/plugin/cuda_kernel_adapter.h | 118 ++++++++++++------ .../providers/cuda/plugin/cuda_plugin_utils.h | 14 +++ .../cuda/plugin/cuda_stream_plugin.cc | 17 +-- 7 files changed, 129 insertions(+), 56 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 1cbb44a82f97b..859bc43c0c8d3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -44,7 +44,7 @@ SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) ORT_ENFORCE(epsilon_ >= 0); #ifdef BUILD_CUDA_EP_AS_PLUGIN - strict_ = onnxruntime::cuda::GetCudaKernelAdapterSkipLayerNormStrictMode(); + strict_ = onnxruntime::cuda::GetCudaKernelAdapterSkipLayerNormStrictMode(op_kernel_info.GetExecutionProvider()); #else const CUDAExecutionProvider* cuda_ep = static_cast(op_kernel_info.GetExecutionProvider()); strict_ = cuda_ep->IsSkipLayerNormInStrictMode(); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc index 5da80f4121c9a..68f4dce39414f 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc @@ -28,7 +28,9 @@ CudaDeviceAllocator::CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int d auto* alloc = static_cast(this_ptr); void* p = nullptr; if (size == 0) return nullptr; - cudaSetDevice(alloc->device_id_); + if (cudaSetDevice(alloc->device_id_) != cudaSuccess) { + return nullptr; + } cudaError_t err = cudaMalloc(&p, size); if (err != cudaSuccess) { return nullptr; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index 8a9a5a4548e9b..7b1227408b503 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -42,12 +42,15 @@ CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& lo ORT_FILE, __LINE__, __FUNCTION__)); // Seed adapter-level runtime options for migrated kernels. - onnxruntime::cuda::SetCudaKernelAdapterRuntimeConfig( + onnxruntime::cuda::SetCudaKernelAdapterRuntimeConfigForProvider( + static_cast(static_cast(this)), config_.use_tf32, config_.device_id, config_.enable_skip_layer_norm_strict_mode, config_.cudnn_conv_algo, config_.cudnn_conv_use_max_workspace, config_.cudnn_conv1d_pad_to_nc1d); } -CudaEp::~CudaEp() = default; +CudaEp::~CudaEp() { + onnxruntime::cuda::detail::RemoveCudaKernelAdapterRuntimeConfigForProvider(static_cast(static_cast(this))); +} /*static*/ const char* ORT_API_CALL CudaEp::GetNameImpl(const OrtEp* this_ptr) noexcept { diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index 33294851f50bf..503f9245ac905 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -131,11 +131,12 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( cudaError_t err = cudaGetDeviceCount(&cuda_device_count); if (err == cudaSuccess && cuda_device_count > 0 && current_device_id < cuda_device_count) { cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, current_device_id); - factory->ort_api_.AddKeyValuePair(ep_metadata, "cuda_device_name", prop.name); - factory->ort_api_.AddKeyValuePair( - ep_metadata, "cuda_compute_capability", - (std::to_string(prop.major) + "." + std::to_string(prop.minor)).c_str()); + if (cudaGetDeviceProperties(&prop, current_device_id) == cudaSuccess) { + factory->ort_api_.AddKeyValuePair(ep_metadata, "cuda_device_name", prop.name); + factory->ort_api_.AddKeyValuePair( + ep_metadata, "cuda_compute_capability", + (std::to_string(prop.major) + "." + std::to_string(prop.minor)).c_str()); + } } OrtEpDevice* ep_device = nullptr; @@ -236,6 +237,13 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( try { value = std::stoi(buf.data()); } catch (...) { + if (logger) { + std::string msg = std::string("Failed to parse session config for key: ") + key + ". Using default value."; + OrtStatus* st = factory->ort_api_.Logger_LogMessage(logger, ORT_LOGGING_LEVEL_WARNING, msg.c_str(), "cuda_ep_factory.cc", __LINE__, "CudaEpFactory"); + if (st) { + factory->ort_api_.ReleaseStatus(st); + } + } } }; @@ -334,13 +342,14 @@ bool ORT_API_CALL CudaEpFactory::IsStreamAwareImpl(const OrtEpFactory* /*this_pt /*static*/ OrtStatus* ORT_API_CALL CudaEpFactory::CreateSyncStreamForDeviceImpl( OrtEpFactory* this_ptr, - const OrtMemoryDevice* /*memory_device*/, + const OrtMemoryDevice* memory_device, const OrtKeyValuePairs* /*stream_options*/, OrtSyncStreamImpl** stream) noexcept { EXCEPTION_TO_STATUS_BEGIN auto* factory = static_cast(this_ptr); - auto cuda_stream = std::make_unique(*factory, factory->device_id_, nullptr); + int req_device_id = factory->ep_api_.MemoryDevice_GetDeviceId(memory_device); + auto cuda_stream = std::make_unique(*factory, req_device_id, nullptr); // Initialize CUDA handles (stream, cuBLAS, cuDNN) RETURN_IF_ERROR(cuda_stream->InitHandles()); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 57032831d89b1..bc8fe835d5ddd 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -334,6 +334,7 @@ struct PluginNoOpLogStream { #include #include #include +#include namespace onnxruntime { namespace cuda { @@ -345,17 +346,51 @@ namespace cuda { // =================================================================== namespace detail { +// All fields are written once during CudaEp construction (under unique_lock) +// and only read afterwards, so std::atomic is not needed — the shared_mutex +// in ProviderConfigStore provides the necessary happens-before guarantee. struct CudaKernelAdapterRuntimeConfig { - std::atomic use_tf32{true}; - std::atomic skip_layer_norm_strict_mode{false}; - std::atomic device_id{0}; - std::atomic cudnn_conv_algo{0}; - std::atomic cudnn_conv_use_max_workspace{true}; - std::atomic cudnn_conv1d_pad_to_nc1d{false}; + bool use_tf32 = true; + bool skip_layer_norm_strict_mode = false; + int cudnn_conv_algo = 0; + bool cudnn_conv_use_max_workspace = true; + bool cudnn_conv1d_pad_to_nc1d = false; + int device_id = 0; + cudaDeviceProp device_prop{}; }; -inline CudaKernelAdapterRuntimeConfig& GetCudaKernelAdapterRuntimeConfig() { - static CudaKernelAdapterRuntimeConfig config; - return config; +// Shared storage for per-provider runtime configurations. +// Both Get and Remove must operate on the same static map instance, +// so we centralise them in a single struct with static lifetime. +struct ProviderConfigStore { + std::shared_mutex mutex; + std::unordered_map> configs; + + static ProviderConfigStore& Instance() { + static ProviderConfigStore store; + return store; + } +}; + +inline CudaKernelAdapterRuntimeConfig& GetCudaKernelAdapterRuntimeConfigForProvider(const void* provider) { + auto& store = ProviderConfigStore::Instance(); + std::shared_lock lock(store.mutex); + auto it = store.configs.find(provider); + if (it != store.configs.end()) { + return *it->second; + } + lock.unlock(); + std::unique_lock unique_lock(store.mutex); + auto& ptr = store.configs[provider]; + if (!ptr) { + ptr = std::make_unique(); + } + return *ptr; +} + +inline void RemoveCudaKernelAdapterRuntimeConfigForProvider(const void* provider) { + auto& store = ProviderConfigStore::Instance(); + std::unique_lock lock(store.mutex); + store.configs.erase(provider); } template struct SizeOf { @@ -406,6 +441,18 @@ inline const cudaDeviceProp& GetDevicePropForDevice(int device_id) { // Provides the minimal API surface that migrated kernels expect // (GetCudnnConvAlgo, UseTF32, GetDeviceProp, etc.) without the full // CUDAExecutionProvider class from onnxruntime/core/providers/cuda/. +// +// DESIGN NOTE: Why does this class have no state/member variables? +// In the plugin build, the object returned by `info.GetExecutionProvider()` +// is an opaque C-API struct (`OrtEp*`/`CudaEp*`), NOT this class. +// The raw kernel code performs `static_cast` on it. +// If this shim class defined any member variables (e.g., `config_`), the +// compiler would read them at specific byte offsets relative to `this`, causing +// memory layout UB (garbage reads/segfaults) since the underlying object in +// memory is actually an `OrtEp`. +// Therefore, `CUDAExecutionProvider` here must remain a pure "phantom shim." +// To safely access state (like TF32 settings), it dynamically queries a static +// map keyed by its own `this` pointer (which equals the `CudaEp*` memory address). // =================================================================== // Shim for CUDAExecutionProvider required by conv.cc, einsum, and others @@ -413,43 +460,44 @@ class CUDAExecutionProvider : public onnxruntime::IExecutionProvider { public: explicit CUDAExecutionProvider(const std::string& name) : onnxruntime::IExecutionProvider{name} {} int GetCudnnConvAlgo() const { - return cuda::detail::GetCudaKernelAdapterRuntimeConfig().cudnn_conv_algo.load(std::memory_order_relaxed); + return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).cudnn_conv_algo; } bool GetCudnnConvUseMaxWorkspace() const { - return cuda::detail::GetCudaKernelAdapterRuntimeConfig().cudnn_conv_use_max_workspace.load(std::memory_order_relaxed); + return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).cudnn_conv_use_max_workspace; } bool GetCudnnConv1dPadToNc1d() const { - return cuda::detail::GetCudaKernelAdapterRuntimeConfig().cudnn_conv1d_pad_to_nc1d.load(std::memory_order_relaxed); + return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).cudnn_conv1d_pad_to_nc1d; } bool UseTF32() const { - return cuda::detail::GetCudaKernelAdapterRuntimeConfig().use_tf32.load(std::memory_order_relaxed); + return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).use_tf32; } bool IsFuseConvBias() const { return false; } const cudaDeviceProp& GetDeviceProp() const { - int device_id = cuda::detail::GetCudaKernelAdapterRuntimeConfig().device_id.load(std::memory_order_relaxed); - return cuda::detail::GetDevicePropForDevice(device_id); + return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).device_prop; } }; namespace cuda { -inline void SetCudaKernelAdapterRuntimeConfig(bool use_tf32, int device_id, bool skip_layer_norm_strict_mode = false, - int cudnn_conv_algo = 0, bool cudnn_conv_use_max_workspace = true, - bool cudnn_conv1d_pad_to_nc1d = false) { - auto& config = detail::GetCudaKernelAdapterRuntimeConfig(); - config.use_tf32.store(use_tf32, std::memory_order_relaxed); - config.skip_layer_norm_strict_mode.store(skip_layer_norm_strict_mode, std::memory_order_relaxed); - config.device_id.store(device_id, std::memory_order_relaxed); - config.cudnn_conv_algo.store(cudnn_conv_algo, std::memory_order_relaxed); - config.cudnn_conv_use_max_workspace.store(cudnn_conv_use_max_workspace, std::memory_order_relaxed); - config.cudnn_conv1d_pad_to_nc1d.store(cudnn_conv1d_pad_to_nc1d, std::memory_order_relaxed); +inline void SetCudaKernelAdapterRuntimeConfigForProvider(const void* provider, bool use_tf32, int device_id, + bool skip_layer_norm_strict_mode = false, + int cudnn_conv_algo = 0, bool cudnn_conv_use_max_workspace = true, + bool cudnn_conv1d_pad_to_nc1d = false) { + auto& config = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); + config.use_tf32 = use_tf32; + config.skip_layer_norm_strict_mode = skip_layer_norm_strict_mode; + config.cudnn_conv_algo = cudnn_conv_algo; + config.cudnn_conv_use_max_workspace = cudnn_conv_use_max_workspace; + config.cudnn_conv1d_pad_to_nc1d = cudnn_conv1d_pad_to_nc1d; + config.device_id = device_id; + PL_CUDA_CALL_THROW(cudaGetDeviceProperties(&config.device_prop, device_id)); } -inline bool GetCudaKernelAdapterSkipLayerNormStrictMode() { - const auto& config = detail::GetCudaKernelAdapterRuntimeConfig(); - return config.skip_layer_norm_strict_mode.load(std::memory_order_relaxed); +inline bool GetCudaKernelAdapterSkipLayerNormStrictMode(const void* provider) { + const auto& config = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); + return config.skip_layer_norm_strict_mode; } // Global aliases and shims @@ -626,15 +674,11 @@ namespace cuda { // re-open onnxruntime::cuda class CudaKernel : public OpKernel { public: explicit CudaKernel(const OpKernelInfo& info) : OpKernel(info), info_(info) { - const auto& config = detail::GetCudaKernelAdapterRuntimeConfig(); - use_tf32_ = config.use_tf32.load(std::memory_order_relaxed); - device_id_ = config.device_id.load(std::memory_order_relaxed); - int cur = device_id_; - if (cudaGetDevice(&cur) == cudaSuccess) device_id_ = cur; - if (cudaGetDeviceProperties(&device_prop_, device_id_) != cudaSuccess) { - std::memset(&device_prop_, 0, sizeof(device_prop_)); - device_prop_.major = -1; - } + const auto* provider = info.GetExecutionProvider(); + const auto& config = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); + use_tf32_ = config.use_tf32; + device_id_ = config.device_id; + device_prop_ = config.device_prop; } virtual ~CudaKernel() = default; Status Compute(OpKernelContext* ctx) const { diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h index 82a792faa3258..0e4808d07046d 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h @@ -30,6 +30,20 @@ } while (0) #endif +// Throwing variant for use in constructors and non-OrtStatus contexts. +// Analogous to CUDA_CALL_THROW in the non-plugin build. +#ifndef PL_CUDA_CALL_THROW +#define PL_CUDA_CALL_THROW(cuda_call_expr) \ + do { \ + cudaError_t _cuda_err = (cuda_call_expr); \ + if (_cuda_err != cudaSuccess) { \ + throw std::runtime_error( \ + std::string("CUDA error: ") + cudaGetErrorName(_cuda_err) + ": " + \ + cudaGetErrorString(_cuda_err)); \ + } \ + } while (0) +#endif + #ifndef PL_CUBLAS_RETURN_IF_ERROR #define PL_CUBLAS_RETURN_IF_ERROR(cublas_call_expr) \ do { \ diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc index 06d16cb92ca8f..2c2ee50029816 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc @@ -4,6 +4,7 @@ #include "cuda_stream_plugin.h" #include "cuda_ep_factory.h" #include +#include namespace onnxruntime { namespace cuda_plugin { @@ -20,8 +21,8 @@ StreamMap& GetStreamMap() { return stream_map; } -std::mutex& GetStreamMapMutex() { - static std::mutex stream_map_mutex; +std::shared_mutex& GetStreamMapMutex() { + static std::shared_mutex stream_map_mutex; return stream_map_mutex; } } // namespace @@ -55,7 +56,7 @@ CudaSyncStream::~CudaSyncStream() { } OrtStatus* CudaSyncStream::InitHandles() { - cudaSetDevice(device_id_); + PL_CUDA_RETURN_IF_ERROR(cudaSetDevice(device_id_)); PL_CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&cuda_stream_, cudaStreamNonBlocking)); RegisterStream(cuda_stream_, this); @@ -122,7 +123,7 @@ void CudaSyncStream::CleanupDeferredCPUBuffers() { } auto& stream_map = GetStreamMap(); - std::lock_guard lock(GetStreamMapMutex()); + std::shared_lock lock(GetStreamMapMutex()); auto it = stream_map.find(stream); if (it != stream_map.end()) { return it->second; @@ -132,13 +133,13 @@ void CudaSyncStream::CleanupDeferredCPUBuffers() { /*static*/ void CudaSyncStream::RegisterStream(cudaStream_t stream, CudaSyncStream* sync_stream) { auto& stream_map = GetStreamMap(); - std::lock_guard lock(GetStreamMapMutex()); + std::unique_lock lock(GetStreamMapMutex()); stream_map[stream] = sync_stream; } /*static*/ void CudaSyncStream::UnregisterStream(cudaStream_t stream) { auto& stream_map = GetStreamMap(); - std::lock_guard lock(GetStreamMapMutex()); + std::unique_lock lock(GetStreamMapMutex()); stream_map.erase(stream); } @@ -156,8 +157,8 @@ CudaSyncNotification::CudaSyncNotification(CudaSyncStream& stream) Release = ReleaseImpl; // Create a CUDA event for synchronization (disable timing for performance) - cudaSetDevice(stream_.GetDeviceId()); - cudaEventCreateWithFlags(&event_, cudaEventDisableTiming); + PL_CUDA_CALL_THROW(cudaSetDevice(stream_.GetDeviceId())); + PL_CUDA_CALL_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); } CudaSyncNotification::~CudaSyncNotification() { From 6307a1522647a9756cb1590eccc8f95ac6bb494d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 24 Mar 2026 17:02:08 -0700 Subject: [PATCH 16/48] refactor ConstantOfShape and other feedbacks --- .../cpu/generator/constant_of_shape_base.h | 150 ++++++++++++------ .../cuda/generator/constant_of_shape.h | 92 ++--------- .../cuda/plugin/cuda_allocator_plugin.cc | 10 +- .../cuda/plugin/cuda_allocator_plugin.h | 27 +++- .../cuda/plugin/cuda_controlflow_plugin.cc | 23 +-- .../cuda/plugin/cuda_controlflow_plugin.cu | 33 ++-- .../cuda/plugin/cuda_controlflow_plugin.h | 8 +- .../providers/cuda/plugin/cuda_ep_factory.cc | 31 +++- .../cuda/plugin/cuda_kernel_adapter.h | 27 +--- .../cuda/plugin/cuda_stream_plugin.cc | 9 ++ 10 files changed, 198 insertions(+), 212 deletions(-) diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h index f08f134d0c080..c30721c909a97 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h @@ -66,32 +66,41 @@ using ConstantOfShapeDefaultOutputTypesOpset23 = uint8_t, uint16_t, uint32_t, uint64_t, bool>; -template -class ConstantOfShapeBase { +#define ORT_CONSTANT_OF_SHAPE_VALUE_TYPES(M) \ + M(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) \ + M(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) \ + M(MLFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) \ + M(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) \ + M(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) \ + M(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16) \ + M(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) \ + M(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) \ + M(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) \ + M(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16) \ + M(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32) \ + M(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64) \ + M(BFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) + +template +struct ConstantOfShapeOrtType; + +#define DEFINE_CONSTANT_OF_SHAPE_ORT_TYPE(c_type, ort_type) \ + template <> \ + struct ConstantOfShapeOrtType { \ + static constexpr ONNXTensorElementDataType value = ort_type; \ + }; + +ORT_CONSTANT_OF_SHAPE_VALUE_TYPES(DEFINE_CONSTANT_OF_SHAPE_ORT_TYPE) + +#undef DEFINE_CONSTANT_OF_SHAPE_ORT_TYPE + +class ConstantOfShapeCore { protected: - ConstantOfShapeBase(const OpKernelInfo& info) { -#ifndef SHARED_PROVIDER - ONNX_NAMESPACE::TensorProto t_proto; - auto* t_proto_p = &t_proto; -#else - auto t_proto = ONNX_NAMESPACE::TensorProto::Create(); - auto* t_proto_p = t_proto.get(); -#endif - if (info.GetAttr("value", t_proto_p).IsOK()) { - for (auto dim : t_proto_p->dims()) { - ORT_ENFORCE(dim == 1, "The value attribute of ConstantOfShape must be a single-element tensor"); - } - SetValueFromTensorProto(*t_proto_p); - } else { - float f_value = 0.f; - SetValue(sizeof(float), reinterpret_cast(&f_value)); - } - } - void* GetValuePtr() const { return p_value_; } - static Status PrepareCompute(OpKernelContext* ctx, Tensor** output_tensor) { - const auto shape_tensor = ctx->Input(0); + template + static Status PrepareCompute(ContextType* ctx, Tensor** output_tensor) { + const auto shape_tensor = ctx->template Input(0); const auto& input_shape = shape_tensor->Shape(); // If empty the output is a scalar with empty shape @@ -99,7 +108,7 @@ class ConstantOfShapeBase { // one value ORT_RETURN_IF_NOT(input_shape.NumDimensions() > 0, "Must have a valid input shape."); - const auto span = shape_tensor->DataAsSpan(); + const auto span = shape_tensor->template DataAsSpan(); TensorShape output_shape(span); (*output_tensor) = ctx->Output(0, output_shape); @@ -107,31 +116,48 @@ class ConstantOfShapeBase { return Status::OK(); } - private: - union SizeBasedValue { - int8_t int8_; - int16_t int16_; - int32_t int32_; - int64_t int64_; - } s_value_; - void* p_value_; + void SetDefaultValue() { + float f_value = 0.f; + SetValue(sizeof(float), &f_value); + } + + template + void SetValueFromOrtTensor(ONNXTensorElementDataType tensor_type, const void* data) { + bool handled = false; + switch (tensor_type) { +#define CASE_SET_ORT_VALUE(c_type, ort_type) \ + case ConstantOfShapeOrtType::value: { \ + if (utils::HasType()) { \ + SetValue(sizeof(c_type), data); \ + handled = true; \ + } \ + break; \ + } + ORT_CONSTANT_OF_SHAPE_VALUE_TYPES(CASE_SET_ORT_VALUE) +#undef CASE_SET_ORT_VALUE + default: + ORT_THROW("Unsupported value attribute datatype: ", static_cast(tensor_type)); + } - void SetValue(size_t size, void* value) { + ORT_ENFORCE(handled, "Unsupported value attribute datatype in this build: ", static_cast(tensor_type)); + } + + void SetValue(size_t size, const void* value) { switch (size) { case sizeof(int8_t): - s_value_.int8_ = *(reinterpret_cast(value)); + s_value_.int8_ = *(reinterpret_cast(value)); p_value_ = reinterpret_cast(&(s_value_.int8_)); break; case sizeof(int16_t): - s_value_.int16_ = *(reinterpret_cast(value)); + s_value_.int16_ = *(reinterpret_cast(value)); p_value_ = reinterpret_cast(&(s_value_.int16_)); break; case sizeof(int32_t): - s_value_.int32_ = *(reinterpret_cast(value)); + s_value_.int32_ = *(reinterpret_cast(value)); p_value_ = reinterpret_cast(&(s_value_.int32_)); break; case sizeof(int64_t): - s_value_.int64_ = *(reinterpret_cast(value)); + s_value_.int64_ = *(reinterpret_cast(value)); p_value_ = reinterpret_cast(&(s_value_.int64_)); break; default: @@ -139,10 +165,43 @@ class ConstantOfShapeBase { } } + private: + union SizeBasedValue { + int8_t int8_; + int16_t int16_; + int32_t int32_; + int64_t int64_; + }; + mutable SizeBasedValue s_value_{}; + mutable void* p_value_ = nullptr; +}; + +template +class ConstantOfShapeBase : public ConstantOfShapeCore { + protected: + ConstantOfShapeBase(const OpKernelInfo& info) { +#ifndef SHARED_PROVIDER + ONNX_NAMESPACE::TensorProto t_proto; + auto* t_proto_p = &t_proto; +#else + auto t_proto = ONNX_NAMESPACE::TensorProto::Create(); + auto* t_proto_p = t_proto.get(); +#endif + if (info.GetAttr("value", t_proto_p).IsOK()) { + for (auto dim : t_proto_p->dims()) { + ORT_ENFORCE(dim == 1, "The value attribute of ConstantOfShape must be a single-element tensor"); + } + SetValueFromTensorProto(*t_proto_p); + } else { + SetDefaultValue(); + } + } + void SetValueFromTensorProto(const ONNX_NAMESPACE::TensorProto&); }; -#define CASE_FETCH_VALUE_DATA(c_type) \ +// ort_type parameter unused here but required for ORT_CONSTANT_OF_SHAPE_VALUE_TYPES X-macro conformance. +#define CASE_FETCH_VALUE_DATA(c_type, ort_type) \ case utils::ToTensorProtoElementType(): { \ if (utils::HasType()) { \ c_type val; \ @@ -164,19 +223,7 @@ void ConstantOfShapeBase::SetValueFromTensorProto(const O const size_t raw_data_len = utils::HasRawData(t_proto) ? t_proto.raw_data().size() : 0; bool handled = false; switch (tensor_type) { - CASE_FETCH_VALUE_DATA(bool) - CASE_FETCH_VALUE_DATA(float) - CASE_FETCH_VALUE_DATA(MLFloat16) - CASE_FETCH_VALUE_DATA(double) - CASE_FETCH_VALUE_DATA(int8_t) - CASE_FETCH_VALUE_DATA(int16_t) - CASE_FETCH_VALUE_DATA(int32_t) - CASE_FETCH_VALUE_DATA(int64_t) - CASE_FETCH_VALUE_DATA(uint8_t) - CASE_FETCH_VALUE_DATA(uint16_t) - CASE_FETCH_VALUE_DATA(uint32_t) - CASE_FETCH_VALUE_DATA(uint64_t) - CASE_FETCH_VALUE_DATA(BFloat16) + ORT_CONSTANT_OF_SHAPE_VALUE_TYPES(CASE_FETCH_VALUE_DATA) default: ORT_THROW("Unsupported value attribute datatype: ", tensor_type); } @@ -185,5 +232,6 @@ void ConstantOfShapeBase::SetValueFromTensorProto(const O } #undef CASE_FETCH_VALUE_DATA +#undef ORT_CONSTANT_OF_SHAPE_VALUE_TYPES } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/generator/constant_of_shape.h b/onnxruntime/core/providers/cuda/generator/constant_of_shape.h index 10d9b979a82bd..7f0268923cb3e 100644 --- a/onnxruntime/core/providers/cuda/generator/constant_of_shape.h +++ b/onnxruntime/core/providers/cuda/generator/constant_of_shape.h @@ -4,9 +4,7 @@ #pragma once #include "core/providers/cuda/cuda_kernel.h" -#ifndef BUILD_CUDA_EP_AS_PLUGIN #include "core/providers/cpu/generator/constant_of_shape_base.h" -#endif #include "core/providers/cuda/shared_inc/cuda_utils.h" namespace onnxruntime { @@ -14,10 +12,11 @@ namespace cuda { #ifdef BUILD_CUDA_EP_AS_PLUGIN -// Plugin build: self-contained ConstantOfShape without ConstantOfShapeBase dependency. -// ConstantOfShapeBase uses TensorProto/UnpackTensor utilities not available in the plugin, -// so we read the 'value' attribute via the ORT C API (KernelInfoGetAttribute_tensor) instead. -class ConstantOfShape final : public CudaKernel { +// Plugin build: keep the attribute fetch self-contained while reusing the shared +// ConstantOfShapeCore helpers for default handling and supported type mapping. +// ConstantOfShapeBase still depends on TensorProto/UnpackTensor utilities that the +// plugin build avoids, so the plugin path reads the attribute via the ORT C API instead. +class ConstantOfShape final : public ConstantOfShapeCore, public CudaKernel { public: explicit ConstantOfShape(const OpKernelInfo& info) : CudaKernel(info) { InitValue(info); @@ -27,53 +26,7 @@ class ConstantOfShape final : public CudaKernel { Status ComputeInternal(OpKernelContext* ctx) const override; - protected: - void* GetValuePtr() const { return p_value_; } - - static Status PrepareCompute(OpKernelContext* ctx, Tensor** output_tensor) { - const auto* shape_tensor = ctx->Input(0); - const auto& input_shape = shape_tensor->Shape(); - ORT_RETURN_IF_NOT(input_shape.NumDimensions() > 0, "Must have a valid input shape."); - const auto span = shape_tensor->DataAsSpan(); - TensorShape output_shape(span); - (*output_tensor) = ctx->Output(0, output_shape); - return Status::OK(); - } - private: - union SizeBasedValue { - int8_t int8_; - int16_t int16_; - int32_t int32_; - int64_t int64_; - }; - - mutable SizeBasedValue s_value_{}; - mutable void* p_value_ = nullptr; - - void SetValue(size_t size, const void* value) { - switch (size) { - case sizeof(int8_t): - s_value_.int8_ = *(reinterpret_cast(value)); - p_value_ = reinterpret_cast(&(s_value_.int8_)); - break; - case sizeof(int16_t): - s_value_.int16_ = *(reinterpret_cast(value)); - p_value_ = reinterpret_cast(&(s_value_.int16_)); - break; - case sizeof(int32_t): - s_value_.int32_ = *(reinterpret_cast(value)); - p_value_ = reinterpret_cast(&(s_value_.int32_)); - break; - case sizeof(int64_t): - s_value_.int64_ = *(reinterpret_cast(value)); - p_value_ = reinterpret_cast(&(s_value_.int64_)); - break; - default: - ORT_THROW("Unsupported value attribute datatype with size: ", size); - } - } - void InitValue(const OpKernelInfo& info) { Ort::AllocatorWithDefaultOptions allocator; auto ort_info = info.GetKernelInfo(); @@ -84,40 +37,13 @@ class ConstantOfShape final : public CudaKernel { ORT_ENFORCE(elem_count == 1 || elem_count == 0, "The value attribute of ConstantOfShape must be a single-element tensor"); if (elem_count == 1) { - const void* data = value_tensor.GetTensorRawData(); - size_t elem_size = GetElementSize(type_and_shape.GetElementType()); - SetValue(elem_size, data); + SetValueFromOrtTensor( + type_and_shape.GetElementType(), value_tensor.GetTensorRawData()); } else { - float f_value = 0.f; - SetValue(sizeof(float), &f_value); + SetDefaultValue(); } } catch (const Ort::Exception&) { - float f_value = 0.f; - SetValue(sizeof(float), &f_value); - } - } - - static size_t GetElementSize(ONNXTensorElementDataType type) { - switch (type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - return 1; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - return 2; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - return 4; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: - return 8; - default: - ORT_THROW("Unsupported element type for ConstantOfShape: ", static_cast(type)); + SetDefaultValue(); } } }; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc index 68f4dce39414f..895916bb76b71 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc @@ -12,8 +12,7 @@ namespace cuda_plugin { // --------------------------------------------------------------------------- CudaDeviceAllocator::CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int device_id) - : OrtAllocator{}, - memory_info_(memory_info), + : CudaAllocatorBase(CudaAllocatorKind::kDevice, memory_info), device_id_(device_id) { version = ORT_API_VERSION; Alloc = AllocImpl; @@ -48,7 +47,7 @@ CudaDeviceAllocator::CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int d /*static*/ const OrtMemoryInfo* ORT_API_CALL CudaDeviceAllocator::InfoImpl(const OrtAllocator* this_ptr) noexcept { const auto* alloc = static_cast(this_ptr); - return alloc->memory_info_; + return alloc->GetMemoryInfo(); } /*static*/ void* ORT_API_CALL CudaDeviceAllocator::ReserveImpl(OrtAllocator* this_ptr, size_t size) noexcept { @@ -62,8 +61,7 @@ CudaDeviceAllocator::CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int d // --------------------------------------------------------------------------- CudaPinnedAllocator::CudaPinnedAllocator(const OrtMemoryInfo* memory_info) - : OrtAllocator{}, - memory_info_(memory_info) { + : CudaAllocatorBase(CudaAllocatorKind::kPinned, memory_info) { version = ORT_API_VERSION; Alloc = AllocImpl; Free = FreeImpl; @@ -91,7 +89,7 @@ CudaPinnedAllocator::CudaPinnedAllocator(const OrtMemoryInfo* memory_info) /*static*/ const OrtMemoryInfo* ORT_API_CALL CudaPinnedAllocator::InfoImpl(const OrtAllocator* this_ptr) noexcept { const auto* alloc = static_cast(this_ptr); - return alloc->memory_info_; + return alloc->GetMemoryInfo(); } /*static*/ void* ORT_API_CALL CudaPinnedAllocator::ReserveImpl(OrtAllocator* this_ptr, size_t size) noexcept { diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h index c39270774b992..371b097f2eedd 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h @@ -8,9 +8,29 @@ namespace onnxruntime { namespace cuda_plugin { +enum class CudaAllocatorKind { + kDevice, + kPinned, +}; + +class CudaAllocatorBase : public OrtAllocator { + public: + explicit CudaAllocatorBase(CudaAllocatorKind kind, const OrtMemoryInfo* memory_info) + : OrtAllocator{}, + kind_(kind), + memory_info_(memory_info) {} + + CudaAllocatorKind GetKind() const { return kind_; } + const OrtMemoryInfo* GetMemoryInfo() const { return memory_info_; } + + private: + CudaAllocatorKind kind_; + const OrtMemoryInfo* memory_info_; +}; + /// CUDA device memory allocator using cudaMalloc/cudaFree. /// Lifetime is managed by the EP factory (ReleaseAllocatorImpl), not by a Release callback. -class CudaDeviceAllocator : public OrtAllocator { +class CudaDeviceAllocator final : public CudaAllocatorBase { public: CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int device_id); ~CudaDeviceAllocator() = default; @@ -21,13 +41,12 @@ class CudaDeviceAllocator : public OrtAllocator { static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept; static void* ORT_API_CALL ReserveImpl(OrtAllocator* this_ptr, size_t size) noexcept; - const OrtMemoryInfo* memory_info_; int device_id_; }; /// CUDA pinned (host) memory allocator using cudaHostAlloc/cudaFreeHost. /// Lifetime is managed by the EP factory (ReleaseAllocatorImpl), not by a Release callback. -class CudaPinnedAllocator : public OrtAllocator { +class CudaPinnedAllocator final : public CudaAllocatorBase { public: CudaPinnedAllocator(const OrtMemoryInfo* memory_info); ~CudaPinnedAllocator() = default; @@ -37,8 +56,6 @@ class CudaPinnedAllocator : public OrtAllocator { static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept; static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept; static void* ORT_API_CALL ReserveImpl(OrtAllocator* this_ptr, size_t size) noexcept; - - const OrtMemoryInfo* memory_info_; }; } // namespace cuda_plugin diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc index f03e69645df9c..2d066002d7496 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc @@ -115,15 +115,8 @@ OrtStatus* ORT_API_CALL PluginLoopHelper::ConcatOutputImpl( if (cur_bytes != bytes_per_iteration) { return Ort::Status("Inconsistent size in loop output iteration", ORT_FAIL).release(); } - cudaError_t err = cudaMemcpyAsync(cur, val.GetTensorRawData(), bytes_per_iteration, - cudaMemcpyDeviceToDevice, cuda_stream); - if (err != cudaSuccess) { - return Ort::Status((std::string("cudaMemcpyAsync failed in Loop ConcatOutput: ") + - cudaGetErrorString(err)) - .c_str(), - ORT_FAIL) - .release(); - } + PL_CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cur, val.GetTensorRawData(), bytes_per_iteration, + cudaMemcpyDeviceToDevice, cuda_stream)); cur += bytes_per_iteration; } @@ -214,12 +207,12 @@ OrtStatus* ORT_API_CALL PluginScanHelper::TransposeImpl( void* output_data = output.GetTensorMutableData(); // Launch the GPU transpose kernel - status = LaunchTransposeKernel(input_data, output_data, - input_shape.data(), permutation, - num_dims, element_size, total_elements, - cuda_stream); - if (!status.IsOK()) { - return Ort::Status(status.ErrorMessage().c_str(), ORT_EP_FAIL).release(); + OrtStatus* ort_status = LaunchTransposeKernel(input_data, output_data, + input_shape.data(), permutation, + num_dims, element_size, total_elements, + cuda_stream); + if (ort_status != nullptr) { + return ort_status; } return nullptr; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu index 77eed62068413..5e4b7acc2f95a 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu @@ -11,7 +11,7 @@ #include #include -#include "core/common/status.h" +#include "cuda_plugin_utils.h" namespace onnxruntime { namespace cuda { @@ -19,15 +19,6 @@ namespace plugin { namespace { -Status CudaStatus(cudaError_t cuda_status, const char* operation) { - if (cuda_status == cudaSuccess) { - return Status::OK(); - } - - return common::Status(common::ONNXRUNTIME, common::FAIL, - std::string("Scan Transpose: ") + operation + " failed: " + cudaGetErrorString(cuda_status)); -} - // Maximum number of dimensions supported by the transpose kernel. // Most real-world tensors have <= 8 dimensions. constexpr int kMaxTransposeDims = 8; @@ -72,18 +63,16 @@ __global__ void TransposeNDKernel(const char* __restrict__ input, memcpy(dst, src, element_size); } -Status LaunchTransposeKernel(const void* input, void* output, - const int64_t* input_shape, const size_t* permutation, - size_t num_dims, size_t element_size, size_t total_elements, - cudaStream_t stream) { +OrtStatus* LaunchTransposeKernel(const void* input, void* output, + const int64_t* input_shape, const size_t* permutation, + size_t num_dims, size_t element_size, size_t total_elements, + cudaStream_t stream) { if (total_elements == 0 || num_dims == 0) { - return Status::OK(); + return nullptr; } if (num_dims > static_cast(kMaxTransposeDims)) { - return common::Status(common::ONNXRUNTIME, common::FAIL, - "Scan Transpose: rank " + std::to_string(num_dims) + - " exceeds the supported maximum rank of " + std::to_string(kMaxTransposeDims)); + return Ort::Status("Scan Transpose: rank exceeds the supported maximum rank", ORT_FAIL).release(); } TransposeArgs args; @@ -116,11 +105,9 @@ Status LaunchTransposeKernel(const void* input, void* output, element_size, total_elements); - auto status = CudaStatus(cudaGetLastError(), "TransposeNDKernel launch"); - if (!status.IsOK()) { - return status; - } - return Status::OK(); + PL_CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + return nullptr; } } // namespace plugin diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h index f5e203322a0c3..da6fb94023333 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h @@ -87,10 +87,10 @@ class PluginScanKernel : public OpKernel { }; // GPU transpose helper (defined in cuda_controlflow_plugin.cu) -Status LaunchTransposeKernel(const void* input, void* output, - const int64_t* input_shape, const size_t* permutation, - size_t num_dims, size_t element_size, size_t total_elements, - cudaStream_t stream); +OrtStatus* LaunchTransposeKernel(const void* input, void* output, + const int64_t* input_shape, const size_t* permutation, + size_t num_dims, size_t element_size, size_t total_elements, + cudaStream_t stream); } // namespace plugin } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index 503f9245ac905..fa5173e7e8736 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -18,6 +18,10 @@ CudaEpFactory::CudaEpFactory(const OrtApi& ort_api, const OrtEpApi& ep_api, pinned_memory_info_{nullptr} { ort_version_supported = ORT_API_VERSION; + if (!::onnxruntime::ep::adapter::LoggingManager::HasDefaultLogger()) { + ::onnxruntime::ep::adapter::LoggingManager::CreateDefaultLogger(&default_logger); + } + // Assign callback function pointers GetName = GetNameImpl; GetVendor = GetVendorImpl; @@ -318,9 +322,18 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateAllocatorImpl( /*static*/ void ORT_API_CALL CudaEpFactory::ReleaseAllocatorImpl( OrtEpFactory* /*this_ptr*/, OrtAllocator* allocator) noexcept { - // We know the allocator was created by us, so cast and delete. - // OrtAllocator itself has no Release method. - delete allocator; + if (!allocator) return; + auto* typed_allocator = static_cast(allocator); + switch (typed_allocator->GetKind()) { + case CudaAllocatorKind::kDevice: + delete static_cast(allocator); + return; + case CudaAllocatorKind::kPinned: + delete static_cast(allocator); + return; + default: + ORT_ENFORCE(false, "Unknown CudaAllocatorKind in ReleaseAllocatorImpl"); + } } /*static*/ @@ -328,7 +341,17 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateDataTransferImpl( OrtEpFactory* this_ptr, OrtDataTransferImpl** data_transfer) noexcept { auto& factory = *static_cast(this_ptr); - const OrtMemoryDevice* gpu_device = factory.ep_api_.MemoryInfo_GetMemoryDevice(factory.default_memory_info_); + + // Use the device ID this factory was created for + Ort::MemoryInfo device_memory_info{"Cuda", + OrtMemoryInfoDeviceType_GPU, + factory.vendor_id_, + static_cast(factory.device_id_), + OrtDeviceMemoryType_DEFAULT, + 0, + OrtAllocatorType::OrtDeviceAllocator}; + + const OrtMemoryDevice* gpu_device = factory.ep_api_.MemoryInfo_GetMemoryDevice(device_memory_info); auto data_transfer_impl = std::make_unique(factory.ort_api_, factory.ep_api_, gpu_device); *data_transfer = data_transfer_impl.release(); return nullptr; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index bc8fe835d5ddd..a61a05719794c 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -231,7 +231,7 @@ class PluginKernelCollector { return Status::OK(); \ })); \ } \ - static const bool ORT_ADAPTER_CONCAT(_autoreg_##name##_, __COUNTER__) = \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_, __COUNTER__) = \ (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ &BuildKernelCreateInfo), \ true); @@ -250,7 +250,7 @@ class PluginKernelCollector { return Status::OK(); \ })); \ } \ - static const bool ORT_ADAPTER_CONCAT(_autoreg_##name##_, __COUNTER__) = \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_, __COUNTER__) = \ (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ &BuildKernelCreateInfo), \ @@ -270,7 +270,7 @@ class PluginKernelCollector { return Status::OK(); \ })); \ } \ - static const bool ORT_ADAPTER_CONCAT(_autoreg_##name##_##type##_, __COUNTER__) = \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_##type##_, __COUNTER__) = \ (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ &BuildKernelCreateInfo), \ true); @@ -290,7 +290,7 @@ class PluginKernelCollector { return Status::OK(); \ })); \ } \ - static const bool ORT_ADAPTER_CONCAT(_autoreg_##name##_##type##_, __COUNTER__) = \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_##type##_, __COUNTER__) = \ (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ &BuildKernelCreateInfo), \ @@ -305,23 +305,8 @@ class PluginKernelCollector { // Explicit function instantiation — called once per unique class in each .cc file #define ONNX_OPERATOR_TYPED_KERNEL_COMPUTE_INSTANTIATION(cls) template Status cls::ComputeInternal(OpKernelContext* context) const; -#undef CREATE_MESSAGE -#undef LOGS -#undef LOGS_DEFAULT -#undef ORT_LOG_MESSAGE - -namespace onnxruntime { -namespace cuda { -struct PluginNoOpLogStream { - template - PluginNoOpLogStream& operator<<(const T&) { return *this; } -}; -} // namespace cuda -} // namespace onnxruntime - -#ifndef LOGS_DEFAULT -#define LOGS_DEFAULT(severity) ::onnxruntime::cuda::PluginNoOpLogStream() -#endif +// The plugin utilizes ep::adapter::LoggingManager for LOGS_DEFAULT, +// which is initialized in CudaEpFactory::CudaEpFactory. #include #include diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc index 2c2ee50029816..fb144f5903d28 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc @@ -122,10 +122,19 @@ void CudaSyncStream::CleanupDeferredCPUBuffers() { return nullptr; } + // Thread-local TLS cache to mitigate lock contention on the hot path + thread_local cudaStream_t tls_last_stream = nullptr; + thread_local CudaSyncStream* tls_last_sync_stream = nullptr; + if (stream == tls_last_stream) { + return tls_last_sync_stream; + } + auto& stream_map = GetStreamMap(); std::shared_lock lock(GetStreamMapMutex()); auto it = stream_map.find(stream); if (it != stream_map.end()) { + tls_last_stream = stream; + tls_last_sync_stream = it->second; return it->second; } return nullptr; From b3fdf2544a96f512e0e3b9128707bc2399ad78f5 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 24 Mar 2026 23:50:12 -0700 Subject: [PATCH 17/48] update doc; fix test --- docs/cuda_plugin_ep/cuda_plugin_ep_design.md | 231 ++++++++---------- include/onnxruntime/ep/api.h | 21 +- .../cpu/generator/constant_of_shape_base.h | 5 + .../cuda/plugin/cuda_allocator_plugin.cc | 9 + .../cuda/plugin/cuda_data_transfer_plugin.cc | 19 ++ .../providers/cuda/plugin/cuda_ep_factory.cc | 33 ++- .../providers/cuda/plugin/cuda_ep_factory.h | 7 + .../cuda/plugin/cuda_kernel_adapter.h | 59 ++++- .../cuda/plugin/cuda_stream_plugin.cc | 21 +- .../providers/cuda/reduction/reduction_ops.cc | 2 - .../python/onnxruntime_pybind_state.cc | 25 ++ .../cuda/reduction/reduction_ops.cc | 5 +- 12 files changed, 280 insertions(+), 157 deletions(-) diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md index d07c1216ff5c9..c7f25b8cedd0c 100644 --- a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -9,7 +9,7 @@ The CUDA Plugin EP is an alternative build of the ONNX Runtime CUDA Execution Pr - Support all operators currently supported by the in-tree CUDA EP (tunable ops are low priority) - Minimize changes to existing CUDA kernel source files -**Current status:** ~80% of CUDA kernels compile in the plugin build. Excluded operators are documented in [Section 7](#7-excluded-operators). +**Current status:** The plugin build is functional on this branch and the focused plugin validation script (`./cuda_plugin.sh --build --test_plugin`) passes. Most core CUDA kernels now compile in the plugin build; the remaining source-level exclusions are documented in [Section 7](#7-excluded-operators). --- @@ -40,26 +40,27 @@ Each build target uses different preprocessor defines that control how framework ### 2.3 Class Hierarchy ``` -OrtEpFactory OrtEp - ↑ ↑ -CudaEpFactory adapter::Ep (holds unique_ptr) - │ ↑ - │ CudaEp - │ │ - │ └─ owns ──→ CUDAExecutionProvider - │ (: IExecutionProvider) - │ ├─ config members - │ ├─ device properties - │ └─ stream→handle map - │ - └─ creates ──→ CudaSyncStream (owns cublasHandle_t, cudnnHandle_t, cublasLtHandle_t) +OrtEpFactory OrtEp + ↑ ↑ +CudaEpFactory CudaEp + │ │ + ├─ creates OrtEpDevice ├─ stores session-derived Config + ├─ creates CudaSyncStream └─ seeds adapter runtime config for kernels + ├─ caches kernel registry + ├─ caches stable OrtMemoryInfo objects + └─ maps OrtHardwareDevice* → CUDA ordinal + +Migrated CUDA kernels + └─ use CudaKernel / cuda_kernel_adapter.h + ├─ read provider config through a phantom CUDAExecutionProvider shim + └─ resolve stream-local handles via CudaSyncStream::FromCudaStream() ``` Key ownership relationships: -- `CudaEpFactory` creates `CudaEp` instances and `CudaSyncStream` objects -- `CudaEp` inherits from `ep::adapter::Ep` and owns a `CUDAExecutionProvider` instance (accessible via `EpImpl()`) -- `CUDAExecutionProvider` is a plugin-local class (not the framework one) that inherits from `IExecutionProvider` and provides the full API surface CUDA kernels need -- `CudaSyncStream` owns CUDA/cuBLAS/cuDNN handles per stream +- `CudaEpFactory` implements raw `OrtEpFactory` callbacks and owns shared factory-level state such as the kernel registry, cached `OrtMemoryInfo` instances, and the hardware-device to CUDA-ordinal map. +- `CudaEp` inherits directly from `OrtEp`; it does not derive from `ep::adapter::Ep` and does not own a separate framework `IExecutionProvider` object. +- The plugin-local `CUDAExecutionProvider` in `cuda_kernel_adapter.h` is a zero-state compatibility shim used by migrated kernels. Runtime state is stored in adapter-side maps keyed by the `CudaEp` address. +- `CudaSyncStream` owns `cudaStream_t`, `cublasHandle_t`, `cudnnHandle_t`, and `cublasLtHandle_t` for each sync stream created through the EP API. ### 2.4 Plugin DLL Entry Points @@ -124,7 +125,7 @@ This 700+ line header provides everything CUDA kernels need that would normally | Kernel registration | Self-registering `ONNX_OPERATOR_*_KERNEL_EX` macro overrides via `PluginKernelCollector` | | CPU shims | Lightweight reimplementations of CPU helpers not linked into plugin | | Math helpers | `HalfGemmOptions`, `CublasMathModeSetter` | -| Stream shim | `PluginStreamShim` wrapping raw `cudaStream_t` as `onnxruntime::Stream*` | +| Stream shim | `OrtStreamAdapter`/`PluginStreamShim` to present a framework-compatible `Stream*` view over a raw `cudaStream_t` where needed | ### 3.5 Kernel Registration @@ -216,7 +217,7 @@ This allows the base class constructor to work with both the framework `OpKernel Some CPU base classes have heavy dependencies (protobuf, `UnpackTensor`) that make inlining impractical: -- **`ConstantOfShapeBase`** — depends on `TensorProto` and `UnpackTensor`. Plugin uses a self-contained duplicate class in `constant_of_shape.h` guarded by `#ifdef BUILD_CUDA_EP_AS_PLUGIN`. +- **`ConstantOfShapeBase`** — depends on `TensorProto` and `UnpackTensor`. The plugin path in `constant_of_shape.h` stays self-contained: it reuses `ConstantOfShapeCore` but fetches the `value` attribute through the ORT C++ API instead of depending on the full CPU base implementation. - **`UpsampleBase`** — partially addressed: `AdjustOutputSizeAsPolicy` moved to header (#27628). Still depends on `InputDefs()` and `OpKernelInfo::GetAllocator()` which are not in the adapter. --- @@ -225,27 +226,33 @@ Some CPU base classes have heavy dependencies (protobuf, `UnpackTensor`) that ma ### 5.1 Stream Ownership -`CudaSyncStream` is the plugin's CUDA stream implementation: +`CudaSyncStream` is the plugin's CUDA sync-stream implementation: - Owns `cudaStream_t`, `cublasHandle_t`, `cudnnHandle_t`, `cublasLtHandle_t` -- Created by `CudaEpFactory::CreateSyncStreamForDevice` -- Registered with `CUDAExecutionProvider` for handle lookup +- Is created by `CudaEpFactory::CreateSyncStreamForDevice` +- Registers itself in a global `cudaStream_t -> CudaSyncStream*` map so migrated kernels can recover per-stream handles from a raw CUDA stream +- Defers host-buffer cleanup until `OnSessionRunEnd()` after the stream is synchronized ### 5.2 Handle Access Path ``` CudaKernel::GetCublasHandle(OpKernelContext* ctx) - → Stream(ctx) // raw cudaStream_t from ctx - → CUDAExecutionProvider::GetActiveProvider() // static pointer to active EP - → provider->GetCublasHandle(cudaStream_t) // stream→handle map lookup + → Stream(ctx) // raw cudaStream_t from adapter ctx + → CudaSyncStream::FromCudaStream() // global stream map + TLS cache + → sync_stream->GetCublasHandle() ``` -The `CUDAExecutionProvider` maintains a `std::unordered_map` for handle lookups. +The stream lookup path uses a thread-local last-hit cache plus a generation counter so destroyed streams invalidate stale TLS entries without requiring per-thread cleanup. + +For code paths that need handles without an active stream, `cuda_kernel_adapter.h` also provides thread-local default cuBLAS/cuDNN handles via `GetDefaultCudaHandlesForDevice(device_id)`. ### 5.3 Provider Access -Kernels access the provider through two paths: -1. **`CudaKernel::provider_`** — set in the constructor from `info.GetExecutionProvider()` -2. **`CUDAExecutionProvider::GetActiveProvider()`** — static atomic pointer (for `.cu` code that doesn't have a `CudaKernel` instance) +Kernels access provider configuration through the pointer returned by `info.GetExecutionProvider()`, but in the plugin build that pointer is treated as a phantom `CUDAExecutionProvider` shim. The shim must remain layout-compatible with `IExecutionProvider` and carries no member state; runtime configuration is stored in the adapter-side `ProviderConfigStore`, keyed by the provider address. + +For stream bridging, the preferred helpers are: +- `Stream(ctx)` when the kernel only needs a raw `cudaStream_t` +- `GetComputeStream(ctx)` when the kernel API already accepts the adapter's opaque stream pointer +- `GetOrtStream(ctx)` when framework-style `Stream*` plumbing is still needed by shared helper code ### 5.4 CUDA Graph Support @@ -286,28 +293,24 @@ Session::Run() #### 5.4.2 Current Plugin EP Behavior — API Gap -The `OrtEp` C API (`onnxruntime_ep_c_api.h`) provides `OnRunStart` and `OnRunEnd` callbacks but **does not include**: +The `OrtEp` C API (`onnxruntime_ep_c_api.h`) still does not include: - `IsGraphCaptureEnabled()` - `IsGraphCaptured(annotation_id)` - `ReplayGraph(annotation_id)` -The `PluginExecutionProvider` bridge (`ep_plugin_provider_interfaces.cc`) does not override these `IExecutionProvider` virtual methods, so they return the base class defaults (`false`, `false`, `Status::OK()`). - -**Consequence**: The session's `cached_execution_provider_for_graph_replay_` is never set for the plugin EP. The session-level replay bypass **never activates**. Even after the plugin captures a CUDA graph via `OnRunStart`/`OnRunEnd`, subsequent runs still go through the full kernel dispatch pipeline — the captured graph sits unused. +The current plugin EP does not implement CUDA graph callbacks at all: `CudaEp` sets `OnRunStart = nullptr` and `OnRunEnd = nullptr`, and the previously proposed graph-specific plugin files are not part of the branch. As a result, CUDA graph support is currently disabled rather than partially implemented. -The current plugin implementation has a partial mitigation: it captures the graph and replays it once (in `OnRunEnd` after capture). But on subsequent runs, `OnRunEnd` sees the graph is already captured and does nothing. +#### 5.4.3 Current Branch Design -#### 5.4.3 Revised Design — Remove EP-Level Graph Management +Given the API gap, the current branch uses the simplest correct design: -Given the API gap, the correct design for the plugin EP is: - -> **The plugin EP should NOT manage CUDA graph capture/replay internally.** CUDA graph support requires session-level cooperation that is not available through the current `OrtEp` C API. +> **The plugin EP does not manage CUDA graph capture/replay internally.** CUDA graph support remains deferred until the `OrtEp` C API grows the required session-cooperative callbacks. **Rationale:** 1. The `OrtEp` C API has no `IsGraphCaptureEnabled`/`IsGraphCaptured`/`ReplayGraph` callbacks. Without these, the session cannot know that the EP supports graph capture, cannot bypass kernel dispatch for replay, and cannot trigger the recursive warm-up sequence. -2. Implementing capture in `OnRunStart`/`OnRunEnd` without the session-level replay bypass is **incorrect** — the captured graph would never be replayed on subsequent runs (the session always dispatches kernels normally). +2. The plugin branch intentionally removed graph-specific implementation files instead of keeping an incomplete capture-only path. 3. The session's graph validation logic (all nodes on one EP, no control flow) is also not triggered without `IsGraphCaptureEnabled()`. @@ -316,10 +319,9 @@ Given the API gap, the correct design for the plugin EP is: | Option | Description | Effort | Status | |--------|------------|--------|--------| | **A. Extend the OrtEp C API** | Add `IsGraphCaptureEnabled`, `IsGraphCaptured`, `ReplayGraph` to `OrtEp`. Update `PluginExecutionProvider` to delegate to these. | Medium — requires ORT core changes | Preferred long-term solution | -| **B. Disable graph capture in plugin EP** | Remove `CUDAGraphManager` and graph-related code from the plugin. Document as a known limitation. Re-enable when Option A is available. | Small | Recommended for now | -| **C. Keep capture-only (no replay)** | Keep the current code but document that it only captures + replays once (the first time), with no subsequent replay optimization. | None | Misleading — gives false confidence | +| **B. Keep graph support disabled in the plugin EP** | Leave graph files and hooks out of the plugin build until Option A exists. | Small | Current branch behavior | -**Recommendation**: Option B for the current release, with Option A tracked as a public API enhancement request. +**Recommendation**: Keep Option B in place until Option A is available. #### 5.4.4 What Needs to Change in ORT Core (Option A) @@ -364,17 +366,15 @@ This would plug into the existing `cached_execution_provider_for_graph_replay_` | Component | Status | Notes | |-----------|--------|-------| -| `cuda_graph_plugin.h/.cc` | Implemented | `CUDAGraphManager` adapted from bundled EP. Captures/replays correctly. | -| `CudaEp::OnRunStartImpl` | Implemented | Reads `gpu_graph_id`, manages warm-up, begins capture. | -| `CudaEp::OnRunEndImpl` | Implemented | Ends capture, first replay. No subsequent replay. | -| Session-level replay bypass | **Not functional** | `OrtEp` API lacks `IsGraphCaptureEnabled`/`IsGraphCaptured`/`ReplayGraph`. | -| Tests | Pass (capture + first replay) | `test_cuda_plugin_cuda_graph()` tests warm-up, capture, and `gpu_graph_id=-1` disable. | +| `cuda_graph_plugin.h/.cc` | **Removed** | Not present in the current branch. | +| `CudaEp::OnRunStart` / `OnRunEnd` | **Disabled** | `CudaEp` installs `nullptr` for both callbacks. | +| Session-level replay bypass | **Unavailable** | `OrtEp` API still lacks `IsGraphCaptureEnabled`/`IsGraphCaptured`/`ReplayGraph`. | +| Tests | Not included | The plugin test script has no CUDA graph stage. | **Action items:** -1. Keep `cuda_graph_plugin.h/.cc` and `CudaEp` graph state machine code — it is correct and will be needed when the API gap is closed. -2. Default `enable_cuda_graph` to `false` in the plugin EP config and document the limitation. -3. File an ORT core feature request to add `IsGraphCaptureEnabled`/`IsGraphCaptured`/`ReplayGraph` to the `OrtEp` C API. -4. When the API is extended, wire up the existing `CUDAGraphManager` through the new callbacks. +1. Keep CUDA graph support disabled in the plugin build until the `OrtEp` C API grows the required replay hooks. +2. Add `IsGraphCaptureEnabled`/`IsGraphCaptured`/`ReplayGraph` to the `OrtEp` C API. +3. Reintroduce plugin-side graph management only after the public API is capable of session-cooperative replay. --- @@ -397,18 +397,18 @@ The adapter layer provides thin wrappers around the ORT C API that present a C++ ## 7. Excluded Operators -The following operators are excluded from the plugin build. Each exclusion has a specific technical reason and a path to inclusion. +Section 7 reflects the current source exclusions in `cmake/onnxruntime_providers_cuda_plugin.cmake`, plus the small set of intentionally out-of-scope directories. This is the source of truth for what the plugin build omits today. ### 7.1 Infrastructure (Permanently Excluded — Replaced by Plugin Equivalents) | File | Reason | |------|--------| -| `cuda_execution_provider.cc` | Replaced by `cuda_ep_provider.h` + `cuda_ep.h` | +| `cuda_execution_provider.cc` | Replaced by `cuda_ep.h/.cc` and the plugin adapter/runtime shim | | `cuda_provider_factory.cc` | Replaced by `cuda_ep_factory.cc` | | `cuda_provider_interface.cc` | Not needed in plugin architecture | | `cuda_stream_handle.cc` | Replaced by `cuda_stream_plugin.cc` | | `cuda_execution_provider_info.cc` | Config parsed directly in `CudaEp::Config` | -| `cuda_graph.cc` | Replaced by `cuda_graph_plugin.cc` | +| `cuda_graph.cc` | CUDA graph support deferred (files removed pending OrtEp API extension) | | `cuda_mempool_arena.cc` | Plugin uses `cudaMalloc`/`cudaFree` directly | | `cuda_common.cc` | Utility functions shimmed in `cuda_kernel_adapter.h` | | `cuda_nhwc_kernels.cc` | Replaced by `PluginKernelCollector` auto-registration | @@ -423,83 +423,46 @@ The following operators are excluded from the plugin build. Each exclusion has a ### 7.3 Operators Excluded Due to Missing Features -| File | Exclusion Reason | What's Needed to Include | -|------|-----------------|--------------------------| -| `controlflow/*` | CPU base class If/Loop/Scan not linked | Plugin has own wrappers in `cuda_controlflow_plugin.cc` via `OrtEpApi`. Already functional. | -| `tunable/*` | Depends on real `CudaTuningContext` | Implement plugin-side `ITuningContext` that delegates to ORT tuning APIs. Low priority. | -| `rnn/*` | ORT C API lacks `KernelInfoGetAttributeArray_string` | Extend C API with string-array attribute support. | -| `math/einsum.cc`, `math/einsum_utils/*` | `einsum_auxiliary_ops.cc` calls `ReductionOps::ReduceCompute` (framework-only) | Extract `ReduceCompute` into a shared interface or reimplement the reduction path. | -| `tensor/identity_op.cc` | Uses `TensorSeq` (incomplete type in plugin) | Add `TensorSeq` adapter to the EP adapter layer. | -| `tensor/sequence_op.cc` | Uses `TensorSeq` (incomplete type in plugin) | Same as above. | -| `tensor/space_depth_ops.cc` | Inherits `SpaceDepthBase` (CPU provider) | Constructor templatized on `KernelInfoType` (#27628). Remaining: inline `SpaceDepthCompute` validation logic. | -| `tensor/upsample.cc` | `UpsampleBase` uses `InputDefs()` and `OpKernelInfo::GetAllocator()` | `AdjustOutputSizeAsPolicy` moved to header (#27628). Remaining: extend adapter with `GetAllocator()` and `InputDefs()`. | -| `tensor/resize.cc` | Inherits from `Upsample` (excluded above) | Fix `Upsample` first, then `Resize` follows. | -| `generator/constant_of_shape.cc` | `ConstantOfShapeBase` depends on `TensorProto`/`UnpackTensor` | Plugin already has self-contained implementation in `constant_of_shape.h` via `#ifdef BUILD_CUDA_EP_AS_PLUGIN`. The `.cc` is excluded but the kernel works. | -| `object_detection/*` | `NonMaxSuppressionBase`, `RoiAlignBase` from CPU provider | `NonMaxSuppressionBaseImpl` template (#27617), `RoiAlignBase` constructor templatized (#27628). Remaining: integration verification. | -| `llm/*` | Attention ops dereference `onnxruntime::Stream*` (not adapter-compatible) | Extend adapter `OpKernelContext::GetComputeStream()` to return a full `Stream*` implementation. | -| `contrib_ops/cuda/llm/*` | Same as above | Same as above. | -| `contrib_ops/cuda/bert/attention.cc` | `GetComputeStream()` returns real `Stream*` which is needed | `AttentionBase` helpers moved to header (#27628). Remaining: `Stream*` adapter extension for `QkvToContext`. | -| `contrib_ops/cuda/bert/decoder_attention.cc` | Same | Same. | -| `contrib_ops/cuda/bert/decoder_masked_self_attention.cc` | Same | Same. | -| `contrib_ops/cuda/bert/embed_layer_norm.cc` | `EmbedLayerNormHelper` CPU base class | Already refactored helper; needs `GetComputeStream()` fix. | -| `contrib_ops/cuda/bert/fast_gelu.cc` | Was excluded due to `bias_gelu_helper` CPU base class dep | `bias_gelu_helper::CheckInputs` is now inlined. Remove this exclusion and verify. | -| `contrib_ops/cuda/bert/group_query_attention.cc` | `GetComputeStream()` / attention infra | Same `Stream*` adapter extension. | -| `contrib_ops/cuda/bert/longformer_attention.cc` | `LongformerAttentionBase::CheckInputs` moved to header (#27628) | `Stream*` adapter extension. | -| `contrib_ops/cuda/bert/multihead_attention.cc` | Same | Same. | -| `contrib_ops/cuda/bert/packed_attention.cc` | Same | Same. | -| `contrib_ops/cuda/bert/packed_multihead_attention.cc` | Same | Same. | -| `contrib_ops/cuda/bert/paged_attention.cc` | Same | Same. | -| `contrib_ops/cuda/bert/relative_attn_bias.cc` | Same | Same. | -| `contrib_ops/cuda/bert/remove_padding.cc` | Same | Same. | -| `contrib_ops/cuda/diffusion/group_norm.cc` | `GetComputeStream()` | Same `Stream*` adapter extension. | -| `contrib_ops/cuda/fused_conv.cc` | Framework type deps | Audit specific deps; likely `Stream*` related. | -| `contrib_ops/cuda/inverse.cc` | Framework type deps | Audit specific deps. | -| `contrib_ops/cuda/math/bias_dropout.cc` | `GetComputeStream()` | Same `Stream*` adapter extension. | -| `contrib_ops/cuda/math/fft_ops.cc` | `GetComputeStream()` | Same. | -| `contrib_ops/cuda/math/gemm_float8.cc`/`.cu` | `GetComputeStream()` in `.cu` file | Same, plus NVCC compatibility. | -| `contrib_ops/cuda/moe/moe.cc` | `GetComputeStream()` | Same. | -| `contrib_ops/cuda/sparse/sparse_attention.cc` | `GetComputeStream()` | Same. | -| `contrib_ops/cuda/tensor/crop.cc` | `CropBase` constructor templatized (#27628). No `GetComputeStream()` usage. | Verify compilation — very low effort. | -| `contrib_ops/cuda/tensor/dynamic_time_warping.cc` | `GetComputeStream()` | Same. | -| `contrib_ops/cuda/tensor/dynamicslice.cc` | `GetComputeStream()` | Same. | -| `contrib_ops/cuda/tensor/shrunken_gather.cc` | Training op, `provider_api.h` header dep | Low priority (training). | -| `contrib_ops/cuda/quantization/attention_quantization.cc` | `GetComputeStream()` | Same. | -| `contrib_ops/cuda/quantization/matmul_bnb4.cc` | `GetComputeStream()` | Same. | -| `contrib_ops/cuda/quantization/matmul_nbits.cc` | `GetComputeStream()` | Same. | -| `contrib_ops/cuda/quantization/moe_quantization.cc` | `GetComputeStream()` | Same. | -| `contrib_ops/cuda/quantization/qordered_ops/*` | `GetComputeStream()` | Same. | -| `contrib_ops/cuda/transformers/*` | Beam search, greedy search, sampling | Complex framework deps; needs significant adapter work. | -| `aten_ops/*` | ATen interop | Out of scope for plugin. | -| `collective/*` | NCCL collective ops | Out of scope for plugin. | +| File / Pattern | Why It Is Excluded Today | What Would Unblock It | +|----------------|--------------------------|------------------------| +| `core/providers/cuda/controlflow/*` | Framework controlflow kernels are omitted from the source list | Plugin equivalents already exist in `cuda_controlflow_plugin.cc`; the framework sources stay excluded by design | +| `tunable/*` | Depends on the real tuning context and framework CUDA EP infrastructure | Add a plugin-capable tuning context and remove the remaining tunable guards | +| `math/einsum.cc` | The top-level framework einsum source is still excluded | Provide a plugin-safe top-level einsum provider path; `einsum_utils/*` are no longer excluded | +| `tensor/identity_op.cc` | Uses `TensorSeq`, which is still not adapter-safe here | Add `TensorSeq` adapter coverage | +| `tensor/sequence_op.cc` | Uses `TensorSeq`, which is still not adapter-safe here | Same as above | +| `contrib_ops/cuda/llm/*` | Contrib LLM kernels still need their own plugin migration pass | Finish contrib-LLM-specific adapter work | +| `contrib_ops/cuda/tensor/shrunken_gather.cc` | Training header path still depends on framework/provider API wiring | Low-priority training-specific adapter work | +| `contrib_ops/cuda/math/fft_ops.cc` | Still excluded in CMake due to remaining framework/stream assumptions | Finish FFT-specific adapter cleanup | +| `contrib_ops/cuda/tensor/crop.cc` | Still excluded in CMake even though the constructor-side helper work is mostly done | Finish and validate the remaining plugin-safe path, then remove the CMake exclusion | +| `contrib_ops/cuda/tensor/dynamicslice.cc` | Still excluded in CMake due to remaining framework assumptions | Finish dynamicslice-specific adapter cleanup | +| `contrib_ops/cuda/transformers/*` | Beam search / greedy search / sampling require broader framework integration | Significant adapter and subgraph support work | +| `onnxruntime/contrib_ops/cuda/aten_ops/*` | ATen interop is out of scope for the plugin build | Separate ATen plugin strategy | +| `onnxruntime/contrib_ops/cuda/collective/*` | Collective/NCCL path is out of scope for the plugin build | Separate collective/NCCL plugin strategy | ### 7.4 Common Exclusion Themes -The majority of excluded operators fall into a few categories: +The current exclusions fall into a few categories: -1. **`GetComputeStream()` returning `onnxruntime::Stream*`** (~25 ops) — The adapter's `GetComputeStream()` returns a `PluginCudaComputeStreamShim` which wraps a raw `cudaStream_t`. Many attention/LLM ops dereference `Stream*` expecting a `CudaStream` with extra members. **Unblocking this is the single highest-impact change.** +1. **Tunable/framework-dependent infrastructure** — `tunable/*`, contrib transformers, and some contrib LLM paths still rely on framework-only execution-provider services. -2. **CPU base class inheritance** (~5 ops) — Some ops inherit from CPU base classes not linked into the plugin. Most have been refactored with the inline-header pattern. `SpaceDepthBase` and `RoiAlignBase` constructors are now templatized (#27628); `NonMaxSuppressionBase` refactored to a template (#27617); `UpsampleBase::AdjustOutputSizeAsPolicy` moved to header (#27628). Remaining: `UpsampleBase` `InputDefs()`/`GetAllocator()`. +2. **Remaining adapter gaps** — `TensorSeq`, some contrib FFT/crop/dynamicslice paths, and contrib-LLM-specific plumbing still need dedicated adapter work. -3. **Missing C API features** (~2 ops) — RNN ops need string-array attribute support via the C API. +3. **Deliberate scope cuts** — ATen and collective/NCCL sources remain intentionally out of scope for the standalone CUDA plugin. -4. **Framework-only code paths** (~3 ops) — Einsum's reduction path, tunable infrastructure. +4. **Top-level framework wrappers still excluded** — `math/einsum.cc` remains excluded even though supporting pieces such as `einsum_utils/*` are now plugin-safe. --- ## 8. Remaining `#ifdef` Guards in Kernel Code -After refactoring, only 6 files contain `BUILD_CUDA_EP_AS_PLUGIN` or `ORT_USE_EP_API_ADAPTERS` guards: +The branch still contains a small set of plugin guards in both infrastructure and operator code. The important pattern has not changed: -| File | Guard | Purpose | Removable? | -|------|-------|---------|------------| -| `cuda_kernel.h` | Both | Three-way gate: plugin → adapter; in-tree → real CudaKernel | No — infrastructure | -| `cuda_common.h` | Both | Logging macros, error macros, `HalfGemmOptions` | No — infrastructure | -| `cuda_execution_provider.h` | `ORT_USE_EP_API_ADAPTERS` | Skip entire class in plugin build | No — infrastructure | -| `generator/constant_of_shape.h` | `BUILD_CUDA_EP_AS_PLUGIN` | Self-contained plugin implementation | No — can't inline `ConstantOfShapeBase` | -| `math/matmul.cc` | `ORT_USE_EP_API_ADAPTERS` | Guards `FuncManager` registration (tunable) | Only when tunable is supported | -| `math/gemm.cc` | `ORT_USE_EP_API_ADAPTERS` | Guards `FuncManager` registration (tunable) | Only when tunable is supported | +- Infrastructure files such as `cuda_kernel.h`, `cuda_common.h`, and `cudnn_common.h` still need build-mode gates. +- `generator/constant_of_shape.h` still needs a plugin-specific path because `ConstantOfShapeBase` depends on framework-only tensor-attribute helpers. +- Tunable kernels such as `math/matmul.cc` still gate framework-only registration paths. +- A few tensor kernels (`pad.cc`, `tile.cc`, `unsqueeze.cc`, `upsample.*`, `space_depth_ops.h`, `scatter_nd.*`) still contain localized plugin guards where adapter and framework paths have not fully converged. -All kernel-level `#ifdef` guards in operator `.cc` files have been eliminated through the inline-header refactoring pattern, except for `matmul.cc`, `gemm.cc` (tunable dispatch), and `constant_of_shape.h` (protobuf dependency). +The broad trend remains positive: most operator-level plugin conditionals were removed by moving reusable CPU/helper logic into shared headers and by centralizing stream bridging in `CudaKernel` helpers. --- @@ -567,17 +530,17 @@ The plugin is then available as `CudaPluginExecutionProvider` in session provide ### 10.1 Test Script -`onnxruntime/test/python/transformers/test_cuda_plugin_ep.py` provides multi-stage testing: +`onnxruntime/test/python/transformers/test_cuda_plugin_ep.py` provides the current focused plugin validation flow: | Stage | What It Tests | |-------|---------------| +| Registration | Dynamic loading via `register_execution_provider_library()` and EP device discovery | | Stage 2 | Basic ops: Add, MatMul, Gemm, Conv | | Stage 3 | NHWC layout: Conv, BatchNorm, MaxPool, AveragePool | -| Stage 4 | CUDA Graph capture/replay | | Stage 5A | Standard ops: Reshape, Split, Concat, Gather, Unsqueeze | | Stage 5B | More ops: Tile, CumSum, ConstantOfShape, SpaceToDepth, Pad, Slice, Resize, Sum | | Stage 5C | CPU base class ops: Upsample, DepthToSpace | -| Stage 5D | Contrib ops: FastGelu, BiasDropout, SkipLayerNorm | +| Stage 5D | Contrib ops: FastGelu, SkipLayerNorm (BiasDropout is currently skipped as a known issue in the script) | ### 10.2 Running Tests @@ -589,6 +552,8 @@ cd /tmp python /path/to/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py ``` +The current branch has been validated with `./cuda_plugin.sh --build --test_plugin`, which runs this script against the locally built plugin library. + ### 10.3 Parity Report `tools/ci_build/cuda_plugin_parity_report.py` generates a report comparing registered kernels between the in-tree CUDA EP and the plugin EP, identifying gaps. @@ -636,12 +601,14 @@ static void ComputePadsImpl(KernelContextType& ctx, ...) { ... } The CUDA kernel calls `ComputePadsImpl(*ctx, ...)` directly. -### 11.4 If the kernel uses GetComputeStream() +### 11.4 If the kernel uses stream helpers -Check whether the kernel actually dereferences the `Stream*` or just needs the raw `cudaStream_t`: +Prefer the shared helpers in `CudaKernel` instead of introducing new plugin-only stream shims: -- If it only needs `stream->GetHandle()` → use `Stream(ctx)` instead (returns `cudaStream_t`) -- If it dereferences `CudaStream*` members → the kernel is blocked until the `Stream*` adapter is extended +- If the code only needs a raw CUDA stream, use `Stream(ctx)`. +- If the shared helper API accepts the adapter's opaque stream handle, use `GetComputeStream(ctx)`. +- If framework-style helper code still expects `onnxruntime::Stream*`, use `GetOrtStream(ctx)`. +- Prefer `GetCublasHandle(ctx)`, `GetCudnnHandle(ctx)`, and `GetCublasLtHandle(ctx)` over re-discovering handles from the stream manually. ### 11.5 If the kernel uses handle accessors @@ -659,8 +626,7 @@ Use the plugin-compatible overloads already in `CudaKernel`: ``` onnxruntime/core/providers/cuda/plugin/ ├── cuda_kernel_adapter.h # CudaKernel base, macros, CPU shims (force-included) -├── cuda_ep_provider.h # Plugin-local CUDAExecutionProvider -├── cuda_ep.h / .cc # CudaEp : adapter::Ep +├── cuda_ep.h / .cc # CudaEp : OrtEp implementation ├── cuda_ep_factory.h / .cc # CudaEpFactory : OrtEpFactory ├── cuda_plugin_ep.cc # DLL entry points (CreateEpFactories/ReleaseEpFactory) ├── cuda_plugin_kernels.h / .cu # Kernel registry creation @@ -668,10 +634,7 @@ onnxruntime/core/providers/cuda/plugin/ ├── cuda_allocator_plugin.h / .cc # Device/pinned allocators ├── cuda_data_transfer_plugin.h / .cc # GPU↔CPU data transfer ├── cuda_controlflow_plugin.h / .cc / .cu # If/Loop/Scan wrappers -├── cuda_graph_plugin.h / .cc # CUDA Graph support ├── cuda_plugin_utils.h # Common macros, error handling -├── cuda_iallocator_plugin.h # IAllocator declarations -├── cuda_idata_transfer_plugin.h # IDataTransfer declarations └── provider_api_shims.cc # Reimplemented utility functions include/onnxruntime/ep/ @@ -696,14 +659,14 @@ include/onnxruntime/ep/ ## 13. Future Work -1. **`Stream*` adapter** — Extend the adapter `OpKernelContext::GetComputeStream()` to return a full `Stream*` that attention/LLM ops can use. This unblocks ~25 operators. +1. **Contrib LLM migration pass** — The core CUDA LLM attention path is now adapter-safe, but `contrib_ops/cuda/llm/*` is still excluded as a separate follow-up. 2. **Tunable ops** — Implement a plugin-side `ITuningContext` and remove the `ORT_USE_EP_API_ADAPTERS` guards in `matmul.cc`/`gemm.cc`. -3. **String-array C API** — Add `KernelInfoGetAttributeArray_string` to the ORT C API to unblock RNN ops. +3. **TensorSeq adapter coverage** — Add enough sequence/tensor-sequence support to unblock `identity_op.cc` and `sequence_op.cc`. -4. **Remaining CPU base classes** — Inline `SpaceDepthBase`, `UpsampleBase`, and object detection base classes. +4. **Remaining contrib exclusions** — Remove the CMake exclusions for FFT, crop, and dynamicslice once their remaining framework assumptions are gone. 5. **CI integration** — Add plugin build + test to the CI pipeline. -6. **CUDA Graph API for plugin EPs** — Add `IsGraphCaptureEnabled`, `IsGraphCaptured`, and `ReplayGraph` callbacks to the `OrtEp` C API (see [Section 5.4.4](#544-what-needs-to-change-in-ort-core-option-a)). This is required for efficient CUDA graph replay in the plugin EP. The capture/replay infrastructure (`cuda_graph_plugin.h/.cc`, `CudaEp` state machine) is already implemented and will activate once the API is extended. +6. **CUDA Graph API for plugin EPs** — Add `IsGraphCaptureEnabled`, `IsGraphCaptured`, and `ReplayGraph` callbacks to the `OrtEp` C API (see [Section 5.4.4](#544-what-needs-to-change-in-ort-core-option-a)). This is required for efficient CUDA graph replay in the plugin EP. The capture/replay infrastructure will be reintroduced once the API is extended. diff --git a/include/onnxruntime/ep/api.h b/include/onnxruntime/ep/api.h index 8c8951490a55d..4e2b97d544e3b 100644 --- a/include/onnxruntime/ep/api.h +++ b/include/onnxruntime/ep/api.h @@ -3,6 +3,7 @@ #pragma once +#include #include #pragma push_macro("ORT_API_MANUAL_INIT") @@ -35,18 +36,20 @@ inline const ApiPtrs& Api() { /// /// Initialize the EP API pointers and global OrtEnv if not already done. +/// Thread-safe via std::call_once. /// inline void ApiInit(const OrtApiBase* ort_api_base) { - // Manual init for the C++ API - const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); - const OrtEpApi* ep_api = ort_api->GetEpApi(); - const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); - Ort::InitApi(ort_api); - - // Initialize the global API instance - if (!detail::g_api_ptrs) { + static std::once_flag init_flag; + std::call_once(init_flag, [&]() { + // Manual init for the C++ API + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + const OrtEpApi* ep_api = ort_api->GetEpApi(); + const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); + Ort::InitApi(ort_api); + + // Initialize the global API instance detail::g_api_ptrs.emplace(*ort_api, *ep_api, *model_editor_api); - } + }); } } // namespace ep diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h index c30721c909a97..f979b0e6230b5 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h @@ -3,6 +3,11 @@ #pragma once +// SHARED_PROVIDER is defined in the in-tree CUDA EP shared library build +// (onnxruntime_providers_cuda). It gates out framework headers that are +// re-exported via the DLL-boundary proxy. The plugin EP build uses a +// different flag (BUILD_CUDA_EP_AS_PLUGIN) and the force-include adapter +// headers instead. Both builds need these headers excluded. #ifndef SHARED_PROVIDER #include "core/common/common.h" #include "core/common/type_list.h" diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc index 895916bb76b71..6708bbd61dacf 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc @@ -27,10 +27,16 @@ CudaDeviceAllocator::CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int d auto* alloc = static_cast(this_ptr); void* p = nullptr; if (size == 0) return nullptr; + // Save and restore CUDA device context to avoid corrupting the calling + // thread's device state in multi-GPU scenarios. + int prev_device = -1; + cudaGetDevice(&prev_device); if (cudaSetDevice(alloc->device_id_) != cudaSuccess) { + cudaSetDevice(prev_device); return nullptr; } cudaError_t err = cudaMalloc(&p, size); + cudaSetDevice(prev_device); if (err != cudaSuccess) { return nullptr; } @@ -40,8 +46,11 @@ CudaDeviceAllocator::CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int d /*static*/ void ORT_API_CALL CudaDeviceAllocator::FreeImpl(OrtAllocator* this_ptr, void* p) noexcept { auto* alloc = static_cast(this_ptr); if (p != nullptr) { + int prev_device = -1; + cudaGetDevice(&prev_device); cudaSetDevice(alloc->device_id_); cudaFree(p); + cudaSetDevice(prev_device); } } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc index 696712b2ea693..37810e0c9d6d8 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc @@ -51,6 +51,7 @@ CudaDataTransfer::CudaDataTransfer(const OrtApi& ort_api, const OrtEpApi& ep_api EXCEPTION_TO_STATUS_BEGIN auto* dt = static_cast(this_ptr); + bool need_stream_sync = false; for (size_t i = 0; i < count; ++i) { Ort::ConstValue src{src_tensors[i]}; @@ -73,6 +74,7 @@ CudaDataTransfer::CudaDataTransfer(const OrtApi& ort_api, const OrtEpApi& ep_api const OrtMemoryDevice* dst_dev = dt->ep_api_.MemoryInfo_GetMemoryDevice(dst_mem_info); auto src_dev_type = dt->ep_api_.MemoryDevice_GetDeviceType(src_dev); auto dst_dev_type = dt->ep_api_.MemoryDevice_GetDeviceType(dst_dev); + auto src_mem_type = dt->ep_api_.MemoryDevice_GetMemoryType(src_dev); cudaMemcpyKind copy_kind; if (src_dev_type == OrtMemoryInfoDeviceType_CPU && dst_dev_type == OrtMemoryInfoDeviceType_GPU) { @@ -91,10 +93,27 @@ CudaDataTransfer::CudaDataTransfer(const OrtApi& ort_api, const OrtEpApi& ep_api Ort::GetApi().SyncStream_GetHandle(streams[i])); PL_CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, copy_kind, cuda_stream)); } else { + if (copy_kind == cudaMemcpyDeviceToDevice && dst_data == src_data) { + continue; + } + PL_CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, copy_kind)); + + if (copy_kind == cudaMemcpyDeviceToDevice) { + // Match the built-in CUDA EP: cudaMemcpy D2D launches on the default + // stream but does not guarantee host-side completion on return. + need_stream_sync = true; + } else if (copy_kind == cudaMemcpyHostToDevice && src_mem_type != OrtDeviceMemoryType_HOST_ACCESSIBLE) { + // Pageable host memory may still be in flight after cudaMemcpy returns. + need_stream_sync = true; + } } } + if (need_stream_sync) { + PL_CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } + return nullptr; EXCEPTION_TO_STATUS_END diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index fa5173e7e8736..dea3b75078a1a 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -112,6 +112,7 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( size_t& num_ep_devices = *p_num_ep_devices; num_ep_devices = 0; + int cuda_device_index = 0; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *hw_devices[i]; auto hw_type = factory->ort_api_.HardwareDevice_Type(&device); @@ -125,7 +126,12 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( continue; // Skip non-NVIDIA GPUs } - int32_t current_device_id = factory->ort_api_.HardwareDevice_DeviceId(&device); + int32_t current_device_id = cuda_device_index++; + + { + std::lock_guard lock(factory->device_map_mutex_); + factory->hw_device_to_cuda_index_[&device] = current_device_id; + } OrtKeyValuePairs* ep_metadata = nullptr; factory->ort_api_.CreateKeyValuePairs(&ep_metadata); @@ -160,9 +166,15 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( /*alignment is default*/ 0, OrtAllocatorType::OrtDeviceAllocator}; + OrtMemoryInfo* raw_memory_info = device_memory_info; + { + std::lock_guard lock(factory->cached_memory_info_mutex_); + factory->cached_memory_infos_.push_back(std::move(device_memory_info)); + } + // Register allocator info for GPU device memory RETURN_IF_ERROR(factory->ep_api_.EpDevice_AddAllocatorInfo( - ep_device, device_memory_info)); + ep_device, raw_memory_info)); // Register allocator info for CPU pinned memory (host accessible) RETURN_IF_ERROR(factory->ep_api_.EpDevice_AddAllocatorInfo( @@ -204,7 +216,15 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( // The read helpers intentionally swallow errors: if a config entry is // absent or malformed the default value in Config is kept. CudaEp::Config config{}; - config.device_id = factory->ort_api_.HardwareDevice_DeviceId(devices[0]); + + config.device_id = 0; // Default + { + std::lock_guard lock(factory->device_map_mutex_); + auto it = factory->hw_device_to_cuda_index_.find(devices[0]); + if (it != factory->hw_device_to_cuda_index_.end()) { + config.device_id = it->second; + } + } auto read_session_config_bool = [&](const char* key, bool& value) { size_t size = 0; @@ -302,13 +322,13 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateAllocatorImpl( return status; } - if (strcmp(name, "Cuda") == 0) { + if (name != nullptr && strcmp(name, "Cuda") == 0) { auto cuda_allocator = std::make_unique(memory_info, req_device_id); *allocator = cuda_allocator.release(); return nullptr; } - if (strcmp(name, "CudaPinned") == 0) { + if (name != nullptr && strcmp(name, "CudaPinned") == 0) { auto pinned_allocator = std::make_unique(memory_info); *allocator = pinned_allocator.release(); return nullptr; @@ -332,7 +352,8 @@ void ORT_API_CALL CudaEpFactory::ReleaseAllocatorImpl( delete static_cast(allocator); return; default: - ORT_ENFORCE(false, "Unknown CudaAllocatorKind in ReleaseAllocatorImpl"); + // Cannot throw in noexcept function + break; } } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h index 96ec789d9abed..2562f5667e047 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h @@ -90,8 +90,15 @@ class CudaEpFactory : public OrtEpFactory { // Memory info for GPU device and CPU pinned memory Ort::MemoryInfo default_memory_info_{nullptr}; Ort::MemoryInfo pinned_memory_info_{nullptr}; + + std::mutex cached_memory_info_mutex_; + std::vector cached_memory_infos_; int device_id_ = 0; + // Map ORT hardware device pointers to internal CUDA ordinal indices + std::mutex device_map_mutex_; + std::unordered_map hw_device_to_cuda_index_; + // Kernel registry (cached, shared across EP instances) OrtKernelRegistry* kernel_registry_ = nullptr; std::mutex registry_mutex_; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index a61a05719794c..19b10f57c40d5 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -191,6 +191,11 @@ namespace cuda { /// Singleton collector for BuildKernelCreateInfoFn pointers. /// Each compiled kernel .cc file's macro expansion auto-registers here. +/// +/// Thread-safety: Instance() uses a function-local static (C++11 §6.7/4: +/// constructed exactly once, even under concurrent first-access). All Add() +/// calls happen during static initialisation of each translation unit +/// (before main()), which is single-threaded per the C++ standard. class PluginKernelCollector { public: static PluginKernelCollector& Instance() { @@ -403,6 +408,43 @@ IConstantBuffer* GetConstOnesBufferForDevice(int device_id) { return buffer.get(); } +struct DefaultCudaHandles { + cublasHandle_t cublas = nullptr; + cudnnHandle_t cudnn = nullptr; + + ~DefaultCudaHandles() { + if (cublas != nullptr) { + cublasDestroy(cublas); + } + if (cudnn != nullptr) { + cudnnDestroy(cudnn); + } + } +}; + +inline DefaultCudaHandles& GetDefaultCudaHandlesForDevice(int device_id) { + thread_local std::unordered_map handles_by_device; + auto [it, inserted] = handles_by_device.try_emplace(device_id); + if (inserted) { + int prev_device = -1; + cudaGetDevice(&prev_device); + PL_CUDA_CALL_THROW(cudaSetDevice(device_id)); + if (cublasCreate(&it->second.cublas) != CUBLAS_STATUS_SUCCESS) { + cudaSetDevice(prev_device); + ORT_THROW("Failed to create default cuBLAS handle for CUDA plugin device ", device_id); + } + if (cudnnCreate(&it->second.cudnn) != CUDNN_STATUS_SUCCESS) { + cublasDestroy(it->second.cublas); + it->second.cublas = nullptr; + cudaSetDevice(prev_device); + ORT_THROW("Failed to create default cuDNN handle for CUDA plugin device ", device_id); + } + PL_CUDA_CALL_THROW(cudaSetDevice(prev_device)); + } + + return it->second; +} + inline const cudaDeviceProp& GetDevicePropForDevice(int device_id) { static std::mutex mutex; static std::unordered_map> props; @@ -444,6 +486,15 @@ inline const cudaDeviceProp& GetDevicePropForDevice(int device_id) { class CUDAExecutionProvider : public onnxruntime::IExecutionProvider { public: explicit CUDAExecutionProvider(const std::string& name) : onnxruntime::IExecutionProvider{name} {} + + // SAFETY: This class must remain empty (no added member variables beyond + // IExecutionProvider). In the plugin build, an OrtEp*/CudaEp* is cast to + // CUDAExecutionProvider*. Adding members would cause the compiler to read + // them at incorrect byte offsets, silently corrupting data. All runtime + // state is stored in ProviderConfigStore, keyed by `this`. + // If the static_assert below fires, move the new state into + // CudaKernelAdapterRuntimeConfig instead of adding members here. + int GetCudnnConvAlgo() const { return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).cudnn_conv_algo; } @@ -464,6 +515,10 @@ class CUDAExecutionProvider : public onnxruntime::IExecutionProvider { } }; +// Verify CUDAExecutionProvider has no added members — see phantom cast design note. +static_assert(sizeof(CUDAExecutionProvider) == sizeof(onnxruntime::IExecutionProvider), + "CUDAExecutionProvider must not add member variables."); + namespace cuda { inline void SetCudaKernelAdapterRuntimeConfigForProvider(const void* provider, bool use_tf32, int device_id, @@ -677,8 +732,8 @@ class CudaKernel : public OpKernel { virtual Status ComputeInternal(OpKernelContext* ctx) const = 0; inline cudaStream_t DefaultCudaStream() const { return Stream(static_cast(nullptr)); } - inline cublasHandle_t DefaultCublasHandle() const { return GetCublasHandle(static_cast(nullptr)); } - inline cudnnHandle_t DefaultCudnnHandle() const { return GetCudnnHandle(static_cast(nullptr)); } + inline cublasHandle_t DefaultCublasHandle() const { return detail::GetDefaultCudaHandlesForDevice(device_id_).cublas; } + inline cudnnHandle_t DefaultCudnnHandle() const { return detail::GetDefaultCudaHandlesForDevice(device_id_).cudnn; } inline Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst, onnxruntime::Stream& stream) const { if (src.Shape().Size() == 0) return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc index fb144f5903d28..0339f8add1bfd 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc @@ -3,6 +3,7 @@ #include "cuda_stream_plugin.h" #include "cuda_ep_factory.h" +#include #include #include @@ -25,6 +26,13 @@ std::shared_mutex& GetStreamMapMutex() { static std::shared_mutex stream_map_mutex; return stream_map_mutex; } + +// Monotonically increasing generation counter, bumped on every UnregisterStream +// so that TLS caches can detect stale entries without acquiring a lock. +std::atomic& GetStreamMapGeneration() { + static std::atomic generation{0}; + return generation; +} } // namespace // --------------------------------------------------------------------------- @@ -122,10 +130,16 @@ void CudaSyncStream::CleanupDeferredCPUBuffers() { return nullptr; } - // Thread-local TLS cache to mitigate lock contention on the hot path + // Thread-local TLS cache to mitigate lock contention on the hot path. + // The generation counter is bumped on every UnregisterStream() so that + // stale TLS entries (pointing to destroyed CudaSyncStream objects) are + // automatically invalidated without requiring per-thread notification. thread_local cudaStream_t tls_last_stream = nullptr; thread_local CudaSyncStream* tls_last_sync_stream = nullptr; - if (stream == tls_last_stream) { + thread_local uint64_t tls_generation = 0; + + uint64_t current_gen = GetStreamMapGeneration().load(std::memory_order_acquire); + if (stream == tls_last_stream && tls_generation == current_gen) { return tls_last_sync_stream; } @@ -135,6 +149,7 @@ void CudaSyncStream::CleanupDeferredCPUBuffers() { if (it != stream_map.end()) { tls_last_stream = stream; tls_last_sync_stream = it->second; + tls_generation = current_gen; return it->second; } return nullptr; @@ -150,6 +165,8 @@ void CudaSyncStream::CleanupDeferredCPUBuffers() { auto& stream_map = GetStreamMap(); std::unique_lock lock(GetStreamMapMutex()); stream_map.erase(stream); + // Bump generation so TLS caches in other threads are invalidated. + GetStreamMapGeneration().fetch_add(1, std::memory_order_release); } // --------------------------------------------------------------------------- diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 9a828c701d07f..dfa43130868dc 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -808,7 +808,6 @@ SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int8_t) SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(uint8_t) namespace ReductionOps { -#ifndef BUILD_CUDA_EP_AS_PLUGIN template std::unique_ptr ReduceCompute(const AllocatorPtr& gpu_allocator, cudnnReduceTensorOp_t cudnn_reduce_op, AllocatorPtr allocator, const Tensor& input, gsl::span axes, @@ -864,7 +863,6 @@ template std::unique_ptr ReduceCompute CreateExecutionProviderFactory return nullptr; } + bool has_requested_device_id = false; + int requested_device_id = 0; + if (const auto provider_it = provider_options_map.find(type); provider_it != provider_options_map.end()) { + if (const auto device_id_it = provider_it->second.find("device_id"); device_id_it != provider_it->second.end()) { + try { + requested_device_id = std::stoi(device_id_it->second); + has_requested_device_id = requested_device_id >= 0; + } catch (const std::exception& ex) { + LOGS_DEFAULT(WARNING) << "Ignoring invalid device_id provider option '" << device_id_it->second + << "' for registered plugin EP '" << type << "': " << ex.what(); + } + } + } + const OrtEpDevice* selected_device = nullptr; for (const OrtEpDevice* ep_device : ep_devices) { if (!ep_device || ep_device->ep_name != type) { continue; } + if (has_requested_device_id) { + Ort::ConstEpDevice current_device(ep_device); + if (static_cast(current_device.Device().DeviceId()) != requested_device_id) { + continue; + } + } + if (selected_device == nullptr) { selected_device = ep_device; break; @@ -605,6 +626,10 @@ static std::shared_ptr CreateExecutionProviderFactory } if (selected_device == nullptr) { + if (has_requested_device_id) { + LOGS_DEFAULT(WARNING) << "No registered plugin EP device found for '" << type + << "' with device_id=" << requested_device_id; + } return nullptr; } diff --git a/orttraining/orttraining/training_ops/cuda/reduction/reduction_ops.cc b/orttraining/orttraining/training_ops/cuda/reduction/reduction_ops.cc index ff8f5f81e4e37..0c59efaee62c8 100644 --- a/orttraining/orttraining/training_ops/cuda/reduction/reduction_ops.cc +++ b/orttraining/orttraining/training_ops/cuda/reduction/reduction_ops.cc @@ -58,8 +58,9 @@ Status ReduceKernel::ComputeImplEx(OpKernelContext* ctx, cudnn Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); const bool fast_reduction = fast_reduction_ && !ctx->GetUseDeterministicCompute(); - return ReduceComputeCore(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), *X, prepare_reduce_metadata, *Y, cudnn_reduce_op, axes, - calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction, ctx->GetComputeStream()); + return ReduceComputeCore(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), this, *X, prepare_reduce_metadata, *Y, cudnn_reduce_op, axes, + calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction, + Stream(ctx), GetComputeStream(ctx), GetCudnnHandle(ctx)); } template <> From 876c3b620f9bf9ff433945e62861d5b736702be6 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 25 Mar 2026 00:42:59 -0700 Subject: [PATCH 18/48] comments --- include/onnxruntime/ep/adapter/op_kernel.h | 3 +++ onnxruntime/contrib_ops/cuda/bert/attention.cc | 2 ++ onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc | 4 ++++ onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc | 2 ++ onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc | 3 +++ .../contrib_ops/cuda/quantization/matmul_nbits.h | 10 ++++++---- .../providers/cuda/plugin/cuda_allocator_plugin.h | 11 +++++++++-- .../providers/cuda/plugin/cuda_controlflow_plugin.cc | 3 +++ .../cuda/plugin/cuda_data_transfer_plugin.h | 4 ++++ onnxruntime/core/providers/cuda/plugin/cuda_ep.cc | 11 ++++++++++- .../core/providers/cuda/plugin/cuda_ep_factory.h | 2 ++ .../core/providers/cuda/plugin/cuda_stream_plugin.h | 12 ++++++++++++ .../core/providers/cuda/plugin/provider_api_shims.cc | 5 ++++- .../core/providers/cuda/reduction/reduction_ops.cc | 3 +++ onnxruntime/core/providers/cuda/tensor/pad.cc | 6 ++++++ onnxruntime/core/providers/cuda/tensor/scatter_nd.cc | 5 +++-- .../core/providers/cuda/tensor/space_depth_ops.h | 11 ++++++++--- onnxruntime/core/providers/cuda/tensor/split.cc | 3 +++ onnxruntime/core/providers/cuda/tensor/tile.cc | 3 +++ onnxruntime/core/providers/cuda/tensor/unsqueeze.cc | 4 ++++ onnxruntime/core/providers/cuda/tensor/upsample.cc | 3 +++ 21 files changed, 97 insertions(+), 13 deletions(-) diff --git a/include/onnxruntime/ep/adapter/op_kernel.h b/include/onnxruntime/ep/adapter/op_kernel.h index fece85435733d..969f34ae820ba 100644 --- a/include/onnxruntime/ep/adapter/op_kernel.h +++ b/include/onnxruntime/ep/adapter/op_kernel.h @@ -93,6 +93,8 @@ struct OpKernelContext { input_tensors_[index] = CreateTensorFromApiValue(const_cast(static_cast(input))); return &input_tensors_[index]; } + /// Get a required (non-optional) input tensor. Throws if the input is null. + /// Use Input() for optional inputs that may legitimately be absent. template >> const T& RequiredInput(int index) const { @@ -116,6 +118,7 @@ struct OpKernelContext { output_tensors_[index] = CreateTensorFromApiValue(output); return &output_tensors_[index]; } + /// Get a required (non-optional) output tensor. Throws if the output is null. Tensor& RequiredOutput(int index, const TensorShape& shape) { auto* output = Output(index, shape); ORT_ENFORCE(output != nullptr, "Required output ", index, " is null"); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 69f319fdd977d..3e6a78775f4a2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -246,6 +246,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size); int k = parameters.input_hidden_size; #ifdef BUILD_CUDA_EP_AS_PLUGIN + // Plugin build: use GetScratchBuffer (adapter-compatible) instead of + // IAllocator::MakeUniquePtr which requires the full allocator interface. IAllocatorUniquePtr gemm_buffer = GetScratchBuffer(static_cast(m * n) * sizeof(T), GetComputeStream(context)); #else IAllocatorUniquePtr gemm_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(m * n) * sizeof(T), false, context->GetComputeStream()); diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index 3c1feb7af956a..0a5e2ef55197b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -31,6 +31,10 @@ REGISTER_KERNEL_TYPED(double) using namespace ONNX_NAMESPACE; #ifdef BUILD_CUDA_EP_AS_PLUGIN +// PLUGIN BUILD ADAPTATION: bias_gelu_helper::CheckInputs lives in the CPU +// provider and cannot be linked into the plugin. Reimplement the same input +// validation (rank checks, bias shape matching) inline. +// Keep in sync with contrib_ops/cpu/bert/bias_gelu_helper.h. static Status CheckInputsForPlugin(const OpKernelContext* context) { const Tensor* input = context->Input(0); const Tensor* bias = context->Input(1); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 859bc43c0c8d3..aefd86a6ebd10 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -44,6 +44,8 @@ SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) ORT_ENFORCE(epsilon_ >= 0); #ifdef BUILD_CUDA_EP_AS_PLUGIN + // Plugin adapter cannot static_cast to CUDAExecutionProvider directly. + // Use the adapter shim that reads the config from the per-EP runtime map. strict_ = onnxruntime::cuda::GetCudaKernelAdapterSkipLayerNormStrictMode(op_kernel_info.GetExecutionProvider()); #else const CUDAExecutionProvider* cuda_ep = static_cast(op_kernel_info.GetExecutionProvider()); diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 72a7a5f164ead..8f729f913c036 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -67,6 +67,9 @@ struct DispatchGroupNorm { } #ifdef BUILD_CUDA_EP_AS_PLUGIN + // Plugin overload: accepts PluginTuningContextStub* (unused) and raw void* + // stream handle instead of IKernelExplorer*/Stream* which are not available + // in the plugin build. Uses OrtStreamAdapter to bridge to the _impl kernel. Status operator()(CudaKernel::PluginTuningContextStub* tuning_ctx, void* ort_stream, Tensor* output, diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h index 64969ae499bf7..3345856fad98b 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -50,13 +50,15 @@ class MatMulNBits final : public CudaKernel { constexpr int kInputIndexBias = 5; #ifdef BUILD_CUDA_EP_AS_PLUGIN - // Plugin adapter Node does not have InputDefs(). Defer existence checks to ComputeInternal - // where we can check if the actual input tensor is null or not. - ORT_UNUSED_PARAMETER(kInputIndexScale); // used only in non-plugin path + // PLUGIN BUILD ADAPTATION: The adapter Node does not expose InputDefs(), + // so we cannot check whether optional inputs (zero_points, g_idx, bias) + // truly exist at construction time. Instead, we check input count here + // and verify actual tensor presence in ComputeInternal. + ORT_UNUSED_PARAMETER(kInputIndexScale); // only used in non-plugin path for type checking has_zero_points_ = info.GetInputCount() > kInputIndexZeroPoints; has_g_idx_ = info.GetInputCount() > kInputIndexGroupIndex; has_bias_ = info.GetInputCount() > kInputIndexBias; - // is_zero_points_scale_same_type_ defaults to false; runtime will handle differences. + // is_zero_points_scale_same_type_ defaults to false; checked at runtime in plugin path. #else has_zero_points_ = info.GetInputCount() > kInputIndexZeroPoints && info.node().InputDefs()[kInputIndexZeroPoints]->Exists(); has_g_idx_ = info.GetInputCount() > kInputIndexGroupIndex && info.node().InputDefs()[kInputIndexGroupIndex]->Exists(); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h index 371b097f2eedd..8a136d657b661 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h @@ -1,6 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// CUDA device and pinned memory allocator implementations for the plugin EP. +// Provides CudaDeviceAllocator (cudaMalloc/cudaFree) and CudaPinnedAllocator +// (cudaHostAlloc/cudaFreeHost) conforming to the OrtAllocator interface. +// No arena or caching layer; every allocation goes directly to CUDA. + #pragma once #include "cuda_plugin_utils.h" @@ -8,11 +13,13 @@ namespace onnxruntime { namespace cuda_plugin { +/// Allocator type: device memory (GPU) or pinned (page-locked host) memory. enum class CudaAllocatorKind { - kDevice, - kPinned, + kDevice, ///< GPU device memory via cudaMalloc + kPinned, ///< Page-locked host memory via cudaHostAlloc }; +/// Base class for CUDA allocators implementing the OrtAllocator C interface. class CudaAllocatorBase : public OrtAllocator { public: explicit CudaAllocatorBase(CudaAllocatorKind kind, const OrtMemoryInfo* memory_info) diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc index 2d066002d7496..62646a210e1a0 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc @@ -14,6 +14,9 @@ namespace plugin { namespace { +/// Determine byte size of a single element for the given ONNX data type. +/// Used by Scan transpose kernel to allocate and copy tensor data. +/// Returns error for sub-byte types (INT4, UINT4) and strings. Status GetTensorElementStorageSize(ONNXTensorElementDataType elem_type, size_t& element_size) { switch (elem_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h index cd662b105973d..2fcf40be06fc2 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h @@ -1,6 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// CUDA data transfer implementation for CPU<->GPU and GPU<->GPU memory copies. +// Implements OrtDataTransferImpl to handle synchronous and async copies +// via cudaMemcpy/cudaMemcpyAsync. + #pragma once #include "cuda_plugin_utils.h" diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index 7b1227408b503..cf43d6c7721cc 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -41,7 +41,10 @@ CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& lo "CUDA Plugin EP created", ORT_FILE, __LINE__, __FUNCTION__)); - // Seed adapter-level runtime options for migrated kernels. + // Store per-EP runtime configuration (TF32, device ID, tuning options, etc.) + // in a global map keyed by OrtEp pointer. Migrated kernels retrieve these + // settings at runtime via GetCudaKernelAdapterRuntimeConfig() without needing + // to thread the config through multiple layers of framework code. onnxruntime::cuda::SetCudaKernelAdapterRuntimeConfigForProvider( static_cast(static_cast(this)), config_.use_tf32, config_.device_id, config_.enable_skip_layer_norm_strict_mode, @@ -73,6 +76,12 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( return nullptr; } + // Three-phase filtering determines which graph nodes run on this EP: + // Phase 1: Collect tentative nodes that have a registered CUDA kernel. + // Phase 2: Filter out CPU-preferred nodes (cheap ops where device-to-host + // copy overhead would exceed the compute benefit). + // Phase 3: Register remaining nodes as supported by this EP. + // Phase 1: Collect tentative nodes — those for which we have a registered kernel. std::vector tentative_nodes; tentative_nodes.reserve(all_nodes.size()); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h index 2562f5667e047..12c9c0e05b106 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h @@ -30,6 +30,8 @@ class CudaEpFactory : public OrtEpFactory { const std::string& GetEpName() const { return ep_name_; } /// Get or create the shared kernel registry for this factory. + /// Lazily created on first call; subsequent calls return the cached instance. + /// Thread-safe: protected by registry_mutex_. OrtStatus* GetKernelRegistryForEp(CudaEp& ep, const OrtKernelRegistry** out_kernel_registry); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h index 347ac0ede0dfa..4014ead5feb03 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h @@ -1,6 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// CUDA stream and event-based synchronization primitives for the plugin EP. +// CudaSyncStream wraps a cudaStream_t plus cuBLAS/cuDNN/cuBLASLt handles. +// CudaSyncNotification wraps a cudaEvent_t for cross-stream synchronization. +// A global stream registry (with TLS-cached lookups) allows migrated kernels +// to obtain their compute handles from a raw cudaStream_t. + #pragma once #include "cuda_plugin_utils.h" @@ -32,6 +38,9 @@ class CudaSyncStream : public OrtSyncStreamImpl { void EnqueueDeferredCPUBuffer(void* cpu_buffer); OrtStatus* InitHandles(); + /// Look up the CudaSyncStream wrapper from a raw cudaStream_t handle. + /// Uses a thread-local TLS cache with a generation counter to avoid lock + /// contention on this hot path (called on every kernel launch). static CudaSyncStream* FromCudaStream(cudaStream_t stream); private: @@ -53,6 +62,9 @@ class CudaSyncStream : public OrtSyncStreamImpl { cudnnHandle_t cudnn_handle_ = nullptr; cublasLtHandle_t cublas_lt_handle_ = nullptr; + // CPU buffers whose deallocation is deferred to OnSessionRunEnd. + // Pinned memory must remain valid until all async device operations that + // reference it have completed, so we synchronize the stream first. std::vector deferred_cpu_buffers_; }; diff --git a/onnxruntime/core/providers/cuda/plugin/provider_api_shims.cc b/onnxruntime/core/providers/cuda/plugin/provider_api_shims.cc index 7e35096039341..2d6851aae07d2 100644 --- a/onnxruntime/core/providers/cuda/plugin/provider_api_shims.cc +++ b/onnxruntime/core/providers/cuda/plugin/provider_api_shims.cc @@ -2,7 +2,10 @@ // Licensed under the MIT License. // Provider API shims used by migrated CUDA kernels. -// Direct implementations — no SHARED_PROVIDER bridge needed. +// Provides direct implementations of utility functions that in-tree kernels +// obtain via the SHARED_PROVIDER bridge (GetEnvironmentVar, floatToHalf, +// halfToFloat). Plugin builds skip SHARED_PROVIDER entirely, so these thin +// wrappers ensure the migrated kernel code compiles and links. #include #include diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index dfa43130868dc..127cfcc557fd5 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -335,6 +335,9 @@ Status PrepareForReduce(const Tensor* X, return Status::OK(); } +// Unified scratch buffer allocation: when a CudaKernel pointer is available +// (plugin build path), use GetScratchBuffer which routes through the adapter +// allocator. Otherwise fall back to IAllocator::MakeUniquePtr (in-tree path). template inline IAllocatorUniquePtr AllocateScratchBuffer( const AllocatorPtr& gpu_allocator, const CudaKernel* kernel, size_t count, void* compute_stream) { diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index f48316ee5e316..73c8433bbefc8 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -115,6 +115,9 @@ namespace cuda { using PadsVector = PadBase::PadsVector; +// In the plugin build, PadBase::ComputePads is not accessible because it +// depends on CPU provider internals. ComputePadsImpl is a minimal inline +// equivalent. Keep in sync with PadBase::ComputePads in pad.h. template static void ComputePadsLocal(KernelContextType& ctx, size_t data_rank, @@ -127,6 +130,9 @@ static void ComputePadsLocal(KernelContextType& ctx, #endif } +// In the plugin build, PadBase::HandleDimValueZero lives in CPU provider code +// that cannot be linked into the plugin. Inline the same validation here. +// Keep in sync with PadBase::HandleDimValueZero in pad.h. static Status HandleDimValueZeroLocal(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) { diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc b/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc index 81745733b16bc..01c7229783b33 100644 --- a/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc @@ -8,8 +8,9 @@ #include "core/providers/cpu/tensor/utils.h" #ifdef BUILD_CUDA_EP_AS_PLUGIN -// In the plugin build, SCATTER_ND_VALIDATE_SHAPES is not accessible -// (it lives in the CPU provider). Provide an inline equivalent. +// PLUGIN BUILD ADAPTATION: SCATTER_ND_VALIDATE_SHAPES is defined in the CPU +// provider (scatter_nd.h) which cannot be linked into the plugin. Inline the +// same validation logic here. Keep in sync with ScatterND::ValidateShapes. namespace onnxruntime { namespace scatter_nd_plugin { inline Status ValidateShapes(const TensorShape& input_shape, diff --git a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h index 57914ba3af321..3a054175db9da 100644 --- a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h +++ b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h @@ -12,9 +12,10 @@ namespace onnxruntime { namespace cuda { #ifdef BUILD_CUDA_EP_AS_PLUGIN -// Plugin-local equivalent of SpaceDepthBase. -// The CPU header cannot be included in the plugin build because it pulls in -// core/framework/op_kernel.h which conflicts with the adapter types. +// PLUGIN BUILD ADAPTATION: SpaceDepthBase (in cpu/tensor/space_depth_ops.h) +// cannot be included because it pulls in core/framework/op_kernel.h which +// conflicts with the adapter types. This inline namespace reimplements the +// validation and dimension-calculation logic. Keep in sync with SpaceDepthBase. namespace detail { template @@ -84,6 +85,8 @@ class SpaceToDepth final : public CudaKernel #endif { #ifdef BUILD_CUDA_EP_AS_PLUGIN + // Plugin builds cannot inherit from SpaceDepthBase, so extract the + // blocksize attribute directly from OpKernelInfo. ORT_ENFORCE(info.GetAttr("blocksize", &blocksize_).IsOK(), "Attribute blocksize is not set."); #endif @@ -124,6 +127,8 @@ class DepthToSpace final : public CudaKernel #endif { #ifdef BUILD_CUDA_EP_AS_PLUGIN + // Plugin builds cannot inherit from SpaceDepthBase, so extract the + // blocksize attribute directly from OpKernelInfo. ORT_ENFORCE(info.GetAttr("blocksize", &blocksize_).IsOK(), "Attribute blocksize is not set."); #endif diff --git a/onnxruntime/core/providers/cuda/tensor/split.cc b/onnxruntime/core/providers/cuda/tensor/split.cc index 9322c261761a7..06b0c7e50f919 100644 --- a/onnxruntime/core/providers/cuda/tensor/split.cc +++ b/onnxruntime/core/providers/cuda/tensor/split.cc @@ -42,6 +42,9 @@ ONNX_OPERATOR_KERNEL_EX(Split, .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), Split_18); +// Self-contained split preparation that replaces SplitBase::PrepareForCompute. +// The base class method is not available in plugin builds because SplitBase +// depends on CPU provider internals. Keep logic in sync with SplitBase. Status SplitKernel::PrepareForComputeLocal(const TensorShape& input_shape, int num_outputs, int64_t& axis, diff --git a/onnxruntime/core/providers/cuda/tensor/tile.cc b/onnxruntime/core/providers/cuda/tensor/tile.cc index dc33fcb286acf..b07ca4f61cca0 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile.cc +++ b/onnxruntime/core/providers/cuda/tensor/tile.cc @@ -12,6 +12,9 @@ namespace cuda { namespace { #ifdef BUILD_CUDA_EP_AS_PLUGIN +// PLUGIN BUILD ADAPTATION: TileOp::IsTileMemcpy (CPU provider) cannot be +// linked into the plugin. Reimplement the memcpy fast-path check here. +// Keep in sync with TileOp::IsTileMemcpy in cpu/tensor/tile.cc. bool IsTileMemcpyForPlugin(const TensorShape& input_shape, const int64_t* repeats, size_t rank, diff --git a/onnxruntime/core/providers/cuda/tensor/unsqueeze.cc b/onnxruntime/core/providers/cuda/tensor/unsqueeze.cc index d4080b7ef49e5..64f951c70fe15 100644 --- a/onnxruntime/core/providers/cuda/tensor/unsqueeze.cc +++ b/onnxruntime/core/providers/cuda/tensor/unsqueeze.cc @@ -8,6 +8,10 @@ namespace cuda { namespace { +// PLUGIN BUILD ADAPTATION: PrepareCompute() is inherited from UnsqueezeBase +// in the non-plugin build, but the base class cannot be used in plugin builds +// because it depends on core/framework/op_kernel.h internals. This standalone +// function reimplements the same axes-parsing and output-shape computation. Status PrepareComputeForPlugin(OpKernelContext* ctx, UnsqueezeBase::Prepare& p, const TensorShapeVector& axes_attr) { const auto* input = ctx->Input(0); ORT_ENFORCE(input != nullptr); diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index 7ba2eed09353d..36e89cd38e72b 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -106,6 +106,9 @@ Status Upsample::BaseCompute(OpKernelContext* context, if (antialias_) { #ifdef BUILD_CUDA_EP_AS_PLUGIN + // Plugin builds copy the lookup table to a per-call scratch buffer because + // plugin kernels cannot safely hold persistent device pointers across sessions. + // Non-plugin builds cache the table in a member IAllocator::UniquePtr. const uint8_t* lookup_table = GetLookupTableShared(); auto shared_lookup_table_ondevice_buffer = GetScratchBuffer(kLookupTableSize, GetComputeStream(context)); CUDA_CALL_THROW(cudaMemcpyAsync(shared_lookup_table_ondevice_buffer.get(), lookup_table, kLookupTableSize, From 45ed1c1d5dd753e4e8fcae255557b49c044e5a75 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 25 Mar 2026 08:55:55 -0700 Subject: [PATCH 19/48] refine --- include/onnxruntime/ep/adapter/node.h | 3 +++ .../contrib_ops/cuda/bert/attention.cc | 12 ----------- onnxruntime/contrib_ops/cuda/moe/moe.cc | 15 ------------- .../cuda/quantization/moe_quantization.cc | 14 ------------- .../cuda/plugin/cuda_allocator_plugin.h | 4 ++-- .../providers/cuda/plugin/cuda_ep_factory.cc | 8 +++++-- .../providers/cuda/plugin/cuda_ep_factory.h | 3 ++- .../cuda/plugin/cuda_kernel_adapter.h | 21 +++++++++++++++---- .../python/onnxruntime_pybind_module.cc | 7 +++++-- onnxruntime/test/unittest_util/base_tester.cc | 3 ++- .../unittest_util/test_dynamic_plugin_ep.h | 2 ++ 11 files changed, 39 insertions(+), 53 deletions(-) diff --git a/include/onnxruntime/ep/adapter/node.h b/include/onnxruntime/ep/adapter/node.h index 17513d3a14dfa..91aff7d670b2f 100644 --- a/include/onnxruntime/ep/adapter/node.h +++ b/include/onnxruntime/ep/adapter/node.h @@ -43,6 +43,9 @@ struct Node { /** Gets whether an output exists or is an omitted optional output. */ bool OutputExists(size_t index) const { + // KernelInfo_GetOutputName returns an empty string for omitted optional + // outputs, which lets adapter consumers mirror NodeArg::Exists() without + // pulling in full NodeArg metadata. return index < OutputCount() && !kernel_info_.GetOutputName(index).empty(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 3e6a78775f4a2..83b6237dcc2a6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -240,18 +240,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); int m = batch_size * sequence_length; int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size); int k = parameters.input_hidden_size; -#ifdef BUILD_CUDA_EP_AS_PLUGIN - // Plugin build: use GetScratchBuffer (adapter-compatible) instead of - // IAllocator::MakeUniquePtr which requires the full allocator interface. IAllocatorUniquePtr gemm_buffer = GetScratchBuffer(static_cast(m * n) * sizeof(T), GetComputeStream(context)); -#else - IAllocatorUniquePtr gemm_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(m * n) * sizeof(T), false, context->GetComputeStream()); -#endif CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); @@ -283,11 +275,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_memory_efficient_attention, use_cudnn_flash_attention, false); -#ifdef BUILD_CUDA_EP_AS_PLUGIN IAllocatorUniquePtr work_space = GetScratchBuffer(workSpaceSize, GetComputeStream(context)); -#else - IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); -#endif data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); if (nullptr != bias) { diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index e088e5241cc93..ffd1b219da03c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -66,7 +66,6 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int); -#ifdef BUILD_CUDA_EP_AS_PLUGIN IAllocatorUniquePtr work_space = this->template GetScratchBuffer(ws_size, this->GetComputeStream(context)); IAllocatorUniquePtr fc2_output = this->template GetScratchBuffer(fc2_output_size, this->GetComputeStream(context)); IAllocatorUniquePtr expert_scales = this->template GetScratchBuffer(expert_scales_size, this->GetComputeStream(context)); @@ -74,20 +73,6 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { this->template GetScratchBuffer(expanded_source_row_to_expanded_dest_row_size, this->GetComputeStream(context)); IAllocatorUniquePtr expert_for_source_row = this->template GetScratchBuffer(expert_for_source_row_size, this->GetComputeStream(context)); -#else - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - - // TODO: allocate one buffer and reuse it. - IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, context->GetComputeStream()); - IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, context->GetComputeStream()); - IAllocatorUniquePtr expert_scales = - IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, context->GetComputeStream()); - IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = - IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, context->GetComputeStream()); - IAllocatorUniquePtr expert_for_source_row = - IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, context->GetComputeStream()); -#endif const CudaT* fc_scales_ptr = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index 70f4690e60b92..4b261346887f6 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -76,7 +76,6 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int); -#ifdef BUILD_CUDA_EP_AS_PLUGIN IAllocatorUniquePtr work_space = this->template GetScratchBuffer(ws_size, this->GetComputeStream(context)); IAllocatorUniquePtr fc2_output = this->template GetScratchBuffer(fc2_output_size, this->GetComputeStream(context)); IAllocatorUniquePtr expert_scales = this->template GetScratchBuffer(expert_scales_size, this->GetComputeStream(context)); @@ -84,19 +83,6 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, this->template GetScratchBuffer(expanded_source_row_to_expanded_dest_row_size, this->GetComputeStream(context)); IAllocatorUniquePtr expert_for_source_row = this->template GetScratchBuffer(expert_for_source_row_size, this->GetComputeStream(context)); -#else - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - - IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, context->GetComputeStream()); - IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, context->GetComputeStream()); - IAllocatorUniquePtr expert_scales = - IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, context->GetComputeStream()); - IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = - IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, context->GetComputeStream()); - IAllocatorUniquePtr expert_for_source_row = - IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, context->GetComputeStream()); -#endif moe_runner.run_moe_fc( reinterpret_cast(input->template Data()), diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h index 8a136d657b661..8b0d41cad6541 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h @@ -15,8 +15,8 @@ namespace cuda_plugin { /// Allocator type: device memory (GPU) or pinned (page-locked host) memory. enum class CudaAllocatorKind { - kDevice, ///< GPU device memory via cudaMalloc - kPinned, ///< Page-locked host memory via cudaHostAlloc + kDevice, ///< GPU device memory via cudaMalloc + kPinned, ///< Page-locked host memory via cudaHostAlloc }; /// Base class for CUDA allocators implementing the OrtAllocator C interface. diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index dea3b75078a1a..eb55a09c024b1 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -126,7 +126,10 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( continue; // Skip non-NVIDIA GPUs } - int32_t current_device_id = cuda_device_index++; + // CUDA uses contiguous ordinals for CUDA-visible NVIDIA devices. Build that + // mapping from the filtered hardware-device list instead of relying on the + // ORT hardware device id, which is not guaranteed to be a CUDA ordinal. + const int32_t current_device_id = cuda_device_index++; { std::lock_guard lock(factory->device_map_mutex_); @@ -204,7 +207,8 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( if (num_devices != 1) { return factory->ort_api_.CreateStatus( ORT_INVALID_ARGUMENT, - "CUDA EP factory currently supports only one device at a time."); + "CUDA EP factory currently supports exactly one device per EP instance. " + "Pass a single OrtHardwareDevice when creating the CUDA plugin EP."); } if (devices == nullptr || devices[0] == nullptr) { return factory->ort_api_.CreateStatus( diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h index 12c9c0e05b106..d88bf1ac9b647 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h @@ -97,7 +97,8 @@ class CudaEpFactory : public OrtEpFactory { std::vector cached_memory_infos_; int device_id_ = 0; - // Map ORT hardware device pointers to internal CUDA ordinal indices + // Map ORT hardware device pointers to CUDA ordinals for the NVIDIA devices + // visible to the CUDA runtime. std::mutex device_map_mutex_; std::unordered_map hw_device_to_cuda_index_; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 19b10f57c40d5..9cabf20a83a36 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -423,6 +423,9 @@ struct DefaultCudaHandles { }; inline DefaultCudaHandles& GetDefaultCudaHandlesForDevice(int device_id) { + // Fallback handles are only used for code paths that need cuBLAS/cuDNN + // without an active CudaSyncStream. Keep them thread-local so they are not + // shared across callers that may use the libraries concurrently. thread_local std::unordered_map handles_by_device; auto [it, inserted] = handles_by_device.try_emplace(device_id); if (inserted) { @@ -826,20 +829,30 @@ class CudaKernel : public OpKernel { inline IAllocatorUniquePtr GetScratchBuffer(size_t cnt, void* s) const { if (cnt == 0) return IAllocatorUniquePtr(nullptr, [](T*) {}); size_t sz = detail::BytesForCount(cnt, detail::SizeOf::value); + if (sz == 0) { + ORT_THROW("CUDA scratch buffer allocation size overflow for ", cnt, " elements"); + } + void* p = nullptr; cudaError_t alloc_result = cudaSuccess; + bool used_async_alloc = false; if (s) { alloc_result = cudaMallocAsync(&p, sz, static_cast(s)); - if (alloc_result == cudaErrorNotSupported || alloc_result == cudaErrorInvalidValue) { + used_async_alloc = (alloc_result == cudaSuccess); + if (!used_async_alloc && (alloc_result == cudaErrorNotSupported || alloc_result == cudaErrorInvalidValue)) { alloc_result = cudaMalloc(&p, sz); } } else { alloc_result = cudaMalloc(&p, sz); } - if (alloc_result != cudaSuccess) return IAllocatorUniquePtr(nullptr, [](T*) {}); - return IAllocatorUniquePtr(static_cast(p), [s](T* ptr) { + + if (alloc_result != cudaSuccess) { + ORT_THROW("CUDA scratch buffer allocation failed for ", sz, " bytes: ", cudaGetErrorString(alloc_result)); + } + + return IAllocatorUniquePtr(static_cast(p), [s, used_async_alloc](T* ptr) { if (ptr) { - if (s) { + if (used_async_alloc && s) { cudaError_t free_result = cudaFreeAsync(ptr, static_cast(s)); if (free_result == cudaErrorNotSupported || free_result == cudaErrorInvalidValue) { cudaFree(ptr); diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc index 4eaa057a68cff..ebe027c9efa95 100644 --- a/onnxruntime/python/onnxruntime_pybind_module.cc +++ b/onnxruntime/python/onnxruntime_pybind_module.cc @@ -117,13 +117,16 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { "get_available_providers", []() -> std::vector { auto available = GetAvailableExecutionProviderNames(); #if !defined(ORT_MINIMAL_BUILD) + const auto& ep_devices = GetEnv().GetOrtEpDevices(); + available.reserve(available.size() + ep_devices.size()); + InlinedHashSet existing; - existing.reserve(available.size()); + existing.reserve(available.size() + ep_devices.size()); for (const auto& ep_name : available) { existing.insert(ep_name); } - for (const OrtEpDevice* ep_device : GetEnv().GetOrtEpDevices()) { + for (const OrtEpDevice* ep_device : ep_devices) { if (!ep_device) { continue; } diff --git a/onnxruntime/test/unittest_util/base_tester.cc b/onnxruntime/test/unittest_util/base_tester.cc index 6796fba2c9a01..6622960a57680 100644 --- a/onnxruntime/test/unittest_util/base_tester.cc +++ b/onnxruntime/test/unittest_util/base_tester.cc @@ -44,7 +44,8 @@ void DebugTrap() { bool ShouldRouteCudaToDynamicPluginEp(const std::optional& dynamic_plugin_ep_name) { // Route CUDA requests to the CUDA plugin EP when unit test main has initialized // dynamic plugin EP infrastructure with the CUDA plugin registration. - return dynamic_plugin_ep_name.has_value() && *dynamic_plugin_ep_name == "CudaPluginExecutionProvider"; + return dynamic_plugin_ep_name.has_value() && + *dynamic_plugin_ep_name == dynamic_plugin_ep_infra::kCudaPluginExecutionProviderName; } } // namespace diff --git a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h index 0962df8e35308..b5bf075a25447 100644 --- a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h +++ b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h @@ -29,6 +29,8 @@ namespace test { // unit testing infrastructure. namespace dynamic_plugin_ep_infra { +inline constexpr std::string_view kCudaPluginExecutionProviderName{"CudaPluginExecutionProvider"}; + // Note: `Initialize()` and `Shutdown()` are not thread-safe. // They should be called before and after calls to most of the other functions in this namespace. // The exception to this is `ParseInitializationConfig()`, which may be called before `Initialize()`. From b35226c8595edc8fb98408df13ec3856da2ea14c Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 25 Mar 2026 11:30:45 -0700 Subject: [PATCH 20/48] fix Windows build --- cmake/external/cuda_configuration.cmake | 14 ++++++ cmake/onnxruntime_providers_cuda_plugin.cmake | 45 ++++++++++++++++++- .../onnxruntime/ep/adapter/op_kernel_info.h | 9 ++-- .../cuda/bert/xqa/xqa_impl_gen.cuh | 23 ++++------ 4 files changed, 70 insertions(+), 21 deletions(-) diff --git a/cmake/external/cuda_configuration.cmake b/cmake/external/cuda_configuration.cmake index d8378c934f0cc..df180d185a268 100644 --- a/cmake/external/cuda_configuration.cmake +++ b/cmake/external/cuda_configuration.cmake @@ -161,6 +161,20 @@ macro(setup_cuda_architectures) set(CMAKE_CUDA_ARCHITECTURES_ORIG "${CMAKE_CUDA_ARCHITECTURES}") message(STATUS "GPU architectures: ${CMAKE_CUDA_ARCHITECTURES_ORIG}") + unset(ORT_HAS_SM80_OR_LATER) + foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES_ORIG) + if(CUDA_ARCH MATCHES "^([0-9]+)") + if(CMAKE_MATCH_1 GREATER_EQUAL 80) + set(ORT_HAS_SM80_OR_LATER ON) + break() + endif() + endif() + endforeach() + + if(ORT_HAS_SM80_OR_LATER) + add_definitions("-DHAS_SM80_OR_LATER") + endif() + set(ARCHITECTURES_WITH_KERNELS "80" "86" "89" "90" "100" "110" "120") foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS) if(NOT "${CUDA_ARCH}" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index 6814374b2d3ed..4c5a3b36b548b 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -116,9 +116,11 @@ onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_plugin ${CUDA_PLUGIN_EP_CC_SRCS} ${CUDA_PLUGIN_EP_CU_SRCS} ) -# Set CUDA standard and flags +# Keep the plugin CUDA target aligned with the repo-wide C++20 baseline. +# Forcing CUDA C++17 here breaks newer protobuf/absl headers used by the plugin +# build, as absl::compare expects standard ordering support in this configuration. set_target_properties(onnxruntime_providers_cuda_plugin PROPERTIES - CUDA_STANDARD 17 + CUDA_STANDARD 20 CUDA_STANDARD_REQUIRED ON ) @@ -133,15 +135,54 @@ target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE # receive the ORT-framework force-include (it conflicts with cute::Tensor etc.). # cuda_plugin_kernels.cu already #include "cuda_kernel_adapter.h" directly. # Op-registration .cc files do not include it directly, so they need it here. + # Force NVCC onto C++20 explicitly. With the VS generator the CUDA standard + # property alone still leaves `-std=c++17` in AdditionalOptions. # Suppress NVCC cudafe warnings: # 550 - variable set but never used (in adapter headers) # 2810 - [[nodiscard]] false positive on Status assignments in op_kernel.h / kernel_registry.h + "$<$:SHELL:--std c++20>" "$<$:--expt-relaxed-constexpr;-Xcudafe;--diag_suppress=550>" "$<$:SHELL:-Xcudafe --diag_suppress=2810>" "$<$:-include;${REPO_ROOT}/include/onnxruntime/ep/adapters.h>" "$<$:SHELL:-include ${CUDA_PLUGIN_EP_DIR}/cuda_kernel_adapter.h>" ) +if (MSVC) + target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE + "$<$:SHELL:-Xcompiler /permissive>" + "$<$:SHELL:-Xcompiler /wd4834>" + "$<$:SHELL:-Xcompiler /wd4127>" + "$<$:SHELL:-Xcompiler /wd4211>" + "$<$:SHELL:-Xcompiler /Zc:__cplusplus>" + "$<$:SHELL:-Xcompiler /bigobj>" + ) + + target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE + "$<$:/wd4127>" + ) +endif() + +# Mirror the core CUDA provider's CUDA 12.8+ NVCC workarounds so the plugin +# target handles stricter cudafe diagnostics consistently. +set(onnxruntime_plugin_nvcc_threads "1") +target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE + "$<$:SHELL:--threads \"${onnxruntime_plugin_nvcc_threads}\">" + "$<$:--diag-suppress=177>" +) + +if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) + target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE + "$<$:--static-global-template-stub=false>" + "$<$:--diag-suppress=221>" + ) + + if (MSVC) + target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE + "$<$:SHELL:-Xcompiler /wd4505>" + ) + endif() +endif() + include(cudnn_frontend) include(cutlass) diff --git a/include/onnxruntime/ep/adapter/op_kernel_info.h b/include/onnxruntime/ep/adapter/op_kernel_info.h index 644cb30788ec6..638beabb8cbf0 100644 --- a/include/onnxruntime/ep/adapter/op_kernel_info.h +++ b/include/onnxruntime/ep/adapter/op_kernel_info.h @@ -9,6 +9,7 @@ #include +#include "core/common/narrow.h" #include "core/common/status.h" #include "core/framework/config_options.h" #include "core/framework/tensor_shape.h" @@ -42,11 +43,11 @@ struct OpKernelInfo { struct KernelInfoCache { explicit KernelInfoCache(const OrtKernelInfo* kernel_info) : kernel_info_(kernel_info) { Ort::ConstKernelInfo info{kernel_info}; - const int input_count = info.GetInputCount(); + const size_t input_count = info.GetInputCount(); constant_input_tensors.resize(input_count); - for (int i = 0; i < input_count; ++i) { + for (size_t i = 0; i < input_count; ++i) { int is_constant = 0; - Ort::ConstValue const_input = info.GetTensorConstantInput(i, &is_constant); + Ort::ConstValue const_input = info.GetTensorConstantInput(gsl::narrow_cast(i), &is_constant); if (is_constant && const_input != nullptr && const_input.IsTensor()) { constant_input_tensors[i] = CreateTensorFromApiValue(const_cast(static_cast(const_input))); } @@ -85,7 +86,7 @@ struct OpKernelInfo { } int GetInputCount() const noexcept { - return info_.GetInputCount(); + return gsl::narrow_cast(info_.GetInputCount()); } const std::vector& GetConstantInputTensors() const noexcept { diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh index b1a6badc6b3f1..d132fba85988c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh @@ -17,26 +17,19 @@ namespace NAMESPACE_NAME { // XQA kernels require SM80+ (Ampere or newer). We need a guard that works correctly // during both host and device compilation passes: // - Device pass: __CUDA_ARCH__ is defined, check it directly. -// - Host pass: __CUDA_ARCH__ is NOT defined. Use __CUDA_ARCH_LIST__ (available since -// CUDA 11.5), which nvcc defines as a comma-separated ascending list of target -// architectures (e.g. 750,800). In #if, the comma operator returns the rightmost -// (highest) value, so __CUDA_ARCH_LIST__ >= 800 checks the max target arch. -// - Non-nvcc (e.g. IDE parser): cuda_hint.cuh defines __CUDA_ARCH__ 900, which -// takes the first branch. -// Using !defined(__CUDA_ARCH__) here would be WRONG: it always evaluates true during -// the host pass, causing the kernel to be declared even when no SM80+ device code -// exists. CUDA 13+ then fails to generate a host stub, producing C2129 / LNK2001. +// - Host pass: rely on HAS_SM80_OR_LATER from cmake/external/cuda_configuration.cmake. +// If any SM80+ arch is enabled, the host stub must be emitted. +// - Non-nvcc parsers usually won't see the CMake-provided define, so keep editor parsing +// intact by taking the fallback branch when __CUDACC__ is not defined. +// Using only !defined(__CUDA_ARCH__) here would be WRONG: it always evaluates true during +// the host pass, causing the kernel to be declared even when no SM80+ device code exists. +// CUDA 13+ then fails to generate a host stub, producing C2129 / LNK2001. #undef XQA_HAS_SM80_TARGET #ifdef __CUDA_ARCH__ #if __CUDA_ARCH__ >= 800 #define XQA_HAS_SM80_TARGET 1 #endif -#elif defined(__CUDA_ARCH_LIST__) -#if __CUDA_ARCH_LIST__ >= 800 -#define XQA_HAS_SM80_TARGET 1 -#endif -#else -// Non-nvcc fallback: assume supported (IDE parsers, etc.) +#elif defined(HAS_SM80_OR_LATER) || !defined(__CUDACC__) #define XQA_HAS_SM80_TARGET 1 #endif From c8341c39977c9ffe81bcb4b56522d9a608ca6284 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 25 Mar 2026 16:07:13 -0700 Subject: [PATCH 21/48] refine CleanupDeferredCPUBuffers etc. --- .../cuda/plugin/cuda_data_transfer_plugin.cc | 5 + .../cuda/plugin/cuda_plugin_kernels.cu | 2 +- .../cuda/plugin/cuda_stream_plugin.cc | 47 ++++- .../cuda/plugin/cuda_stream_plugin.h | 2 +- .../python/onnxruntime_pybind_state.cc | 163 +++++++++++++----- 5 files changed, 172 insertions(+), 47 deletions(-) diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc index 37810e0c9d6d8..0b248438a421d 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc @@ -36,6 +36,11 @@ CudaDataTransfer::CudaDataTransfer(const OrtApi& ort_api, const OrtEpApi& ep_api bool src_is_gpu = (src_type == OrtMemoryInfoDeviceType_GPU); bool dst_is_gpu = (dst_type == OrtMemoryInfoDeviceType_GPU); + if ((src_is_gpu && ep_api.MemoryDevice_GetVendorId(src_device) != OrtDevice::VendorIds::NVIDIA) || + (dst_is_gpu && ep_api.MemoryDevice_GetVendorId(dst_device) != OrtDevice::VendorIds::NVIDIA)) { + return false; + } + // Support CPU→GPU, GPU→CPU, GPU→GPU return (src_is_cpu && dst_is_gpu) || (src_is_gpu && dst_is_cpu) || diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu index 822ea3d5fa72f..684784400cbeb 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu +++ b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu @@ -45,7 +45,7 @@ OrtStatus* CreateCudaKernelRegistry(const OrtEpApi& /*ep_api*/, for (auto build_fn : entries) { ::onnxruntime::ep::adapter::KernelCreateInfo info = build_fn(); if (info.kernel_def != nullptr) { // filter the BuildKernelCreateInfo sentinel - (void)registry.Register(std::move(info)); + ORT_THROW_IF_ERROR(registry.Register(std::move(info))); } } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc index 0339f8add1bfd..0764b8931dc9e 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc @@ -53,14 +53,40 @@ CudaSyncStream::CudaSyncStream(CudaEpFactory& factory, int device_id, } CudaSyncStream::~CudaSyncStream() { - CleanupDeferredCPUBuffers(); + if (!deferred_cpu_buffers_.empty()) { + if (cuda_stream_ != nullptr) { + OrtStatus* status = OnSessionRunEndImpl(this); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + } + } else { + OrtStatus* status = CleanupDeferredCPUBuffers(); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + } + } + } if (cuda_stream_) UnregisterStream(cuda_stream_); if (cublas_handle_) cublasDestroy(cublas_handle_); if (cudnn_handle_) cudnnDestroy(cudnn_handle_); if (cublas_lt_handle_) cublasLtDestroy(cublas_lt_handle_); - if (cuda_stream_) cudaStreamDestroy(cuda_stream_); + if (cuda_stream_) { + auto destroy_result = cudaStreamDestroy(cuda_stream_); + if (destroy_result == cudaSuccess && !deferred_cpu_buffers_.empty()) { + // Fallback: we only reach here when the earlier cudaStreamSynchronize in + // OnSessionRunEndImpl failed, leaving some buffers un-freed. + // cudaStreamDestroy on a non-blocking stream returns immediately (async + // cleanup), so in-flight ops may still reference these buffers. However, + // a prior sync failure indicates a serious CUDA error, so best-effort + // cleanup is the most we can do here. + OrtStatus* status = CleanupDeferredCPUBuffers(); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + } + } + } } OrtStatus* CudaSyncStream::InitHandles() { @@ -84,11 +110,18 @@ void CudaSyncStream::EnqueueDeferredCPUBuffer(void* cpu_buffer) { deferred_cpu_buffers_.push_back(cpu_buffer); } -void CudaSyncStream::CleanupDeferredCPUBuffers() { +OrtStatus* CudaSyncStream::CleanupDeferredCPUBuffers() noexcept { + OrtStatus* first_error = nullptr; for (void* buf : deferred_cpu_buffers_) { - cudaFreeHost(buf); + cudaError_t err = cudaFreeHost(buf); + if (err != cudaSuccess && first_error == nullptr) { + first_error = Ort::GetApi().CreateStatus( + ORT_EP_FAIL, + (std::string("CUDA error: ") + cudaGetErrorName(err) + ": " + cudaGetErrorString(err)).c_str()); + } } deferred_cpu_buffers_.clear(); + return first_error; } /*static*/ void* ORT_API_CALL CudaSyncStream::GetHandleImpl(OrtSyncStreamImpl* this_ptr) noexcept { @@ -114,11 +147,13 @@ void CudaSyncStream::CleanupDeferredCPUBuffers() { /*static*/ OrtStatus* ORT_API_CALL CudaSyncStream::OnSessionRunEndImpl(OrtSyncStreamImpl* this_ptr) noexcept { auto* stream = static_cast(this_ptr); + if (stream->cuda_stream_ == nullptr) { + return stream->CleanupDeferredCPUBuffers(); + } // Synchronize before releasing deferred CPU buffers to ensure // all async copies using those buffers have completed. PL_CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream->cuda_stream_)); - stream->CleanupDeferredCPUBuffers(); - return nullptr; + return stream->CleanupDeferredCPUBuffers(); } /*static*/ void ORT_API_CALL CudaSyncStream::ReleaseImpl(OrtSyncStreamImpl* this_ptr) noexcept { diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h index 4014ead5feb03..60d85c043bf67 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h @@ -53,7 +53,7 @@ class CudaSyncStream : public OrtSyncStreamImpl { static OrtStatus* ORT_API_CALL OnSessionRunEndImpl(OrtSyncStreamImpl* this_ptr) noexcept; static void ORT_API_CALL ReleaseImpl(OrtSyncStreamImpl* this_ptr) noexcept; - void CleanupDeferredCPUBuffers(); + OrtStatus* CleanupDeferredCPUBuffers() noexcept; CudaEpFactory& factory_; int device_id_; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index d60735ed5f621..2b68d7b2c0cad 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -573,6 +573,49 @@ void RegisterNvTensorRTRtxPluginsAsCustomOps(PySessionOptions& so, const Provide } #endif +#if !defined(ORT_MINIMAL_BUILD) +// Find a registered plugin EP device matching the given EP name and optional device_id from provider options. +// Returns nullptr if no matching device is found. +static const OrtEpDevice* FindRegisteredPluginEpDevice( + const std::string& ep_name, + const ProviderOptions* provider_options) { + const auto& ep_devices = GetEnv().GetOrtEpDevices(); + if (ep_devices.empty()) { + return nullptr; + } + + bool has_requested_device_id = false; + int requested_device_id = 0; + if (provider_options != nullptr) { + if (const auto device_id_it = provider_options->find("device_id"); device_id_it != provider_options->end()) { + try { + requested_device_id = std::stoi(device_id_it->second); + has_requested_device_id = requested_device_id >= 0; + } catch (const std::exception&) { + // Invalid device_id values are logged by callers when appropriate. + } + } + } + + for (const OrtEpDevice* ep_device : ep_devices) { + if (!ep_device || ep_device->ep_name != ep_name) { + continue; + } + + if (has_requested_device_id) { + Ort::ConstEpDevice current_device(ep_device); + if (static_cast(current_device.Device().DeviceId()) != requested_device_id) { + continue; + } + } + + return ep_device; + } + + return nullptr; +} +#endif + /** * Creates an IExecutionProviderFactory instance of the specified type. * @param session_options The session options. @@ -586,56 +629,35 @@ static std::shared_ptr CreateExecutionProviderFactory const std::string& type, const ProviderOptionsMap& provider_options_map) { #if !defined(ORT_MINIMAL_BUILD) - auto try_create_registered_plugin_factory = [&]() -> std::shared_ptr { - const auto& ep_devices = GetEnv().GetOrtEpDevices(); - if (ep_devices.empty()) { - return nullptr; - } + auto get_registered_plugin_ep_devices = [&]() -> InlinedVector { + InlinedVector selected_devices; - bool has_requested_device_id = false; - int requested_device_id = 0; + const ProviderOptions* provider_options = nullptr; if (const auto provider_it = provider_options_map.find(type); provider_it != provider_options_map.end()) { - if (const auto device_id_it = provider_it->second.find("device_id"); device_id_it != provider_it->second.end()) { - try { - requested_device_id = std::stoi(device_id_it->second); - has_requested_device_id = requested_device_id >= 0; - } catch (const std::exception& ex) { - LOGS_DEFAULT(WARNING) << "Ignoring invalid device_id provider option '" << device_id_it->second - << "' for registered plugin EP '" << type << "': " << ex.what(); - } - } + provider_options = &provider_it->second; } - const OrtEpDevice* selected_device = nullptr; - for (const OrtEpDevice* ep_device : ep_devices) { - if (!ep_device || ep_device->ep_name != type) { - continue; - } - - if (has_requested_device_id) { - Ort::ConstEpDevice current_device(ep_device); - if (static_cast(current_device.Device().DeviceId()) != requested_device_id) { - continue; + const OrtEpDevice* selected_device = FindRegisteredPluginEpDevice(type, provider_options); + if (selected_device == nullptr) { + if (provider_options != nullptr) { + if (const auto device_id_it = provider_options->find("device_id"); device_id_it != provider_options->end()) { + LOGS_DEFAULT(WARNING) << "No registered plugin EP device found for '" << type + << "' with device_id=" << device_id_it->second; } } - - if (selected_device == nullptr) { - selected_device = ep_device; - break; - } + return selected_devices; } - if (selected_device == nullptr) { - if (has_requested_device_id) { - LOGS_DEFAULT(WARNING) << "No registered plugin EP device found for '" << type - << "' with device_id=" << requested_device_id; - } + selected_devices.push_back(selected_device); + return selected_devices; + }; + + auto try_create_registered_plugin_factory = [&]() -> std::shared_ptr { + auto selected_devices = get_registered_plugin_ep_devices(); + if (selected_devices.empty()) { return nullptr; } - InlinedVector selected_devices; - selected_devices.push_back(selected_device); - std::unique_ptr ep_factory; const auto status = onnxruntime::CreateIExecutionProviderFactoryForEpDevices(GetEnv(), selected_devices, ep_factory); if (!status.IsOK()) { @@ -1314,6 +1336,44 @@ std::unique_ptr CreateExecutionProviderInstance(const Sessio const auto& default_logger = GetEnv().GetLoggingManager()->DefaultLogger(); OrtSessionOptions ort_session_options; ort_session_options.value = session_options; + +#if !defined(ORT_MINIMAL_BUILD) + auto add_registered_plugin_ep_options_to_session = [&]() -> Status { + const ProviderOptions* provider_options = nullptr; + if (const auto provider_it = provider_options_map.find(type); provider_it != provider_options_map.end()) { + provider_options = &provider_it->second; + } + + if (provider_options == nullptr || provider_options->empty()) { + return Status::OK(); + } + + const OrtEpDevice* selected_device = FindRegisteredPluginEpDevice(type, provider_options); + if (selected_device == nullptr) { + return Status::OK(); + } + + InlinedVector selected_devices; + selected_devices.push_back(selected_device); + + std::vector ep_option_keys; + std::vector ep_option_vals; + ep_option_keys.reserve(provider_options->size()); + ep_option_vals.reserve(provider_options->size()); + for (const auto& [key, val] : *provider_options) { + ep_option_keys.push_back(key.c_str()); + ep_option_vals.push_back(val.c_str()); + } + + return AddEpOptionsToSessionOptions(selected_devices, ep_option_keys, ep_option_vals, ort_session_options.value); + }; + + auto status = add_registered_plugin_ep_options_to_session(); + if (!status.IsOK()) { + ORT_THROW("Error applying registered plugin EP options: ", status); + } +#endif + return ep_factory->CreateProvider(ort_session_options, *default_logger.ToExternal()); } return nullptr; @@ -1336,6 +1396,31 @@ static Status AddExplicitEpFactory(PySessionOptions& py_sess_options, const std: return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to add provider of type '", provider_type, "' to SessionOptions. Provider configuration is not supported."); } + +#if !defined(ORT_MINIMAL_BUILD) + if (!provider_options.empty()) { + const OrtEpDevice* selected_device = FindRegisteredPluginEpDevice(provider_type, &provider_options); + if (selected_device != nullptr) { + InlinedVector selected_devices; + selected_devices.push_back(selected_device); + + std::vector ep_option_keys; + std::vector ep_option_vals; + ep_option_keys.reserve(provider_options.size()); + ep_option_vals.reserve(provider_options.size()); + for (const auto& [key, val] : provider_options) { + ep_option_keys.push_back(key.c_str()); + ep_option_vals.push_back(val.c_str()); + } + + ORT_RETURN_IF_ERROR(AddEpOptionsToSessionOptions(selected_devices, + ep_option_keys, + ep_option_vals, + py_sess_options.value)); + } + } +#endif + py_sess_options.provider_factories.push_back(std::move(ep_factory)); return Status::OK(); } From 3c7e3e0c86a5daf0a86c4f50baf31dc46ba0f15e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 26 Mar 2026 10:24:08 -0700 Subject: [PATCH 22/48] CUDA Plugin EP: Test Coverage & Bug Fixes (#27817) ## Summary - Adds comprehensive test suite for the CUDA Plugin EP (`test_cuda_plugin_ep.py`) covering 5 stages: registration, ONNX ops, NHWC layout preference, contrib ops, and op-level validation - Adds `cuda_plugin_ep_helper.py` utility for transparently routing existing tests to the plugin EP - Fixes `test_gqa.py`: corrects `total_sequence_length` tensor placement from CUDA to CPU (was causing failures under the plugin EP's stricter memory layout) and routes tests through plugin EP - Updates `test_moe_cuda.py` to route through plugin EP when available - Fixes temp file collision risk in `_run_model_test` by using `tempfile.NamedTemporaryFile` --- .../transformers/cuda_plugin_ep_helper.py | 166 ++++ .../transformers/test_cuda_plugin_ep.py | 855 ++++++++++++++++++ .../test/python/transformers/test_gqa.py | 16 +- .../test/python/transformers/test_moe_cuda.py | 15 +- 4 files changed, 1043 insertions(+), 9 deletions(-) create mode 100644 onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py create mode 100644 onnxruntime/test/python/transformers/test_cuda_plugin_ep.py diff --git a/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py b/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py new file mode 100644 index 0000000000000..665f1d6828202 --- /dev/null +++ b/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import sys +from importlib.metadata import PackageNotFoundError, distribution +from pathlib import Path + +import onnxruntime as onnxrt +from onnxruntime import get_build_info + + +class _CudaPluginRegistrationState: + attempted = False + registered = False + + +CUDA_PLUGIN_EP_NAME = "CudaPluginExecutionProvider" +enable_debug_print = False + + +def should_test_with_cuda_plugin_ep(default_value: bool = True) -> bool: + return os.getenv("ORT_TEST_CUDA_PLUGIN_EP", "1" if default_value else "0") == "1" + + +def _get_package_root(package_name: str, directory_name: str | None = None): + root_directory_name = directory_name or package_name + try: + dist = distribution(package_name) + files = dist.files or [] + + for file in files: + if file.name.endswith("__init__.py") and root_directory_name in file.parts: + return file.locate().parent + + if not directory_name: + for file in files: + if file.name.endswith("__init__.py"): + return file.locate().parent + except PackageNotFoundError: + pass + + return None + + +def _is_cuda_plugin_ep_built() -> bool: + build_info = get_build_info() + return ", cuda-plugin-ep=" in build_info + + +def _get_cuda_plugin_library_name() -> str: + if sys.platform == "win32": + return "onnxruntime_providers_cuda_plugin.dll" + + if sys.platform == "darwin": + return "libonnxruntime_providers_cuda_plugin.dylib" + + return "libonnxruntime_providers_cuda_plugin.so" + + +def _get_default_cuda_plugin_ep_path() -> str | None: + library_name = _get_cuda_plugin_library_name() + + # 1) Match currently imported onnxruntime module first to avoid ABI mismatch. + loaded_onnxruntime_root = Path(onnxrt.__file__).resolve().parent + loaded_candidate = loaded_onnxruntime_root / "capi" / library_name + if loaded_candidate.exists(): + return str(loaded_candidate) + + # 2) Installed wheel location. + for package_name in ("onnxruntime-gpu", "onnxruntime"): + package_root = _get_package_root(package_name, "onnxruntime") + if package_root: + candidate = os.path.join(str(package_root), "capi", library_name) + if os.path.exists(candidate): + return candidate + + # 3) In-tree build location fallback. Search under the repo build dir so we + # can handle different platforms/configurations without hard-coding Release/.so. + # This assumes that user only builds in one configuration. + # Recommend to use ORT_CUDA_PLUGIN_PATH if building in multiple configurations. + repo_root = Path(__file__).resolve().parents[4] + build_root = repo_root / "build" + if not build_root.exists(): + return None + + matches = [path for path in build_root.rglob(library_name) if "CMakeFiles" not in path.parts] + if matches: + + def _sort_key(path: Path) -> tuple[int, int, str]: + path_str = str(path) + if "Release" in path.parts: + config_rank = 0 + elif "RelWithDebInfo" in path.parts: + config_rank = 1 + elif "Debug" in path.parts: + config_rank = 2 + else: + config_rank = 3 + + return (config_rank, len(path.parts), path_str) + + return str(sorted(matches, key=_sort_key)[0]) + + return None + + +def ensure_cuda_plugin_ep_registered(default_test_with_cuda_plugin_ep: bool = True) -> bool: + if _CudaPluginRegistrationState.registered: + return True + + if not should_test_with_cuda_plugin_ep(default_test_with_cuda_plugin_ep): + return False + + if not _is_cuda_plugin_ep_built(): + return False + + ep_lib_path = os.environ.get("ORT_CUDA_PLUGIN_PATH", "") + if not ep_lib_path: + detected_path = _get_default_cuda_plugin_ep_path() + ep_lib_path = detected_path if detected_path else "" + + if not ep_lib_path or not os.path.exists(ep_lib_path): + if enable_debug_print: + print(f"CUDA Plugin EP library not found: {ep_lib_path}") + return False + + _CudaPluginRegistrationState.attempted = True + + try: + onnxrt.register_execution_provider_library(CUDA_PLUGIN_EP_NAME, ep_lib_path) + _CudaPluginRegistrationState.registered = True + except Exception as e: + if "already registered" in str(e).lower(): + _CudaPluginRegistrationState.registered = True + else: + try: + providers = {device.ep_name for device in onnxrt.get_ep_devices()} + except Exception: + providers = set() + + _CudaPluginRegistrationState.registered = CUDA_PLUGIN_EP_NAME in providers + + if enable_debug_print and not _CudaPluginRegistrationState.registered: + print(f"Failed to register CUDA Plugin EP from {ep_lib_path}: {e}") + + return _CudaPluginRegistrationState.registered + + +def resolve_cuda_plugin_ep(ep: str, default_test_with_cuda_plugin_ep: bool = True) -> str: + # Keep all existing test call-sites unchanged: they pass CUDA EP, + # and we transparently route to plugin EP when it is built and loadable. + if ep == "CUDAExecutionProvider" and ensure_cuda_plugin_ep_registered(default_test_with_cuda_plugin_ep): + if _is_plugin_provider_type_available(): + return CUDA_PLUGIN_EP_NAME + + if enable_debug_print: + print(f"{CUDA_PLUGIN_EP_NAME} is not exposed in available provider types. Falling back to {ep}.") + return ep + + +def _is_plugin_provider_type_available() -> bool: + try: + return CUDA_PLUGIN_EP_NAME in onnxrt.get_available_providers() + except Exception: + return False diff --git a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py new file mode 100644 index 0000000000000..75a146d7d3bb0 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py @@ -0,0 +1,855 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import tempfile +import unittest + +import numpy as np +import onnx +import torch +import torch.nn.functional as F +from cuda_plugin_ep_helper import CUDA_PLUGIN_EP_NAME, ensure_cuda_plugin_ep_registered, should_test_with_cuda_plugin_ep +from onnx import TensorProto, helper, save + +import onnxruntime as onnxrt + +try: + import faulthandler + + faulthandler.enable() +except ImportError: + pass + + +TEST_PASS = "PASS" +TEST_SKIP = "SKIP" +TEST_FAIL = "FAIL" + + +def require_cuda_plugin_ep(): + if not should_test_with_cuda_plugin_ep(): + raise unittest.SkipTest("CUDA plugin EP is not enabled for testing") + + if not ensure_cuda_plugin_ep_registered(): + raise unittest.SkipTest("CUDA plugin EP is not built or could not be registered") + + +def get_cuda_plugin_device(): + require_cuda_plugin_ep() + + try: + devices = onnxrt.get_ep_devices() + except Exception as exc: + raise unittest.SkipTest(f"Failed to enumerate CUDA plugin EP devices: {exc}") from exc + + plugin_devices = [device for device in devices if device.ep_name == CUDA_PLUGIN_EP_NAME] + if not plugin_devices: + raise unittest.SkipTest("CUDA plugin EP registered, but no plugin devices were enumerated") + + return plugin_devices[0] + + +def create_add_model(model_path): + # Create a simple Add model: Y = A + B + node_def = helper.make_node("Add", ["A", "B"], ["Y"]) + graph_def = helper.make_graph( + [node_def], + "test-model-add", + [ + helper.make_tensor_value_info("A", TensorProto.FLOAT, [3, 2]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 2]), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 2])], + ) + model_def = helper.make_model(graph_def, producer_name="onnx-example") + save(model_def, model_path) + + +def create_matmul_model(model_path): + # Create a simple MatMul model: Y = A @ B + node_def = helper.make_node("MatMul", ["A", "B"], ["Y"]) + graph_def = helper.make_graph( + [node_def], + "test-model-matmul", + [ + helper.make_tensor_value_info("A", TensorProto.FLOAT, [3, 4]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [4, 5]), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 5])], + ) + model_def = helper.make_model(graph_def, producer_name="onnx-example") + save(model_def, model_path) + + +def create_gemm_model(model_path, alpha=1.0, beta=1.0, transA=0, transB=0): + # Create a simple Gemm model: Y = alpha*A*B + beta*C + node_def = helper.make_node("Gemm", ["A", "B", "C"], ["Y"], alpha=alpha, beta=beta, transA=transA, transB=transB) + + m = 3 + k = 4 + n = 5 + shape_a = [m, k] if transA == 0 else [k, m] + shape_b = [k, n] if transB == 0 else [n, k] + shape_c = [n] # Test broadcast + + graph_def = helper.make_graph( + [node_def], + "test-model-gemm", + [ + helper.make_tensor_value_info("A", TensorProto.FLOAT, shape_a), + helper.make_tensor_value_info("B", TensorProto.FLOAT, shape_b), + helper.make_tensor_value_info("C", TensorProto.FLOAT, shape_c), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [m, n])], + ) + model_def = helper.make_model(graph_def, producer_name="onnx-example") + save(model_def, model_path) + + +def create_conv_model(model_path): + # Create a simple Conv model: Y = Conv(X, W) + node_def = helper.make_node("Conv", ["X", "W"], ["Y"], pads=[1, 1, 1, 1], strides=[1, 1], dilations=[1, 1], group=1) + graph_def = helper.make_graph( + [node_def], + "test-model-conv", + [ + helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4]), + helper.make_tensor_value_info("W", TensorProto.FLOAT, [3, 2, 3, 3]), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 4, 4])], + ) + opset = onnx.OperatorSetIdProto() + opset.version = 11 + model_def = helper.make_model(graph_def, producer_name="onnx-example", opset_imports=[opset]) + save(model_def, model_path) + + +def create_batch_norm_model(model_path): + """Create a BatchNormalization model for NHWC testing.""" + num_channels = 3 + node_def = helper.make_node( + "BatchNormalization", + ["X", "scale", "B", "input_mean", "input_var"], + ["Y"], + epsilon=1e-5, + ) + # scale, B, mean, var are 1D tensors of shape [num_channels] + scale_init = helper.make_tensor( + "scale", TensorProto.FLOAT, [num_channels], np.ones(num_channels, dtype=np.float32).tolist() + ) + bias_init = helper.make_tensor( + "B", TensorProto.FLOAT, [num_channels], np.zeros(num_channels, dtype=np.float32).tolist() + ) + mean_init = helper.make_tensor( + "input_mean", TensorProto.FLOAT, [num_channels], np.zeros(num_channels, dtype=np.float32).tolist() + ) + var_init = helper.make_tensor( + "input_var", TensorProto.FLOAT, [num_channels], np.ones(num_channels, dtype=np.float32).tolist() + ) + + graph_def = helper.make_graph( + [node_def], + "test-model-batchnorm", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, num_channels, 4, 4])], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, num_channels, 4, 4])], + initializer=[scale_init, bias_init, mean_init, var_init], + ) + opset = onnx.OperatorSetIdProto() + opset.version = 15 + model_def = helper.make_model(graph_def, producer_name="onnx-example", opset_imports=[opset]) + save(model_def, model_path) + + +def create_maxpool_model(model_path): + """Create a MaxPool model for NHWC testing.""" + node_def = helper.make_node( + "MaxPool", + ["X"], + ["Y"], + kernel_shape=[2, 2], + strides=[2, 2], + ) + graph_def = helper.make_graph( + [node_def], + "test-model-maxpool", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4])], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 2, 2])], + ) + opset = onnx.OperatorSetIdProto() + opset.version = 12 + model_def = helper.make_model(graph_def, producer_name="onnx-example", opset_imports=[opset]) + save(model_def, model_path) + + +def create_avgpool_model(model_path): + """Create an AveragePool model for NHWC testing.""" + node_def = helper.make_node( + "AveragePool", + ["X"], + ["Y"], + kernel_shape=[2, 2], + strides=[2, 2], + ) + graph_def = helper.make_graph( + [node_def], + "test-model-avgpool", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4])], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 2, 2])], + ) + opset = onnx.OperatorSetIdProto() + opset.version = 12 + model_def = helper.make_model(graph_def, producer_name="onnx-example", opset_imports=[opset]) + save(model_def, model_path) + + +def make_bias_dropout_model(): + """Create a deterministic BiasDropout model by forcing inference mode.""" + node = helper.make_node( + "BiasDropout", + ["X", "bias", "residual", "ratio", "training_mode"], + ["Y", ""], + domain="com.microsoft", + ) + graph = helper.make_graph( + [node], + "test-BiasDropout", + [ + helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4]), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, [4]), + helper.make_tensor_value_info("residual", TensorProto.FLOAT, [2, 4]), + helper.make_tensor_value_info("ratio", TensorProto.FLOAT, []), + helper.make_tensor_value_info("training_mode", TensorProto.BOOL, []), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])], + ) + opset_onnx = onnx.OperatorSetIdProto() + opset_onnx.version = 13 + opset_ms = onnx.OperatorSetIdProto() + opset_ms.domain = "com.microsoft" + opset_ms.version = 1 + return helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) + + +def run_operator_test( + target_device, model_creator, inputs, expected_fn, ep_name=CUDA_PLUGIN_EP_NAME, session_config=None +): + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp: + model_path = tmp.name + try: + model_creator(model_path) + sess_options = onnxrt.SessionOptions() + if session_config: + for key, value in session_config.items(): + sess_options.add_session_config_entry(key, value) + sess_options.add_provider_for_devices([target_device], {}) + sess = onnxrt.InferenceSession(model_path, sess_options=sess_options) + + active_providers = sess.get_providers() + if ep_name not in active_providers: + print(f"FAILURE: {ep_name} is NOT active for this operator. Providers: {active_providers}") + return False + + print(f"(Session created with {active_providers})", end=" ", flush=True) + print("Running...", end=" ", flush=True) + res = sess.run(None, inputs) + print("Done.", end=" ", flush=True) + expected = expected_fn(inputs) + np.testing.assert_allclose(res[0], expected, rtol=1e-3, atol=1e-3) + return True + finally: + if os.path.exists(model_path): + os.remove(model_path) + + +def run_provider_options_test(provider_options, expect_plugin_provider=True): + require_cuda_plugin_ep() + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp: + model_path = tmp.name + try: + create_add_model(model_path) + providers = [(CUDA_PLUGIN_EP_NAME, provider_options), "CPUExecutionProvider"] + sess = onnxrt.InferenceSession(model_path, providers=providers) + active_providers = sess.get_providers() + + if expect_plugin_provider and CUDA_PLUGIN_EP_NAME not in active_providers: + print(f"FAILURE: {CUDA_PLUGIN_EP_NAME} is NOT active. Providers: {active_providers}") + return False + if not expect_plugin_provider and CUDA_PLUGIN_EP_NAME in active_providers: + print(f"FAILURE: {CUDA_PLUGIN_EP_NAME} unexpectedly active. Providers: {active_providers}") + return False + + a = np.random.rand(3, 2).astype(np.float32) + b = np.random.rand(3, 2).astype(np.float32) + res = sess.run(None, {"A": a, "B": b}) + np.testing.assert_allclose(res[0], a + b, rtol=1e-3, atol=1e-3) + return True + except Exception as e: + if expect_plugin_provider: + print(f"FAIL ({e})") + return False + + print(f"Expected failure for provider options {provider_options}: {e}") + return True + finally: + if os.path.exists(model_path): + os.remove(model_path) + + +def _run_registration_checks(test_case: unittest.TestCase): + target_device = get_cuda_plugin_device() + print(f"Using registered plugin: {CUDA_PLUGIN_EP_NAME}", flush=True) + print(f"Using device: {target_device.ep_name}", flush=True) + + x = np.random.rand(1, 2, 4, 4).astype(np.float32) + w = np.random.rand(3, 2, 3, 3).astype(np.float32) + + def expected_conv(inputs): + return F.conv2d(torch.from_numpy(inputs["X"]), torch.from_numpy(inputs["W"]), padding=1).numpy() + + stage2_cases = [ + ( + "Add", + create_add_model, + {"A": np.random.rand(3, 2).astype(np.float32), "B": np.random.rand(3, 2).astype(np.float32)}, + lambda feed: feed["A"] + feed["B"], + None, + ), + ( + "MatMul", + create_matmul_model, + {"A": np.random.rand(3, 4).astype(np.float32), "B": np.random.rand(4, 5).astype(np.float32)}, + lambda feed: feed["A"] @ feed["B"], + None, + ), + ( + "Gemm", + lambda model_path: create_gemm_model(model_path, alpha=2.0, beta=0.5), + { + "A": np.random.rand(3, 4).astype(np.float32), + "B": np.random.rand(4, 5).astype(np.float32), + "C": np.random.rand(5).astype(np.float32), + }, + lambda feed: 2.0 * (feed["A"] @ feed["B"]) + 0.5 * feed["C"], + None, + ), + ("Conv", create_conv_model, {"X": x, "W": w}, expected_conv, None), + ] + + for name, model_creator, inputs, expected_fn, session_config in stage2_cases: + print(f"Testing {name}...", end=" ", flush=True) + result = run_operator_test(target_device, model_creator, inputs, expected_fn, session_config=session_config) + with test_case.subTest(op=name): + test_case.assertTrue( + result, + f"{name} plugin registration test failed", + ) + print(TEST_PASS if result else TEST_FAIL, flush=True) + + print("\nAll Stage 2 tests finished successfully.", flush=True) + + nhwc_config = {"ep.cuda.prefer_nhwc_layout": "1"} + + def expected_batchnorm(inputs): + return inputs["X"] / np.sqrt(1.0 + 1e-5) + + stage3_cases = [ + ( + "Conv (NHWC)", + create_conv_model, + { + "X": np.random.rand(1, 2, 4, 4).astype(np.float32), + "W": np.random.rand(3, 2, 3, 3).astype(np.float32), + }, + expected_conv, + ), + ( + "BatchNormalization (NHWC)", + create_batch_norm_model, + {"X": np.random.rand(1, 3, 4, 4).astype(np.float32)}, + expected_batchnorm, + ), + ( + "MaxPool (NHWC)", + create_maxpool_model, + {"X": np.random.rand(1, 3, 4, 4).astype(np.float32)}, + lambda feed: F.max_pool2d(torch.from_numpy(feed["X"]), kernel_size=2, stride=2).numpy(), + ), + ( + "AveragePool (NHWC)", + create_avgpool_model, + {"X": np.random.rand(1, 3, 4, 4).astype(np.float32)}, + lambda feed: F.avg_pool2d(torch.from_numpy(feed["X"]), kernel_size=2, stride=2).numpy(), + ), + ] + + for name, model_creator, inputs, expected_fn in stage3_cases: + print(f"Testing {name}...", end=" ", flush=True) + result = run_operator_test(target_device, model_creator, inputs, expected_fn, session_config=nhwc_config) + with test_case.subTest(op=name): + test_case.assertTrue( + result, + f"{name} plugin NHWC test failed", + ) + print(TEST_PASS if result else TEST_FAIL, flush=True) + + print("\nAll Stage 3 NHWC tests finished successfully.", flush=True) + + provider_option_cases = [ + ("provider options with valid device_id/use_tf32", {"device_id": "0", "use_tf32": "0"}, True), + ("provider options with invalid device_id", {"device_id": "999"}, False), + ] + + print("\nTesting provider options path...", flush=True) + for name, provider_options, expect_plugin_provider in provider_option_cases: + print(f"Testing {name}...", end=" ", flush=True) + result = run_provider_options_test(provider_options, expect_plugin_provider=expect_plugin_provider) + with test_case.subTest(op=name): + test_case.assertTrue( + result, + f"{name} failed", + ) + print(TEST_PASS if result else TEST_FAIL, flush=True) + + +def _make_simple_model(op_type, inputs_info, outputs_info, attrs=None, opset=13, domain=""): + """Helper to create a simple single-node ONNX model. + + Args: + op_type: ONNX op type string + inputs_info: list of (name, elem_type, shape) tuples + outputs_info: list of (name, elem_type, shape) tuples + attrs: dict of node attributes + opset: opset version + domain: op domain (empty string for default ONNX domain) + """ + input_names = [info[0] for info in inputs_info] + output_names = [info[0] for info in outputs_info] + node = helper.make_node(op_type, input_names, output_names, domain=domain, **(attrs or {})) + graph = helper.make_graph( + [node], + f"test-{op_type}", + [helper.make_tensor_value_info(n, t, s) for n, t, s in inputs_info], + [helper.make_tensor_value_info(n, t, s) for n, t, s in outputs_info], + ) + opset_import = [onnx.OperatorSetIdProto()] + opset_import[0].version = opset + if domain: + ms_opset = onnx.OperatorSetIdProto() + ms_opset.domain = domain + ms_opset.version = 1 + opset_import.append(ms_opset) + model = helper.make_model(graph, opset_imports=opset_import) + return model + + +def _run_model_test( + target_device, op_name, model, feed_dict, expected_fn, ep_name=CUDA_PLUGIN_EP_NAME, rtol=1e-3, atol=1e-3 +): + """Run a single op test: save model, create session, run, compare.""" + with tempfile.NamedTemporaryFile(suffix=f"_{op_name}.onnx", delete=False) as tmp: + model_path = tmp.name + try: + save(model, model_path) + sess_options = onnxrt.SessionOptions() + sess_options.add_provider_for_devices([target_device], {}) + sess = onnxrt.InferenceSession(model_path, sess_options=sess_options) + active_providers = sess.get_providers() + if ep_name not in active_providers: + print(f"{TEST_SKIP} ({ep_name} not active)") + return TEST_SKIP + res = sess.run(None, feed_dict) + expected = expected_fn(feed_dict) + if isinstance(expected, (list, tuple)): + if len(res) != len(expected): + raise AssertionError(f"{op_name} produced {len(res)} outputs, expected {len(expected)}") + + for r, e in zip(res, expected, strict=True): + np.testing.assert_allclose(r, e, rtol=rtol, atol=atol) + else: + np.testing.assert_allclose(res[0], expected, rtol=rtol, atol=atol) + return TEST_PASS + except Exception as e: + print(f"{TEST_FAIL} ({e})") + return TEST_FAIL + finally: + if os.path.exists(model_path): + os.remove(model_path) + + +def _run_stage5_checks(test_case: unittest.TestCase): + """Stage 5: Test all ops enabled during Stage 5 (5A through 5D).""" + target_device = get_cuda_plugin_device() + passed = 0 + failed = 0 + skipped = 0 + + def run_test(name, model, feed, expected_fn, rtol=1e-3, atol=1e-3): + nonlocal passed, failed, skipped + print(f" {name}...", end=" ", flush=True) + result = _run_model_test(target_device, name, model, feed, expected_fn, rtol=rtol, atol=atol) + with test_case.subTest(op=name): + if result == TEST_PASS: + passed += 1 + print(TEST_PASS, flush=True) + return + + if result == TEST_SKIP: + skipped += 1 + print(TEST_SKIP, flush=True) + return + + failed += 1 + print(TEST_FAIL, flush=True) + test_case.fail(f"{name} Stage 5 plugin op test failed") + + print("\n==================== Stage 5: Expanded Op Tests ====================", flush=True) + f_dtype = TensorProto.FLOAT + + # ---- 5A/5B: Standard ops ---- + print("\n--- Standard Ops (5A/5B) ---", flush=True) + + # Reshape + model = _make_simple_model( + "Reshape", [("X", f_dtype, [2, 3, 4]), ("shape", TensorProto.INT64, [2])], [("Y", f_dtype, [6, 4])] + ) + # Need shape as initializer; build manually + shape_init = helper.make_tensor("shape", TensorProto.INT64, [2], [6, 4]) + model.graph.initializer.append(shape_init) + x = np.random.rand(2, 3, 4).astype(np.float32) + run_test("Reshape", model, {"X": x}, lambda f: f["X"].reshape(6, 4)) + + # Split (opset 18 supports num_outputs; use split input for opset 13) + node = helper.make_node("Split", ["X", "split"], ["Y1", "Y2"], axis=0) + graph = helper.make_graph( + [node], + "test-Split", + [helper.make_tensor_value_info("X", f_dtype, [6, 4])], + [helper.make_tensor_value_info("Y1", f_dtype, [3, 4]), helper.make_tensor_value_info("Y2", f_dtype, [3, 4])], + ) + opset = onnx.OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("split", TensorProto.INT64, [2], [3, 3])) + x = np.random.rand(6, 4).astype(np.float32) + run_test("Split", model, {"X": x}, lambda f: [f["X"][:3], f["X"][3:]]) + + # Concat + model = _make_simple_model( + "Concat", [("A", f_dtype, [2, 3]), ("B", f_dtype, [2, 3])], [("Y", f_dtype, [4, 3])], attrs={"axis": 0} + ) + a = np.random.rand(2, 3).astype(np.float32) + b = np.random.rand(2, 3).astype(np.float32) + run_test("Concat", model, {"A": a, "B": b}, lambda f: np.concatenate([f["A"], f["B"]], axis=0)) + + # Gather + gather_model = _make_simple_model( + "Gather", + [("X", f_dtype, [5, 4]), ("indices", TensorProto.INT64, [3])], + [("Y", f_dtype, [3, 4])], + attrs={"axis": 0}, + opset=13, + ) + x = np.random.rand(5, 4).astype(np.float32) + idx = np.array([0, 2, 4], dtype=np.int64) + run_test("Gather", gather_model, {"X": x, "indices": idx}, lambda f: f["X"][f["indices"]]) + + # Unsqueeze (opset 13 uses axes as input) + node = helper.make_node("Unsqueeze", ["X", "axes"], ["Y"]) + graph = helper.make_graph( + [node], + "test-Unsqueeze", + [helper.make_tensor_value_info("X", f_dtype, [3, 4])], + [helper.make_tensor_value_info("Y", f_dtype, [1, 3, 4])], + ) + opset = onnx.OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + axes_init = helper.make_tensor("axes", TensorProto.INT64, [1], [0]) + model.graph.initializer.append(axes_init) + x = np.random.rand(3, 4).astype(np.float32) + run_test("Unsqueeze", model, {"X": x}, lambda f: np.expand_dims(f["X"], 0)) + + # Tile + node = helper.make_node("Tile", ["X", "repeats"], ["Y"]) + graph = helper.make_graph( + [node], + "test-Tile", + [helper.make_tensor_value_info("X", f_dtype, [2, 3])], + [helper.make_tensor_value_info("Y", f_dtype, [4, 9])], + ) + opset = onnx.OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + repeats_init = helper.make_tensor("repeats", TensorProto.INT64, [2], [2, 3]) + model.graph.initializer.append(repeats_init) + x = np.random.rand(2, 3).astype(np.float32) + run_test("Tile", model, {"X": x}, lambda f: np.tile(f["X"], (2, 3))) + + # CumSum + node = helper.make_node("CumSum", ["X", "axis"], ["Y"]) + graph = helper.make_graph( + [node], + "test-CumSum", + [helper.make_tensor_value_info("X", f_dtype, [3, 4])], + [helper.make_tensor_value_info("Y", f_dtype, [3, 4])], + ) + opset = onnx.OperatorSetIdProto() + opset.version = 14 + model = helper.make_model(graph, opset_imports=[opset]) + axis_init = helper.make_tensor("axis", TensorProto.INT64, [], [1]) + model.graph.initializer.append(axis_init) + x = np.random.rand(3, 4).astype(np.float32) + run_test("CumSum", model, {"X": x}, lambda f: np.cumsum(f["X"], axis=1)) + + # ConstantOfShape + node = helper.make_node( + "ConstantOfShape", ["shape"], ["Y"], value=helper.make_tensor("value", TensorProto.FLOAT, [1], [3.14]) + ) + graph = helper.make_graph( + [node], + "test-ConstantOfShape", + [helper.make_tensor_value_info("shape", TensorProto.INT64, [2])], + [helper.make_tensor_value_info("Y", f_dtype, None)], + ) + opset = onnx.OperatorSetIdProto() + opset.version = 9 + model = helper.make_model(graph, opset_imports=[opset]) + run_test( + "ConstantOfShape", + model, + {"shape": np.array([2, 3], dtype=np.int64)}, + lambda f: np.full((2, 3), 3.14, dtype=np.float32), + ) + + # SpaceToDepth + model = _make_simple_model( + "SpaceToDepth", [("X", f_dtype, [1, 2, 4, 4])], [("Y", f_dtype, [1, 8, 2, 2])], attrs={"blocksize": 2}, opset=13 + ) + x = np.random.rand(1, 2, 4, 4).astype(np.float32) + + def space_to_depth(f): + inp = f["X"] + b, c, h, w = inp.shape + bs = 2 + # ONNX SpaceToDepth: rearrange blocks of spatial data into depth + # (b, c, h, w) -> (b, c, h/bs, bs, w/bs, bs) -> (b, c*bs*bs, h/bs, w/bs) + tmp = inp.reshape(b, c, h // bs, bs, w // bs, bs) + tmp = tmp.transpose(0, 3, 5, 1, 2, 4) + return tmp.reshape(b, c * bs * bs, h // bs, w // bs) + + run_test("SpaceToDepth", model, {"X": x}, space_to_depth) + + # Pad + node = helper.make_node("Pad", ["X", "pads", "constant_value"], ["Y"]) + graph = helper.make_graph( + [node], + "test-Pad", + [helper.make_tensor_value_info("X", f_dtype, [2, 3])], + [helper.make_tensor_value_info("Y", f_dtype, [4, 5])], + ) + opset = onnx.OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("pads", TensorProto.INT64, [4], [1, 1, 1, 1])) + model.graph.initializer.append(helper.make_tensor("constant_value", TensorProto.FLOAT, [], [0.0])) + x = np.random.rand(2, 3).astype(np.float32) + run_test("Pad", model, {"X": x}, lambda f: np.pad(f["X"], ((1, 1), (1, 1)), constant_values=0)) + + # Slice + node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"]) + graph = helper.make_graph( + [node], + "test-Slice", + [helper.make_tensor_value_info("X", f_dtype, [4, 6])], + [helper.make_tensor_value_info("Y", f_dtype, [2, 4])], + ) + opset = onnx.OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("starts", TensorProto.INT64, [2], [1, 1])) + model.graph.initializer.append(helper.make_tensor("ends", TensorProto.INT64, [2], [3, 5])) + model.graph.initializer.append(helper.make_tensor("axes", TensorProto.INT64, [2], [0, 1])) + x = np.random.rand(4, 6).astype(np.float32) + run_test("Slice", model, {"X": x}, lambda f: f["X"][1:3, 1:5]) + + # Resize (nearest) + node = helper.make_node("Resize", ["X", "", "scales"], ["Y"], mode="nearest") + graph = helper.make_graph( + [node], + "test-Resize", + [helper.make_tensor_value_info("X", f_dtype, [1, 1, 2, 2])], + [helper.make_tensor_value_info("Y", f_dtype, [1, 1, 4, 4])], + ) + opset = onnx.OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("scales", TensorProto.FLOAT, [4], [1.0, 1.0, 2.0, 2.0])) + x = np.random.rand(1, 1, 2, 2).astype(np.float32) + run_test("Resize", model, {"X": x}, lambda f: np.repeat(np.repeat(f["X"], 2, axis=2), 2, axis=3)) + + # Sum (variadic) + model = _make_simple_model( + "Sum", + [("A", f_dtype, [3, 4]), ("B", f_dtype, [3, 4]), ("C", f_dtype, [3, 4])], + [("Y", f_dtype, [3, 4])], + opset=13, + ) + a = np.random.rand(3, 4).astype(np.float32) + b = np.random.rand(3, 4).astype(np.float32) + c = np.random.rand(3, 4).astype(np.float32) + run_test("Sum_variadic", model, {"A": a, "B": b, "C": c}, lambda f: f["A"] + f["B"] + f["C"]) + + # ---- 5C: CPU base class ops ---- + print("\n--- CPU Base Class Ops (5C) ---", flush=True) + + # Upsample (deprecated but still present) + node = helper.make_node("Upsample", ["X", "scales"], ["Y"], mode="nearest") + graph = helper.make_graph( + [node], + "test-Upsample", + [helper.make_tensor_value_info("X", f_dtype, [1, 1, 2, 2])], + [helper.make_tensor_value_info("Y", f_dtype, [1, 1, 4, 4])], + ) + opset = onnx.OperatorSetIdProto() + opset.version = 9 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("scales", TensorProto.FLOAT, [4], [1.0, 1.0, 2.0, 2.0])) + x = np.random.rand(1, 1, 2, 2).astype(np.float32) + run_test("Upsample", model, {"X": x}, lambda f: np.repeat(np.repeat(f["X"], 2, axis=2), 2, axis=3)) + + # DepthToSpace + model = _make_simple_model( + "DepthToSpace", + [("X", f_dtype, [1, 8, 2, 2])], + [("Y", f_dtype, [1, 2, 4, 4])], + attrs={"blocksize": 2, "mode": "DCR"}, + opset=13, + ) + x = np.random.rand(1, 8, 2, 2).astype(np.float32) + + def depth_to_space_dcr(f): + inp = f["X"] + b, c, h, w = inp.shape + bs = 2 + return ( + inp.reshape(b, bs, bs, c // (bs * bs), h, w) + .transpose(0, 3, 4, 1, 5, 2) + .reshape(b, c // (bs * bs), h * bs, w * bs) + ) + + run_test("DepthToSpace", model, {"X": x}, depth_to_space_dcr) + + # ---- 5D: Contrib Ops ---- + print("\n--- Contrib Ops (5D) ---", flush=True) + + # FastGelu (com.microsoft domain) + node = helper.make_node("FastGelu", ["X"], ["Y"], domain="com.microsoft") + graph = helper.make_graph( + [node], + "test-FastGelu", + [helper.make_tensor_value_info("X", f_dtype, [2, 4])], + [helper.make_tensor_value_info("Y", f_dtype, [2, 4])], + ) + opset_onnx = onnx.OperatorSetIdProto() + opset_onnx.version = 13 + opset_ms = onnx.OperatorSetIdProto() + opset_ms.domain = "com.microsoft" + opset_ms.version = 1 + model = helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) + x = np.random.rand(2, 4).astype(np.float32) + + def fast_gelu_ref(f): + x = f["X"] + # FastGelu approximation: x * sigmoid(1.702 * x) + return x * (1.0 / (1.0 + np.exp(-1.702 * x))) + + run_test("FastGelu", model, {"X": x}, fast_gelu_ref, rtol=1e-2, atol=1e-2) + + # BiasDropout (com.microsoft). We force inference mode so the op is deterministic. + model = make_bias_dropout_model() + x = np.random.rand(2, 4).astype(np.float32) + bias = np.random.rand(4).astype(np.float32) + residual = np.random.rand(2, 4).astype(np.float32) + ratio = np.array(0.5, dtype=np.float32) + training_mode = np.array(False, dtype=np.bool_) + run_test( + "BiasDropout", + model, + { + "X": x, + "bias": bias, + "residual": residual, + "ratio": ratio, + "training_mode": training_mode, + }, + lambda feed: feed["X"] + feed["bias"] + feed["residual"], + ) + + # SkipLayerNormalization (com.microsoft) + hidden_size = 8 + node = helper.make_node( + "SkipLayerNormalization", + ["X", "skip", "gamma", "beta"], + ["Y", "", "", "input_skip_bias_sum"], + domain="com.microsoft", + epsilon=1e-5, + ) + graph = helper.make_graph( + [node], + "test-SkipLayerNorm", + [ + helper.make_tensor_value_info("X", f_dtype, [2, hidden_size]), + helper.make_tensor_value_info("skip", f_dtype, [2, hidden_size]), + helper.make_tensor_value_info("gamma", f_dtype, [hidden_size]), + helper.make_tensor_value_info("beta", f_dtype, [hidden_size]), + ], + [ + helper.make_tensor_value_info("Y", f_dtype, [2, hidden_size]), + helper.make_tensor_value_info("input_skip_bias_sum", f_dtype, None), + ], + ) + opset_onnx = onnx.OperatorSetIdProto() + opset_onnx.version = 13 + opset_ms = onnx.OperatorSetIdProto() + opset_ms.domain = "com.microsoft" + opset_ms.version = 1 + model = helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) + x = np.random.rand(2, hidden_size).astype(np.float32) + skip = np.random.rand(2, hidden_size).astype(np.float32) + gamma = np.ones(hidden_size, dtype=np.float32) + beta = np.zeros(hidden_size, dtype=np.float32) + + def skip_layer_norm_ref(f): + added = f["X"] + f["skip"] + mean = added.mean(axis=-1, keepdims=True) + var = added.var(axis=-1, keepdims=True) + normed = (added - mean) / np.sqrt(var + 1e-5) + return [normed * f["gamma"] + f["beta"], added] + + run_test( + "SkipLayerNorm", + model, + {"X": x, "skip": skip, "gamma": gamma, "beta": beta}, + skip_layer_norm_ref, + rtol=1e-2, + atol=1e-2, + ) + + # ---- Summary ---- + total = passed + failed + skipped + print(f"\n--- Stage 5 Results: {passed} passed, {failed} failed, {skipped} skipped ({total} total) ---", flush=True) + test_case.assertEqual(failed, 0, f"Stage 5 had {failed} failing plugin op checks") + print("All Stage 5 tests finished successfully.", flush=True) + + +class TestCudaPluginEP(unittest.TestCase): + def test_cuda_plugin_registration(self): + _run_registration_checks(self) + + def test_cuda_plugin_stage5_ops(self): + _run_stage5_checks(self) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 5ff0572c927c6..5d15a70c207f3 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -20,6 +20,7 @@ import numpy import torch +from cuda_plugin_ep_helper import resolve_cuda_plugin_ep from einops import rearrange, repeat # --- ONNX and Torch/Numpy Dtype Mappings --- @@ -456,7 +457,7 @@ def gqa_prompt_func( new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) + ort_session = InferenceSession(onnx_model_str, sess_options, providers=[resolve_cuda_plugin_ep(ep)]) io_binding = ort_session.io_binding() # Determine input device for binding @@ -492,8 +493,9 @@ def gqa_prompt_func( # total_sequence_length is INT32 [1] # Schema requires this to be on CPU (OrtMemTypeCPUInput) - tsl = torch.tensor([config.q_sequence_length], dtype=torch.int32, device="cpu") - bind_tensor(io_binding, "total_sequence_length", tsl, "cpu", TensorProto.INT32) + cpu_device = torch.device("cpu") + tsl = torch.tensor([config.q_sequence_length], dtype=torch.int32, device=cpu_device) + bind_tensor(io_binding, "total_sequence_length", tsl, cpu_device, TensorProto.INT32) # 5. Optional inputs if cos is not None: @@ -616,7 +618,7 @@ def gqa_past_func( sess_options = SessionOptions() # sess_options.log_severity_level = 0 - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) + ort_session = InferenceSession(onnx_model_str, sess_options, providers=[resolve_cuda_plugin_ep(ep)]) io_binding = ort_session.io_binding() # Common inputs @@ -653,8 +655,10 @@ def gqa_past_func( seqlens_k_int32 = seqlens_k.to(dtype=torch.int32, device=device) bind_tensor(io_binding, "seqlens_k", seqlens_k_int32, device, TensorProto.INT32) - tsl = torch.tensor([total_seq_len], dtype=torch.int32, device=device) - bind_tensor(io_binding, "total_sequence_length", tsl, device, TensorProto.INT32) + # GroupQueryAttention expects total_sequence_length as CPU input. + cpu_device = torch.device("cpu") + tsl = torch.tensor([total_seq_len], dtype=torch.int32, device=cpu_device) + bind_tensor(io_binding, "total_sequence_length", tsl, cpu_device, TensorProto.INT32) # 5. Optional inputs if cos is not None: diff --git a/onnxruntime/test/python/transformers/test_moe_cuda.py b/onnxruntime/test/python/transformers/test_moe_cuda.py index c09d8bacf1fa2..67caf903f0165 100644 --- a/onnxruntime/test/python/transformers/test_moe_cuda.py +++ b/onnxruntime/test/python/transformers/test_moe_cuda.py @@ -17,6 +17,7 @@ import numpy import torch import torch.nn.functional as F +from cuda_plugin_ep_helper import resolve_cuda_plugin_ep from onnx import TensorProto, helper from parameterized import parameterized from torch import nn @@ -31,7 +32,14 @@ # Determine the execution provider and device based on CUDA availability. use_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers() and torch.cuda.is_available() device = torch.device("cuda:0" if use_cuda else "cpu") -ort_provider = ["CUDAExecutionProvider"] if use_cuda else ["CPUExecutionProvider"] + + +def get_ort_provider(): + if not use_cuda: + return ["CPUExecutionProvider"] + + return [resolve_cuda_plugin_ep("CUDAExecutionProvider")] + torch.manual_seed(42) numpy.random.seed(42) @@ -586,11 +594,12 @@ def create_ort_session(self, moe_onnx_graph): sess_options = SessionOptions() sess_options.log_severity_level = 2 + providers = get_ort_provider() try: - ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider) + ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=providers) except Exception as e: - print(f"Failed to create ONNX Runtime session with provider {ort_provider}: {e}") + print(f"Failed to create ONNX Runtime session with provider {providers}: {e}") print("Skipping ONNX Runtime execution for this test case.") return None From 34a5416f9c2244249a062dc800a8885ef6d9d2de Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 26 Mar 2026 08:32:36 -0700 Subject: [PATCH 23/48] fill controlflow opset gap --- .../cuda/plugin/cuda_controlflow_plugin.cc | 104 ++++++++++++++++-- 1 file changed, 97 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc index 62646a210e1a0..8cda066ffccea 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc @@ -282,17 +282,47 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes()), + PluginIfKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 19, 20, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), + PluginIfKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 21, 22, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), + PluginIfKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 23, 24, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), PluginIfKernel); ONNX_OPERATOR_KERNEL_EX(If, kOnnxDomain, - 19, + 25, kCudaExecutionProvider, (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), PluginIfKernel); // --- Loop --- @@ -330,19 +360,55 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, .InputMemoryType(OrtMemTypeCPUInput, 1) .TypeConstraint("I", DataTypeImpl::GetTensorType()) .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes()), + PluginLoopKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, + kOnnxDomain, + 19, 20, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), + PluginLoopKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, + kOnnxDomain, + 21, 22, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), + PluginLoopKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, + kOnnxDomain, + 23, 24, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), PluginLoopKernel); ONNX_OPERATOR_KERNEL_EX(Loop, kOnnxDomain, - 19, + 25, kCudaExecutionProvider, (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) .InputMemoryType(OrtMemTypeCPUInput, 1) .TypeConstraint("I", DataTypeImpl::GetTensorType()) .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), PluginLoopKernel); // --- Scan (opset 8) --- @@ -383,9 +449,33 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), PluginScanKernel); +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, + kOnnxDomain, + 19, 20, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + PluginScanKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, + kOnnxDomain, + 21, 22, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + PluginScanKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, + kOnnxDomain, + 23, 24, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + PluginScanKernel); + ONNX_OPERATOR_KERNEL_EX(Scan, kOnnxDomain, - 19, + 25, kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), From 5ac5b9f2ba29c34d227f48a02d288437ae7e6f01 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 26 Mar 2026 08:34:32 -0700 Subject: [PATCH 24/48] ep config --- cmake/onnxruntime_unittests.cmake | 5 + .../providers/cuda/plugin/cuda_ep_factory.cc | 172 +++++++++++++----- .../python/onnxruntime_pybind_state.cc | 23 ++- .../test/framework/dynamic_plugin_ep_test.cc | 94 ++++++++++ 4 files changed, 249 insertions(+), 45 deletions(-) create mode 100644 onnxruntime/test/framework/dynamic_plugin_ep_test.cc diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 59f5e6ff11fd8..f9d6c5d6fd980 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -818,6 +818,11 @@ file(GLOB onnxruntime_test_framework_src CONFIGURE_DEPENDS ${onnxruntime_test_framework_src_patterns} ) +if (NOT (onnxruntime_USE_CUDA AND onnxruntime_BUILD_CUDA_EP_AS_PLUGIN)) + list(REMOVE_ITEM onnxruntime_test_framework_src + "${TEST_SRC_DIR}/framework/dynamic_plugin_ep_test.cc") +endif() + #This is a small wrapper library that shouldn't use any onnxruntime internal symbols(except onnxruntime_common). #Because it could dynamically link to onnxruntime. Otherwise you will have two copies of onnxruntime in the same #process and you won't know which one you are testing. diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index eb55a09c024b1..a9e89c3df15b9 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -4,6 +4,13 @@ #include "cuda_ep_factory.h" #include "cuda_ep.h" #include "cuda_plugin_kernels.h" +#include "core/session/abi_session_options_impl.h" + +#include +#include +#include +#include +#include namespace onnxruntime { namespace cuda_plugin { @@ -100,6 +107,17 @@ const char* ORT_API_CALL CudaEpFactory::GetVersionImpl(const OrtEpFactory* this_ return static_cast(this_ptr)->ep_version_.c_str(); } +namespace { + +std::string ToUpper(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { + return static_cast(std::toupper(c)); + }); + return value; +} + +} // namespace + /*static*/ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( OrtEpFactory* this_ptr, @@ -137,7 +155,11 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( } OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; factory->ort_api_.CreateKeyValuePairs(&ep_metadata); + factory->ort_api_.CreateKeyValuePairs(&ep_options); + factory->ort_api_.AddKeyValuePair(ep_metadata, "cuda_device_id", std::to_string(current_device_id).c_str()); + factory->ort_api_.AddKeyValuePair(ep_options, "device_id", std::to_string(current_device_id).c_str()); // Get CUDA device properties for metadata int cuda_device_count = 0; @@ -153,9 +175,10 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( } OrtEpDevice* ep_device = nullptr; - auto* status = factory->ep_api_.CreateEpDevice(factory, &device, ep_metadata, nullptr, + auto* status = factory->ep_api_.CreateEpDevice(factory, &device, ep_metadata, ep_options, &ep_device); factory->ort_api_.ReleaseKeyValuePairs(ep_metadata); + factory->ort_api_.ReleaseKeyValuePairs(ep_options); if (status != nullptr) { return status; @@ -230,67 +253,128 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( } } - auto read_session_config_bool = [&](const char* key, bool& value) { + auto try_get_session_config = [&](std::string_view key) -> std::optional { + if (session_options == nullptr) { + return std::nullopt; + } + size_t size = 0; - OrtStatus* status = factory->ort_api_.GetSessionConfigEntry(session_options, key, nullptr, &size); + OrtStatus* status = factory->ort_api_.GetSessionConfigEntry(session_options, key.data(), nullptr, &size); if (status != nullptr) { Ort::Status s(status); - return; + return std::nullopt; + } + if (size == 0) { + return std::nullopt; } - if (size == 0) return; std::vector buf(size); - status = factory->ort_api_.GetSessionConfigEntry(session_options, key, buf.data(), &size); + status = factory->ort_api_.GetSessionConfigEntry(session_options, key.data(), buf.data(), &size); if (status != nullptr) { Ort::Status s(status); - return; + return std::nullopt; } - const std::string val(buf.data()); - value = (val == "1" || val == "true"); + return std::string(buf.data()); }; - auto read_session_config_int = [&](const char* key, int& value) { - size_t size = 0; - OrtStatus* status = factory->ort_api_.GetSessionConfigEntry(session_options, key, nullptr, &size); - if (status != nullptr) { - Ort::Status s(status); + auto log_invalid_session_config = [&](std::string_view key, std::string_view expected) { + if (logger == nullptr) { return; } - if (size == 0) return; - std::vector buf(size); - status = factory->ort_api_.GetSessionConfigEntry(session_options, key, buf.data(), &size); - if (status != nullptr) { - Ort::Status s(status); + + const std::string msg = std::format( + "Failed to parse session config for key '{}'. Expected {}. Using default value.", + key, expected); + + OrtStatus* st = factory->ort_api_.Logger_LogMessage( + logger, ORT_LOGGING_LEVEL_WARNING, msg.c_str(), "cuda_ep_factory.cc", __LINE__, "CudaEpFactory"); + if (st != nullptr) { + factory->ort_api_.ReleaseStatus(st); + } + }; + + auto read_session_config_bool = [&](std::initializer_list keys, bool& value) { + for (const auto& key : keys) { + auto raw_value = try_get_session_config(key); + if (!raw_value.has_value()) { + continue; + } + + const auto normalized = ToUpper(*raw_value); + if (normalized == "1" || normalized == "TRUE") { + value = true; + return; + } + if (normalized == "0" || normalized == "FALSE") { + value = false; + return; + } + + log_invalid_session_config(key, "a boolean"); return; } - try { - value = std::stoi(buf.data()); - } catch (...) { - if (logger) { - std::string msg = std::string("Failed to parse session config for key: ") + key + ". Using default value."; - OrtStatus* st = factory->ort_api_.Logger_LogMessage(logger, ORT_LOGGING_LEVEL_WARNING, msg.c_str(), "cuda_ep_factory.cc", __LINE__, "CudaEpFactory"); - if (st) { - factory->ort_api_.ReleaseStatus(st); - } + }; + + auto read_cudnn_conv_algo = [&](std::initializer_list keys, int& value) { + for (const auto& key : keys) { + auto raw_value = try_get_session_config(key); + if (!raw_value.has_value()) { + continue; + } + + try { + value = std::stoi(*raw_value); + return; + } catch (const std::exception&) { + } + + const auto normalized = ToUpper(*raw_value); + if (normalized == "EXHAUSTIVE") { + value = 0; + return; } + if (normalized == "HEURISTIC") { + value = 1; + return; + } + if (normalized == "DEFAULT") { + value = 2; + return; + } + + log_invalid_session_config(key, "an integer or one of EXHAUSTIVE/HEURISTIC/DEFAULT"); + return; } }; - // Read from flat keys first, then from ep.cuda.* prefixed keys. - // The second pass intentionally overwrites the first so that - // ep.cuda.* takes precedence over unprefixed keys. - read_session_config_bool("prefer_nhwc", config.prefer_nhwc); - read_session_config_bool("use_tf32", config.use_tf32); - read_session_config_bool("enable_skip_layer_norm_strict_mode", config.enable_skip_layer_norm_strict_mode); - read_session_config_bool("cudnn_conv_use_max_workspace", config.cudnn_conv_use_max_workspace); - read_session_config_bool("cudnn_conv1d_pad_to_nc1d", config.cudnn_conv1d_pad_to_nc1d); - read_session_config_int("cudnn_conv_algo", config.cudnn_conv_algo); - - read_session_config_bool("ep.cuda.prefer_nhwc_layout", config.prefer_nhwc); - read_session_config_bool("ep.cuda.use_tf32", config.use_tf32); - read_session_config_bool("ep.cuda.enable_skip_layer_norm_strict_mode", config.enable_skip_layer_norm_strict_mode); - read_session_config_bool("ep.cuda.cudnn_conv_use_max_workspace", config.cudnn_conv_use_max_workspace); - read_session_config_bool("ep.cuda.cudnn_conv1d_pad_to_nc1d", config.cudnn_conv1d_pad_to_nc1d); - read_session_config_int("ep.cuda.cudnn_conv_algo", config.cudnn_conv_algo); + const std::string ep_options_prefix = OrtSessionOptions::GetProviderOptionPrefix(factory->GetEpName().c_str()); + const std::string prefer_nhwc_key = ep_options_prefix + "prefer_nhwc"; + const std::string prefer_nhwc_layout_key = ep_options_prefix + "prefer_nhwc_layout"; + const std::string use_tf32_key = ep_options_prefix + "use_tf32"; + const std::string skip_layer_norm_key = ep_options_prefix + "enable_skip_layer_norm_strict_mode"; + const std::string cudnn_use_max_workspace_key = ep_options_prefix + "cudnn_conv_use_max_workspace"; + const std::string cudnn_conv1d_pad_key = ep_options_prefix + "cudnn_conv1d_pad_to_nc1d"; + const std::string cudnn_conv_algo_key = ep_options_prefix + "cudnn_conv_algo"; + const std::string cudnn_conv_algo_search_key = ep_options_prefix + "cudnn_conv_algo_search"; + + // Prefer plugin-provider-option keys, then fall back to the legacy ep.cuda.* + // aliases and finally to the historical flat session config names. + read_session_config_bool( + {prefer_nhwc_key, prefer_nhwc_layout_key, "ep.cuda.prefer_nhwc_layout", "prefer_nhwc", "prefer_nhwc_layout"}, + config.prefer_nhwc); + read_session_config_bool({use_tf32_key, "ep.cuda.use_tf32", "use_tf32"}, config.use_tf32); + read_session_config_bool( + {skip_layer_norm_key, "ep.cuda.enable_skip_layer_norm_strict_mode", "enable_skip_layer_norm_strict_mode"}, + config.enable_skip_layer_norm_strict_mode); + read_session_config_bool( + {cudnn_use_max_workspace_key, "ep.cuda.cudnn_conv_use_max_workspace", "cudnn_conv_use_max_workspace"}, + config.cudnn_conv_use_max_workspace); + read_session_config_bool( + {cudnn_conv1d_pad_key, "ep.cuda.cudnn_conv1d_pad_to_nc1d", "cudnn_conv1d_pad_to_nc1d"}, + config.cudnn_conv1d_pad_to_nc1d); + read_cudnn_conv_algo( + {cudnn_conv_algo_search_key, cudnn_conv_algo_key, "ep.cuda.cudnn_conv_algo_search", "ep.cuda.cudnn_conv_algo", + "cudnn_conv_algo_search", "cudnn_conv_algo"}, + config.cudnn_conv_algo); const OrtLogger& ep_logger = logger ? *logger : factory->default_logger_; auto actual_ep = std::make_unique(*factory, config, ep_logger); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 2b68d7b2c0cad..f7b7056fd277b 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -604,7 +604,28 @@ static const OrtEpDevice* FindRegisteredPluginEpDevice( if (has_requested_device_id) { Ort::ConstEpDevice current_device(ep_device); - if (static_cast(current_device.Device().DeviceId()) != requested_device_id) { + std::optional current_device_id{}; + if (const char* device_id = current_device.EpOptions().GetValue("device_id"); device_id != nullptr) { + try { + current_device_id = std::stoi(device_id); + } catch (const std::exception&) { + } + } + + if (!current_device_id.has_value()) { + if (const char* device_id = current_device.EpMetadata().GetValue("cuda_device_id"); device_id != nullptr) { + try { + current_device_id = std::stoi(device_id); + } catch (const std::exception&) { + } + } + } + + if (!current_device_id.has_value()) { + current_device_id = static_cast(current_device.Device().DeviceId()); + } + + if (*current_device_id != requested_device_id) { continue; } } diff --git a/onnxruntime/test/framework/dynamic_plugin_ep_test.cc b/onnxruntime/test/framework/dynamic_plugin_ep_test.cc new file mode 100644 index 0000000000000..90aff351b017b --- /dev/null +++ b/onnxruntime/test/framework/dynamic_plugin_ep_test.cc @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/execution_provider.h" +#include "test/unittest_util/test_dynamic_plugin_ep.h" + +#include + +#include "test/util/include/asserts.h" + +namespace onnxruntime::test { + +namespace dynamic_plugin_ep_infra = onnxruntime::test::dynamic_plugin_ep_infra; + +TEST(DynamicPluginEpInfraTest, ParseInitializationConfigParsesOptionalFields) { + constexpr std::string_view kConfigJson = R"json( +{ + "ep_library_registration_name": "CudaPluginExecutionProvider", + "ep_library_path": "/tmp/libonnxruntime_providers_cuda_plugin.so", + "selected_ep_device_indices": [0, 2], + "default_ep_options": { + "ep.cuda.use_tf32": "1", + "ep.cuda.prefer_nhwc_layout": "1" + }, + "tests_to_skip": [ + "CudaTests.SkipMe", + "GraphTests.SkipMeToo" + ] +} +)json"; + + dynamic_plugin_ep_infra::InitializationConfig config{}; + ASSERT_STATUS_OK(dynamic_plugin_ep_infra::ParseInitializationConfig(kConfigJson, config)); + + EXPECT_EQ(config.ep_library_registration_name, "CudaPluginExecutionProvider"); + EXPECT_EQ(config.ep_library_path, "/tmp/libonnxruntime_providers_cuda_plugin.so"); + EXPECT_TRUE(config.selected_ep_name.empty()); + EXPECT_THAT(config.selected_ep_device_indices, ::testing::ElementsAre(0u, 2u)); + EXPECT_THAT(config.default_ep_options, + ::testing::UnorderedElementsAre( + ::testing::Pair("ep.cuda.prefer_nhwc_layout", "1"), + ::testing::Pair("ep.cuda.use_tf32", "1"))); + EXPECT_THAT(config.tests_to_skip, + ::testing::ElementsAre("CudaTests.SkipMe", "GraphTests.SkipMeToo")); +} + +TEST(DynamicPluginEpInfraTest, ParseInitializationConfigDefaultsUnsetOptionalFields) { + constexpr std::string_view kConfigJson = R"json( +{ + "ep_library_registration_name": "ExamplePluginEP", + "ep_library_path": "/tmp/libexample_plugin_ep.so", + "selected_ep_name": "ExampleExecutionProvider" +} +)json"; + + dynamic_plugin_ep_infra::InitializationConfig config{}; + ASSERT_STATUS_OK(dynamic_plugin_ep_infra::ParseInitializationConfig(kConfigJson, config)); + + EXPECT_EQ(config.ep_library_registration_name, "ExamplePluginEP"); + EXPECT_EQ(config.ep_library_path, "/tmp/libexample_plugin_ep.so"); + EXPECT_EQ(config.selected_ep_name, "ExampleExecutionProvider"); + EXPECT_TRUE(config.selected_ep_device_indices.empty()); + EXPECT_TRUE(config.default_ep_options.empty()); + EXPECT_TRUE(config.tests_to_skip.empty()); +} + +TEST(DynamicPluginEpInfraTest, ParseInitializationConfigRejectsMissingRequiredFields) { + constexpr std::string_view kConfigJson = R"json( +{ + "ep_library_registration_name": "CudaPluginExecutionProvider" +} +)json"; + + dynamic_plugin_ep_infra::InitializationConfig config{}; + ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(dynamic_plugin_ep_infra::ParseInitializationConfig(kConfigJson, config), + "JSON parse error"); +} + +TEST(DynamicPluginEpInfraTest, UninitializedStateReturnsSafeDefaults) { + dynamic_plugin_ep_infra::Shutdown(); + + EXPECT_FALSE(dynamic_plugin_ep_infra::IsInitialized()); + EXPECT_EQ(dynamic_plugin_ep_infra::MakeEp(), nullptr); + EXPECT_FALSE(dynamic_plugin_ep_infra::GetEpName().has_value()); + EXPECT_TRUE(dynamic_plugin_ep_infra::GetTestsToSkip().empty()); + + dynamic_plugin_ep_infra::Shutdown(); + + EXPECT_FALSE(dynamic_plugin_ep_infra::IsInitialized()); + EXPECT_FALSE(dynamic_plugin_ep_infra::GetEpName().has_value()); + EXPECT_TRUE(dynamic_plugin_ep_infra::GetTestsToSkip().empty()); +} + +} // namespace onnxruntime::test From 0c311f9d99866d246437c51e56552e73a14e74a4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 26 Mar 2026 09:01:50 -0700 Subject: [PATCH 25/48] Add fuse_conv_bias and sdpa_kernel option --- .../core/providers/cuda/plugin/cuda_ep.cc | 3 ++- .../core/providers/cuda/plugin/cuda_ep.h | 2 ++ .../providers/cuda/plugin/cuda_ep_factory.cc | 26 +++++++++++++++++++ .../cuda/plugin/cuda_kernel_adapter.h | 19 +++++++++++--- .../test/framework/dynamic_plugin_ep_test.cc | 24 +++++++++++++++++ 5 files changed, 69 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index cf43d6c7721cc..b619f845a7b21 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -48,7 +48,8 @@ CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& lo onnxruntime::cuda::SetCudaKernelAdapterRuntimeConfigForProvider( static_cast(static_cast(this)), config_.use_tf32, config_.device_id, config_.enable_skip_layer_norm_strict_mode, - config_.cudnn_conv_algo, config_.cudnn_conv_use_max_workspace, config_.cudnn_conv1d_pad_to_nc1d); + config_.cudnn_conv_algo, config_.cudnn_conv_use_max_workspace, config_.cudnn_conv1d_pad_to_nc1d, + config_.fuse_conv_bias, config_.sdpa_kernel); } CudaEp::~CudaEp() { diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.h b/onnxruntime/core/providers/cuda/plugin/cuda_ep.h index 0e3fe81561af7..1d5ce8a77118e 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.h @@ -25,6 +25,8 @@ class CudaEp : public OrtEp { int cudnn_conv_algo = 0; ///< cuDNN convolution algorithm selection. bool cudnn_conv_use_max_workspace = true; ///< Use maximum workspace for cuDNN conv algo search. bool cudnn_conv1d_pad_to_nc1d = false; ///< Pad 1D convolutions to NC1D format. + bool fuse_conv_bias = false; ///< Enable cuDNN frontend conv+bias fusion. + int sdpa_kernel = 0; ///< Attention backend bitmask override. }; CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& logger); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index a9e89c3df15b9..217b831fd5935 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -346,6 +346,24 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( } }; + auto read_session_config_int = [&](std::initializer_list keys, int& value) { + for (const auto& key : keys) { + auto raw_value = try_get_session_config(key); + if (!raw_value.has_value()) { + continue; + } + + try { + value = std::stoi(*raw_value); + return; + } catch (const std::exception&) { + } + + log_invalid_session_config(key, "an integer"); + return; + } + }; + const std::string ep_options_prefix = OrtSessionOptions::GetProviderOptionPrefix(factory->GetEpName().c_str()); const std::string prefer_nhwc_key = ep_options_prefix + "prefer_nhwc"; const std::string prefer_nhwc_layout_key = ep_options_prefix + "prefer_nhwc_layout"; @@ -355,6 +373,8 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( const std::string cudnn_conv1d_pad_key = ep_options_prefix + "cudnn_conv1d_pad_to_nc1d"; const std::string cudnn_conv_algo_key = ep_options_prefix + "cudnn_conv_algo"; const std::string cudnn_conv_algo_search_key = ep_options_prefix + "cudnn_conv_algo_search"; + const std::string fuse_conv_bias_key = ep_options_prefix + "fuse_conv_bias"; + const std::string sdpa_kernel_key = ep_options_prefix + "sdpa_kernel"; // Prefer plugin-provider-option keys, then fall back to the legacy ep.cuda.* // aliases and finally to the historical flat session config names. @@ -375,6 +395,12 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( {cudnn_conv_algo_search_key, cudnn_conv_algo_key, "ep.cuda.cudnn_conv_algo_search", "ep.cuda.cudnn_conv_algo", "cudnn_conv_algo_search", "cudnn_conv_algo"}, config.cudnn_conv_algo); + read_session_config_bool( + {fuse_conv_bias_key, "ep.cuda.fuse_conv_bias", "fuse_conv_bias"}, + config.fuse_conv_bias); + read_session_config_int( + {sdpa_kernel_key, "ep.cuda.sdpa_kernel", "sdpa_kernel"}, + config.sdpa_kernel); const OrtLogger& ep_logger = logger ? *logger : factory->default_logger_; auto actual_ep = std::make_unique(*factory, config, ep_logger); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 9cabf20a83a36..69eef22f6dbfe 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -345,8 +345,11 @@ struct CudaKernelAdapterRuntimeConfig { int cudnn_conv_algo = 0; bool cudnn_conv_use_max_workspace = true; bool cudnn_conv1d_pad_to_nc1d = false; + bool fuse_conv_bias = false; + int sdpa_kernel = 0; int device_id = 0; cudaDeviceProp device_prop{}; + onnxruntime::AttentionKernelOptions attention_kernel_options; }; // Shared storage for per-provider runtime configurations. // Both Get and Remove must operate on the same static map instance, @@ -511,7 +514,12 @@ class CUDAExecutionProvider : public onnxruntime::IExecutionProvider { return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).use_tf32; } bool IsFuseConvBias() const { - return false; + return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).fuse_conv_bias; + } + const onnxruntime::AttentionKernelOptions* GetAttentionKernelOptions() const { + auto& config = cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this); + config.attention_kernel_options.InitializeOnce(config.sdpa_kernel, true, true); + return &config.attention_kernel_options; } const cudaDeviceProp& GetDeviceProp() const { return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).device_prop; @@ -527,13 +535,17 @@ namespace cuda { inline void SetCudaKernelAdapterRuntimeConfigForProvider(const void* provider, bool use_tf32, int device_id, bool skip_layer_norm_strict_mode = false, int cudnn_conv_algo = 0, bool cudnn_conv_use_max_workspace = true, - bool cudnn_conv1d_pad_to_nc1d = false) { + bool cudnn_conv1d_pad_to_nc1d = false, + bool fuse_conv_bias = false, + int sdpa_kernel = 0) { auto& config = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); config.use_tf32 = use_tf32; config.skip_layer_norm_strict_mode = skip_layer_norm_strict_mode; config.cudnn_conv_algo = cudnn_conv_algo; config.cudnn_conv_use_max_workspace = cudnn_conv_use_max_workspace; config.cudnn_conv1d_pad_to_nc1d = cudnn_conv1d_pad_to_nc1d; + config.fuse_conv_bias = fuse_conv_bias; + config.sdpa_kernel = sdpa_kernel; config.device_id = device_id; PL_CUDA_CALL_THROW(cudaGetDeviceProperties(&config.device_prop, device_id)); } @@ -802,8 +814,7 @@ class CudaKernel : public OpKernel { bool IsArchAvailable(int arch) const { return device_prop_.major >= arch; } const OpKernelInfo& Info() const { return info_; } const onnxruntime::AttentionKernelOptions* GetAttentionKernelOptions() const { - static onnxruntime::AttentionKernelOptions options; - return &options; + return static_cast(info_.GetExecutionProvider())->GetAttentionKernelOptions(); } // Stub for GetTuningContext — tunable ops are not supported in the plugin. diff --git a/onnxruntime/test/framework/dynamic_plugin_ep_test.cc b/onnxruntime/test/framework/dynamic_plugin_ep_test.cc index 90aff351b017b..430e76df6e588 100644 --- a/onnxruntime/test/framework/dynamic_plugin_ep_test.cc +++ b/onnxruntime/test/framework/dynamic_plugin_ep_test.cc @@ -8,6 +8,11 @@ #include "test/util/include/asserts.h" +#if defined(USE_CUDA) && defined(ORT_USE_EP_API_ADAPTERS) +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/providers/cuda/plugin/cuda_kernel_adapter.h" +#endif + namespace onnxruntime::test { namespace dynamic_plugin_ep_infra = onnxruntime::test::dynamic_plugin_ep_infra; @@ -91,4 +96,23 @@ TEST(DynamicPluginEpInfraTest, UninitializedStateReturnsSafeDefaults) { EXPECT_TRUE(dynamic_plugin_ep_infra::GetTestsToSkip().empty()); } +#if defined(USE_CUDA) && defined(ORT_USE_EP_API_ADAPTERS) +TEST(DynamicPluginEpInfraTest, CudaKernelAdapterRuntimeConfigExposesFuseConvBiasAndSdpaKernel) { + onnxruntime::cuda::CUDAExecutionProvider provider{"CudaPluginExecutionProvider"}; + auto& config = onnxruntime::cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(&provider); + config.fuse_conv_bias = true; + config.sdpa_kernel = static_cast(onnxruntime::contrib::attention::AttentionBackend::MATH); + + EXPECT_TRUE(provider.IsFuseConvBias()); + + const auto* attention_kernel_options = provider.GetAttentionKernelOptions(); + EXPECT_TRUE(attention_kernel_options->UseUnfusedAttention()); + EXPECT_FALSE(attention_kernel_options->UseFlashAttention()); + EXPECT_FALSE(attention_kernel_options->UseEfficientAttention()); + EXPECT_FALSE(attention_kernel_options->UseCudnnFlashAttention()); + + onnxruntime::cuda::detail::RemoveCudaKernelAdapterRuntimeConfigForProvider(&provider); +} +#endif + } // namespace onnxruntime::test From 6175224cae756ad4a25f90c530d9264aef02a05f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 26 Mar 2026 10:42:41 -0700 Subject: [PATCH 26/48] RestoreDeviceIfKnown and GetCublasHandleOrDefault --- .../cuda/plugin/cuda_allocator_plugin.cc | 28 ++++++++++++++----- .../core/providers/cuda/tensor/transpose.cc | 3 +- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc index 6708bbd61dacf..2534ce31de6c2 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc @@ -6,6 +6,16 @@ namespace onnxruntime { namespace cuda_plugin { +namespace { + +void RestoreDeviceIfKnown(bool restore_prev_device, int prev_device) noexcept { + if (restore_prev_device) { + static_cast(cudaSetDevice(prev_device)); + } +} + +} // namespace + // --------------------------------------------------------------------------- // CudaDeviceAllocator — uses cudaMalloc/cudaFree for GPU device memory. // Note: No arena or caching layer — every allocation goes directly to CUDA. @@ -30,13 +40,13 @@ CudaDeviceAllocator::CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int d // Save and restore CUDA device context to avoid corrupting the calling // thread's device state in multi-GPU scenarios. int prev_device = -1; - cudaGetDevice(&prev_device); + const bool restore_prev_device = cudaGetDevice(&prev_device) == cudaSuccess; if (cudaSetDevice(alloc->device_id_) != cudaSuccess) { - cudaSetDevice(prev_device); + RestoreDeviceIfKnown(restore_prev_device, prev_device); return nullptr; } cudaError_t err = cudaMalloc(&p, size); - cudaSetDevice(prev_device); + RestoreDeviceIfKnown(restore_prev_device, prev_device); if (err != cudaSuccess) { return nullptr; } @@ -47,10 +57,14 @@ CudaDeviceAllocator::CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int d auto* alloc = static_cast(this_ptr); if (p != nullptr) { int prev_device = -1; - cudaGetDevice(&prev_device); - cudaSetDevice(alloc->device_id_); - cudaFree(p); - cudaSetDevice(prev_device); + const bool restore_prev_device = cudaGetDevice(&prev_device) == cudaSuccess; + if (cudaSetDevice(alloc->device_id_) != cudaSuccess) { + RestoreDeviceIfKnown(restore_prev_device, prev_device); + return; + } + + static_cast(cudaFree(p)); + RestoreDeviceIfKnown(restore_prev_device, prev_device); } } diff --git a/onnxruntime/core/providers/cuda/tensor/transpose.cc b/onnxruntime/core/providers/cuda/tensor/transpose.cc index 82096a2f397a7..7df9238b3acb1 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose.cc +++ b/onnxruntime/core/providers/cuda/tensor/transpose.cc @@ -99,7 +99,8 @@ Status Transpose::DoTranspose(const Transpose& transpose_kernel, onnxruntime::Stream* ort_stream, const gsl::span& permutations, const Tensor& input, Tensor& output) { cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - const cublasHandle_t cublas_handle = transpose_kernel.GetCublasHandle(ort_stream); + const cublasHandle_t cublas_handle = + ort_stream ? transpose_kernel.GetCublasHandle(ort_stream) : transpose_kernel.DefaultCublasHandle(); return Transpose::DoTranspose(transpose_kernel.GetDeviceProp(), cuda_stream, cublas_handle, From d7f52053be1f3c4817171bd2b1f4f4a9425a3d05 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 26 Mar 2026 10:47:19 -0700 Subject: [PATCH 27/48] update import style --- .../transformers/cuda_plugin_ep_helper.py | 4 +- .../transformers/test_cuda_plugin_ep.py | 46 +++++++++---------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py b/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py index 665f1d6828202..bc024c9274f03 100644 --- a/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py +++ b/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py @@ -7,7 +7,6 @@ from pathlib import Path import onnxruntime as onnxrt -from onnxruntime import get_build_info class _CudaPluginRegistrationState: @@ -38,13 +37,14 @@ def _get_package_root(package_name: str, directory_name: str | None = None): if file.name.endswith("__init__.py"): return file.locate().parent except PackageNotFoundError: + # Some test environments only have an in-tree build, not an installed wheel. pass return None def _is_cuda_plugin_ep_built() -> bool: - build_info = get_build_info() + build_info = onnxrt.get_build_info() return ", cuda-plugin-ep=" in build_info diff --git a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py index 75a146d7d3bb0..41132e7bf2b89 100644 --- a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py +++ b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py @@ -6,11 +6,10 @@ import unittest import numpy as np -import onnx import torch import torch.nn.functional as F from cuda_plugin_ep_helper import CUDA_PLUGIN_EP_NAME, ensure_cuda_plugin_ep_registered, should_test_with_cuda_plugin_ep -from onnx import TensorProto, helper, save +from onnx import OperatorSetIdProto, TensorProto, helper, save import onnxruntime as onnxrt @@ -19,6 +18,7 @@ faulthandler.enable() except ImportError: + # faulthandler is optional in some Python runtimes used by CI. pass @@ -119,7 +119,7 @@ def create_conv_model(model_path): ], [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 4, 4])], ) - opset = onnx.OperatorSetIdProto() + opset = OperatorSetIdProto() opset.version = 11 model_def = helper.make_model(graph_def, producer_name="onnx-example", opset_imports=[opset]) save(model_def, model_path) @@ -155,7 +155,7 @@ def create_batch_norm_model(model_path): [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, num_channels, 4, 4])], initializer=[scale_init, bias_init, mean_init, var_init], ) - opset = onnx.OperatorSetIdProto() + opset = OperatorSetIdProto() opset.version = 15 model_def = helper.make_model(graph_def, producer_name="onnx-example", opset_imports=[opset]) save(model_def, model_path) @@ -176,7 +176,7 @@ def create_maxpool_model(model_path): [helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4])], [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 2, 2])], ) - opset = onnx.OperatorSetIdProto() + opset = OperatorSetIdProto() opset.version = 12 model_def = helper.make_model(graph_def, producer_name="onnx-example", opset_imports=[opset]) save(model_def, model_path) @@ -197,7 +197,7 @@ def create_avgpool_model(model_path): [helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4])], [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 2, 2])], ) - opset = onnx.OperatorSetIdProto() + opset = OperatorSetIdProto() opset.version = 12 model_def = helper.make_model(graph_def, producer_name="onnx-example", opset_imports=[opset]) save(model_def, model_path) @@ -223,9 +223,9 @@ def make_bias_dropout_model(): ], [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])], ) - opset_onnx = onnx.OperatorSetIdProto() + opset_onnx = OperatorSetIdProto() opset_onnx.version = 13 - opset_ms = onnx.OperatorSetIdProto() + opset_ms = OperatorSetIdProto() opset_ms.domain = "com.microsoft" opset_ms.version = 1 return helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) @@ -432,10 +432,10 @@ def _make_simple_model(op_type, inputs_info, outputs_info, attrs=None, opset=13, [helper.make_tensor_value_info(n, t, s) for n, t, s in inputs_info], [helper.make_tensor_value_info(n, t, s) for n, t, s in outputs_info], ) - opset_import = [onnx.OperatorSetIdProto()] + opset_import = [OperatorSetIdProto()] opset_import[0].version = opset if domain: - ms_opset = onnx.OperatorSetIdProto() + ms_opset = OperatorSetIdProto() ms_opset.domain = domain ms_opset.version = 1 opset_import.append(ms_opset) @@ -527,7 +527,7 @@ def run_test(name, model, feed, expected_fn, rtol=1e-3, atol=1e-3): [helper.make_tensor_value_info("X", f_dtype, [6, 4])], [helper.make_tensor_value_info("Y1", f_dtype, [3, 4]), helper.make_tensor_value_info("Y2", f_dtype, [3, 4])], ) - opset = onnx.OperatorSetIdProto() + opset = OperatorSetIdProto() opset.version = 13 model = helper.make_model(graph, opset_imports=[opset]) model.graph.initializer.append(helper.make_tensor("split", TensorProto.INT64, [2], [3, 3])) @@ -562,7 +562,7 @@ def run_test(name, model, feed, expected_fn, rtol=1e-3, atol=1e-3): [helper.make_tensor_value_info("X", f_dtype, [3, 4])], [helper.make_tensor_value_info("Y", f_dtype, [1, 3, 4])], ) - opset = onnx.OperatorSetIdProto() + opset = OperatorSetIdProto() opset.version = 13 model = helper.make_model(graph, opset_imports=[opset]) axes_init = helper.make_tensor("axes", TensorProto.INT64, [1], [0]) @@ -578,7 +578,7 @@ def run_test(name, model, feed, expected_fn, rtol=1e-3, atol=1e-3): [helper.make_tensor_value_info("X", f_dtype, [2, 3])], [helper.make_tensor_value_info("Y", f_dtype, [4, 9])], ) - opset = onnx.OperatorSetIdProto() + opset = OperatorSetIdProto() opset.version = 13 model = helper.make_model(graph, opset_imports=[opset]) repeats_init = helper.make_tensor("repeats", TensorProto.INT64, [2], [2, 3]) @@ -594,7 +594,7 @@ def run_test(name, model, feed, expected_fn, rtol=1e-3, atol=1e-3): [helper.make_tensor_value_info("X", f_dtype, [3, 4])], [helper.make_tensor_value_info("Y", f_dtype, [3, 4])], ) - opset = onnx.OperatorSetIdProto() + opset = OperatorSetIdProto() opset.version = 14 model = helper.make_model(graph, opset_imports=[opset]) axis_init = helper.make_tensor("axis", TensorProto.INT64, [], [1]) @@ -612,7 +612,7 @@ def run_test(name, model, feed, expected_fn, rtol=1e-3, atol=1e-3): [helper.make_tensor_value_info("shape", TensorProto.INT64, [2])], [helper.make_tensor_value_info("Y", f_dtype, None)], ) - opset = onnx.OperatorSetIdProto() + opset = OperatorSetIdProto() opset.version = 9 model = helper.make_model(graph, opset_imports=[opset]) run_test( @@ -648,7 +648,7 @@ def space_to_depth(f): [helper.make_tensor_value_info("X", f_dtype, [2, 3])], [helper.make_tensor_value_info("Y", f_dtype, [4, 5])], ) - opset = onnx.OperatorSetIdProto() + opset = OperatorSetIdProto() opset.version = 13 model = helper.make_model(graph, opset_imports=[opset]) model.graph.initializer.append(helper.make_tensor("pads", TensorProto.INT64, [4], [1, 1, 1, 1])) @@ -664,7 +664,7 @@ def space_to_depth(f): [helper.make_tensor_value_info("X", f_dtype, [4, 6])], [helper.make_tensor_value_info("Y", f_dtype, [2, 4])], ) - opset = onnx.OperatorSetIdProto() + opset = OperatorSetIdProto() opset.version = 13 model = helper.make_model(graph, opset_imports=[opset]) model.graph.initializer.append(helper.make_tensor("starts", TensorProto.INT64, [2], [1, 1])) @@ -681,7 +681,7 @@ def space_to_depth(f): [helper.make_tensor_value_info("X", f_dtype, [1, 1, 2, 2])], [helper.make_tensor_value_info("Y", f_dtype, [1, 1, 4, 4])], ) - opset = onnx.OperatorSetIdProto() + opset = OperatorSetIdProto() opset.version = 13 model = helper.make_model(graph, opset_imports=[opset]) model.graph.initializer.append(helper.make_tensor("scales", TensorProto.FLOAT, [4], [1.0, 1.0, 2.0, 2.0])) @@ -711,7 +711,7 @@ def space_to_depth(f): [helper.make_tensor_value_info("X", f_dtype, [1, 1, 2, 2])], [helper.make_tensor_value_info("Y", f_dtype, [1, 1, 4, 4])], ) - opset = onnx.OperatorSetIdProto() + opset = OperatorSetIdProto() opset.version = 9 model = helper.make_model(graph, opset_imports=[opset]) model.graph.initializer.append(helper.make_tensor("scales", TensorProto.FLOAT, [4], [1.0, 1.0, 2.0, 2.0])) @@ -751,9 +751,9 @@ def depth_to_space_dcr(f): [helper.make_tensor_value_info("X", f_dtype, [2, 4])], [helper.make_tensor_value_info("Y", f_dtype, [2, 4])], ) - opset_onnx = onnx.OperatorSetIdProto() + opset_onnx = OperatorSetIdProto() opset_onnx.version = 13 - opset_ms = onnx.OperatorSetIdProto() + opset_ms = OperatorSetIdProto() opset_ms.domain = "com.microsoft" opset_ms.version = 1 model = helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) @@ -809,9 +809,9 @@ def fast_gelu_ref(f): helper.make_tensor_value_info("input_skip_bias_sum", f_dtype, None), ], ) - opset_onnx = onnx.OperatorSetIdProto() + opset_onnx = OperatorSetIdProto() opset_onnx.version = 13 - opset_ms = onnx.OperatorSetIdProto() + opset_ms = OperatorSetIdProto() opset_ms.domain = "com.microsoft" opset_ms.version = 1 model = helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) From 827afa700f11192e466a927b1d90ef3d7026a8ea Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 26 Mar 2026 22:20:19 -0700 Subject: [PATCH 28/48] update tests --- cmake/onnxruntime_providers_cuda_plugin.cmake | 6 +- cmake/onnxruntime_unittests.cmake | 5 - docs/cuda_plugin_ep/cuda_plugin_ep_design.md | 190 ++- include/onnxruntime/ep/adapter/op_kernel.h | 12 +- .../onnxruntime/ep/adapter/op_kernel_info.h | 15 +- onnxruntime/core/providers/cuda/math/gemm.cc | 4 + onnxruntime/core/providers/cuda/nn/conv.cc | 6 +- .../core/providers/cuda/nn/conv_transpose.cc | 6 +- .../cuda/plugin/cuda_controlflow_plugin.cc | 21 +- .../core/providers/cuda/plugin/cuda_ep.cc | 58 +- .../core/providers/cuda/plugin/cuda_ep.h | 3 +- .../providers/cuda/plugin/cuda_ep_factory.cc | 8 +- .../cuda/plugin/cuda_kernel_adapter.h | 194 +-- .../transformers/cuda_plugin_ep_helper.py | 28 +- .../transformers/test_cuda_plugin_ep.py | 1045 +++++++++-------- .../test/python/transformers/test_gqa.py | 10 +- .../test/python/transformers/test_moe_cuda.py | 10 +- 17 files changed, 958 insertions(+), 663 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index 4c5a3b36b548b..84edd0b35e8f0 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -126,7 +126,7 @@ set_target_properties(onnxruntime_providers_cuda_plugin PROPERTIES # Suppress -Werror=maybe-uninitialized for local variables written by # adapter OpKernelInfo::GetAttr<> (GCC falsely warns about variables that are -# initialised inside GetAttr’s output parameter path). +# initialized inside GetAttr’s output parameter path). target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE $<$,$>:-Wno-maybe-uninitialized> ) @@ -250,6 +250,10 @@ target_link_libraries(onnxruntime_providers_cuda_plugin PRIVATE # Symbol visibility — only export CreateEpFactories and ReleaseEpFactory target_compile_definitions(onnxruntime_providers_cuda_plugin PRIVATE ORT_API_MANUAL_INIT BUILD_CUDA_EP_AS_PLUGIN ORT_USE_EP_API_ADAPTERS=1 ONNX_ML=1 ONNX_NAMESPACE=onnx ONNX_USE_LITE_PROTO=1) +if (onnxruntime_USE_CUDA_NHWC_OPS) + target_compile_definitions(onnxruntime_providers_cuda_plugin PRIVATE ENABLE_CUDA_NHWC_OPS) +endif() + if(WIN32) # Windows: use .def file for symbol exports set(CUDA_PLUGIN_DEF_FILE ${CUDA_PLUGIN_EP_DIR}/cuda_plugin_ep_symbols.def) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index f9d6c5d6fd980..59f5e6ff11fd8 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -818,11 +818,6 @@ file(GLOB onnxruntime_test_framework_src CONFIGURE_DEPENDS ${onnxruntime_test_framework_src_patterns} ) -if (NOT (onnxruntime_USE_CUDA AND onnxruntime_BUILD_CUDA_EP_AS_PLUGIN)) - list(REMOVE_ITEM onnxruntime_test_framework_src - "${TEST_SRC_DIR}/framework/dynamic_plugin_ep_test.cc") -endif() - #This is a small wrapper library that shouldn't use any onnxruntime internal symbols(except onnxruntime_common). #Because it could dynamically link to onnxruntime. Otherwise you will have two copies of onnxruntime in the same #process and you won't know which one you are testing. diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md index c7f25b8cedd0c..98856422bffab 100644 --- a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -40,26 +40,28 @@ Each build target uses different preprocessor defines that control how framework ### 2.3 Class Hierarchy ``` -OrtEpFactory OrtEp - ↑ ↑ -CudaEpFactory CudaEp - │ │ - ├─ creates OrtEpDevice ├─ stores session-derived Config - ├─ creates CudaSyncStream └─ seeds adapter runtime config for kernels - ├─ caches kernel registry +OrtEpFactory OrtEp + ↑ ↑ +CudaEpFactory adapter::Ep + │ ↑ + ├─ creates OrtEpDevice CudaEp + ├─ creates CudaSyncStream ├─ stores session-derived Config + ├─ caches kernel registry └─ owns a real shim CUDAExecutionProvider via EpImpl() ├─ caches stable OrtMemoryInfo objects └─ maps OrtHardwareDevice* → CUDA ordinal Migrated CUDA kernels └─ use CudaKernel / cuda_kernel_adapter.h - ├─ read provider config through a phantom CUDAExecutionProvider shim + ├─ receive EpImpl() from info.GetExecutionProvider() + ├─ cast that pointer to the shim CUDAExecutionProvider └─ resolve stream-local handles via CudaSyncStream::FromCudaStream() ``` Key ownership relationships: - `CudaEpFactory` implements raw `OrtEpFactory` callbacks and owns shared factory-level state such as the kernel registry, cached `OrtMemoryInfo` instances, and the hardware-device to CUDA-ordinal map. -- `CudaEp` inherits directly from `OrtEp`; it does not derive from `ep::adapter::Ep` and does not own a separate framework `IExecutionProvider` object. -- The plugin-local `CUDAExecutionProvider` in `cuda_kernel_adapter.h` is a zero-state compatibility shim used by migrated kernels. Runtime state is stored in adapter-side maps keyed by the `CudaEp` address. +- `CudaEp` inherits from `ep::adapter::Ep`, which itself derives from `OrtEp` and owns a framework-facing `IExecutionProvider` object. +- The plugin-local `CUDAExecutionProvider` in `cuda_kernel_adapter.h` is a real shim object owned by `ep::adapter::Ep`. It is not the full in-tree CUDA EP, but it does have its own object identity and may store plugin-specific members such as the wrapped `OrtEp*`. +- Runtime state needed by migrated kernels is still stored in adapter-side maps keyed by the shim provider address (`EpImpl()`), not by the `CudaEp` object address. - `CudaSyncStream` owns `cudaStream_t`, `cublasHandle_t`, `cudnnHandle_t`, and `cublasLtHandle_t` for each sync stream created through the EP API. ### 2.4 Plugin DLL Entry Points @@ -247,13 +249,129 @@ For code paths that need handles without an active stream, `cuda_kernel_adapter. ### 5.3 Provider Access -Kernels access provider configuration through the pointer returned by `info.GetExecutionProvider()`, but in the plugin build that pointer is treated as a phantom `CUDAExecutionProvider` shim. The shim must remain layout-compatible with `IExecutionProvider` and carries no member state; runtime configuration is stored in the adapter-side `ProviderConfigStore`, keyed by the provider address. +Kernels access provider configuration through the pointer returned by `info.GetExecutionProvider()`. In the plugin build, `ep::adapter::OpKernelInfo` snapshots three related pointers from the framework `OpKernelInfo` when the kernel is created: + +- the session-owned outer `PluginExecutionProvider` +- the wrapped plugin `OrtEp` / `CudaEp` +- the inner shim provider returned by `static_cast(ort_ep)->EpImpl()` + +`OpKernelInfo::GetExecutionProvider()` then returns that cached shim pointer, so migrated kernels receive the real shim `CUDAExecutionProvider` object owned by `ep::adapter::Ep`, not the outer `PluginExecutionProvider` and not the `OrtEp`/`CudaEp` object reinterpreted at the same address. + +Caching the shim pointer at kernel-creation time is important for the NHWC path. Re-querying `OrtKernelInfo -> OrtEp -> EpImpl()` during execution was fragile after layout transformation; the current implementation avoids that repeated round-trip and uses the cached shim for runtime provider access. + +This changes the safety model from the earlier "phantom shim" design: +- The shim no longer needs to remain layout-compatible with `IExecutionProvider`. +- Adding plugin-local members to the shim is safe as long as normal C++ object lifetime/ownership rules are respected. +- The shim still is not the full bundled `CUDAExecutionProvider`; it only exposes the subset of methods that migrated kernels currently need. + +Provider options flow through the plugin in two stages: +- `CudaEpFactory` parses session/provider options into `CudaEp::Config`. +- `CudaEp` mirrors the subset needed by migrated kernels into the adapter-side `ProviderConfigStore`, keyed by `EpImpl()`. + +Today that mirrored subset includes TF32, device ID/device properties, cuDNN convolution settings, skip-layer-norm strict mode, fused-conv-bias, and SDPA kernel selection. Other plugin behaviors, such as preferred layout, are handled directly by `CudaEp` callbacks instead of through the shim. For stream bridging, the preferred helpers are: - `Stream(ctx)` when the kernel only needs a raw `cudaStream_t` - `GetComputeStream(ctx)` when the kernel API already accepts the adapter's opaque stream pointer - `GetOrtStream(ctx)` when framework-style `Stream*` plumbing is still needed by shared helper code +### 5.3.1 NHWC Layout-Transformation Support + +The bundled CUDA EP's NHWC path is not just a kernel-registration feature. It is a coordinated contract between provider configuration, ORT's layout transformer, kernel registration, adapter/provider access, and graph partitioning. On the current branch, the CUDA plugin EP now supports this path when NHWC is compiled in and the session requests `prefer_nhwc`. + +#### 5.3.1.1 End-to-End Flow + +When NHWC is enabled for an EP, the expected ORT flow is: + +1. The EP reports `NHWC` from `GetPreferredLayout()` (or `OrtEp::GetPreferredDataLayout()` for plugins) when `prefer_nhwc` is enabled. +2. During layout transformation, ORT asks the EP whether each layout-sensitive op should be converted via `ShouldConvertDataLayoutForOp()`. +3. For each accepted op, `TransformLayoutForEP()` inserts `Transpose` nodes around the operator and rewrites the operator into the internal NHWC domain (`com.ms.internal.nhwc`). +4. Graph partitioning runs again. The EP must now claim the rewritten internal-domain nodes, not the original ONNX-domain nodes. +5. Kernel lookup must succeed against the EP's kernel registry for the rewritten internal-domain node, with matching domain, opset range, and type constraints. + +This means the plugin must satisfy two distinct contracts at the same time: + +| Contract | Owner | Requirement | +|----------|-------|-------------| +| Layout preference contract | `CudaEp` + ORT plugin bridge | Only request NHWC when the plugin can handle the rewritten graph | +| Kernel/capability contract | Kernel registry + `CudaEp::GetCapabilityImpl()` | Claim the resulting `com.ms.internal.nhwc` nodes during partitioning | + +#### 5.3.1.2 Current Plugin Status + +The current branch already has the core runtime pieces in place: + +| Component | Current state | +|-----------|---------------| +| ORT plugin bridge | `PluginExecutionProvider` already maps `OrtEp::GetPreferredDataLayout()` and `OrtEp::ShouldConvertDataLayoutForOp()` into the normal `IExecutionProvider` layout APIs | +| Plugin callback implementations | `CudaEp` installs `GetPreferredDataLayoutImpl()` and `ShouldConvertDataLayoutForOpImpl()` and advertises NHWC when `prefer_nhwc` is enabled | +| Provider option parsing | `CudaEpFactory` already parses `prefer_nhwc` / `prefer_nhwc_layout` into `CudaEp::Config` | +| Build-time gating | `cmake/onnxruntime_providers_cuda_plugin.cmake` propagates `ENABLE_CUDA_NHWC_OPS` to the plugin target when `onnxruntime_USE_CUDA_NHWC_OPS=ON` | +| NHWC kernel registration | NHWC kernels are compiled from the normal CUDA kernel sources and self-register through `PluginKernelCollector`; the centralized `cuda_nhwc_kernels.cc` table stays excluded in plugin builds | +| Second capability pass | `CudaEp::GetCapabilityImpl()` preserves nodes already assigned to `CudaPluginExecutionProvider`, so ORT's post-layout-transformation partitioning pass does not drop rewritten NHWC nodes that were previously selected by the plugin | +| Adapter provider access | `ep::adapter::OpKernelInfo` caches the inner shim `EpImpl()` pointer at kernel-creation time, avoiding a fragile runtime `OrtKernelInfo -> OrtEp -> EpImpl()` round-trip in NHWC kernels | +| Focused validation | `test_cuda_plugin_ep.py` Stage 3 now runs NHWC-requested sessions for Conv, BatchNormalization, MaxPool, and AveragePool and requires plugin-backed execution to succeed numerically | + +The fixes that made this work were not limited to turning the callbacks back on: + +- The plugin now keeps both newly discovered candidate nodes and nodes already assigned to `CudaPluginExecutionProvider` during the second `GetCapability()` pass that runs after layout transformation. +- NHWC kernels now obtain provider configuration through the cached shim pointer in `ep::adapter::OpKernelInfo`, which removed a runtime crash path in migrated kernels such as NHWC `Conv`. + +With those pieces in place, NHWC-requested sessions take the real plugin execution path rather than silently falling back to the stable NCHW path. + +#### 5.3.1.3 NHWC Design Requirements + +The implementation should preserve the following invariants: + +| Requirement | Why it matters | +|-------------|----------------| +| The plugin must never advertise NHWC unless it can claim every internal-domain op it requests ORT to generate | Otherwise ORT can create an invalid graph containing `com.ms.internal.nhwc` nodes that no EP selects | +| The NHWC conversion allowlist must come from a single shared source of truth | The bundled CUDA EP and the plugin EP must not drift on which ops are safe to rewrite | +| Kernel coverage checks must validate internal-domain registrations, not just original ONNX-domain registrations | The rewritten graph uses `com.ms.internal.nhwc`, so ONNX-domain coverage alone is insufficient | +| Capability diagnostics must identify internal-domain kernel misses clearly | NHWC failures are difficult to debug after rewrite unless the missing domain/op/version/type information is surfaced | +| Tests must verify plugin-backed NHWC execution explicitly | Output correctness alone is not enough because a fallback path can still pass numerically | + +#### 5.3.1.4 Implemented Design and Remaining Follow-Ups + +The current branch has already landed the minimum runtime fixes required for plugin-side NHWC execution. The remaining work is mostly cleanup, consolidation, and stronger diagnostics. + +**A. Keep partitioning registry-driven and preserve pre-assigned NHWC nodes** + +`CudaEp::GetCapabilityImpl()` should continue to rely on `EpGraphSupportInfo_LookUpKernel()` as the source of truth for whether a rewritten node is supported. The important implementation detail is that it must preserve nodes already assigned to the plugin when ORT reruns partitioning after layout transformation. + +That behavior is now implemented by tracking: +- `tentative_nodes`: newly discovered nodes with matching kernel registrations +- `candidate_nodes`: both tentative nodes and nodes already assigned to `CudaPluginExecutionProvider` + +The final support set is chosen from `candidate_nodes`, with the existing CPU-preferred-node filtering applied only where appropriate. + +**B. Cache the shim provider pointer at kernel creation** + +Migrated CUDA kernels expect `info.GetExecutionProvider()` to return the shim `CUDAExecutionProvider`, not the outer `PluginExecutionProvider`. The adapter now resolves that relationship once during kernel creation and caches the inner `EpImpl()` pointer. + +This is especially important for NHWC kernels because layout transformation introduces additional runtime paths before the actual CUDA kernel executes. Repeatedly reconstructing `OrtEp -> EpImpl()` from `OrtKernelInfo` during execution proved fragile in that path. The cached-shim approach keeps provider access deterministic and matches the actual object model: + +- outer session EP: `PluginExecutionProvider` +- wrapped plugin object: `CudaEp` / `ep::adapter::Ep` +- inner shim: `CUDAExecutionProvider` returned by `EpImpl()` + +**C. Remaining follow-ups** + +The main follow-ups are now design quality items rather than blockers: + +- Unify the NHWC conversion allowlist between the bundled CUDA EP and the plugin CUDA EP instead of keeping separate hard-coded tables. +- Improve diagnostics when kernel lookup fails for a rewritten `com.ms.internal.nhwc` node. +- Extend tests to assert internal-domain rewrite structure directly, not just plugin-backed execution and numerical correctness. + +#### 5.3.1.5 Rollout Status + +The NHWC rollout is effectively in a "runtime enabled, cleanup remaining" state: + +| Phase | Change | Expected outcome | +|-------|--------|------------------| +| 1 | Enable plugin NHWC callbacks and preserve pre-assigned nodes in the second capability pass | Completed on the current branch | +| 2 | Cache the shim provider pointer in the adapter `OpKernelInfo` | Completed on the current branch; fixes the observed NHWC runtime crash | +| 3 | Consolidate allowlists, improve internal-domain diagnostics, and strengthen structural NHWC assertions | Recommended follow-up work | + ### 5.4 CUDA Graph Support #### 5.4.1 How CUDA Graph Works in Bundled CUDA EP @@ -421,23 +539,23 @@ Section 7 reflects the current source exclusions in `cmake/onnxruntime_providers | `tensor/size.cc` | Pure CPU op, handled by `GetCpuPreferredNodes` | | `tensor/shape_op.cc` | Pure CPU op, inherits from `onnxruntime::OpKernel` (framework) | -### 7.3 Operators Excluded Due to Missing Features +### 7.3 Additional Current Source Exclusions | File / Pattern | Why It Is Excluded Today | What Would Unblock It | |----------------|--------------------------|------------------------| -| `core/providers/cuda/controlflow/*` | Framework controlflow kernels are omitted from the source list | Plugin equivalents already exist in `cuda_controlflow_plugin.cc`; the framework sources stay excluded by design | -| `tunable/*` | Depends on the real tuning context and framework CUDA EP infrastructure | Add a plugin-capable tuning context and remove the remaining tunable guards | -| `math/einsum.cc` | The top-level framework einsum source is still excluded | Provide a plugin-safe top-level einsum provider path; `einsum_utils/*` are no longer excluded | +| `core/providers/cuda/controlflow/*` | The framework controlflow kernels inherit from CPU-side controlflow bases (`If`, `Loop`, `Scan`) and are intentionally omitted from the plugin source list | No change is currently planned. The plugin uses its own `cuda_controlflow_plugin.cc` wrappers instead of these framework sources | +| `tunable/*` | Depends on the real `CudaTuningContext` and other framework CUDA EP infrastructure that is not available in the plugin build | Add a plugin-capable tuning context and remove the remaining framework-only tunable dependencies | +| `math/einsum.cc` | The top-level framework einsum provider path is still not plugin-safe even though supporting utility code is now included | Finish a plugin-safe top-level einsum path and then remove the CMake exclusion | | `tensor/identity_op.cc` | Uses `TensorSeq`, which is still not adapter-safe here | Add `TensorSeq` adapter coverage | | `tensor/sequence_op.cc` | Uses `TensorSeq`, which is still not adapter-safe here | Same as above | -| `contrib_ops/cuda/llm/*` | Contrib LLM kernels still need their own plugin migration pass | Finish contrib-LLM-specific adapter work | -| `contrib_ops/cuda/tensor/shrunken_gather.cc` | Training header path still depends on framework/provider API wiring | Low-priority training-specific adapter work | -| `contrib_ops/cuda/math/fft_ops.cc` | Still excluded in CMake due to remaining framework/stream assumptions | Finish FFT-specific adapter cleanup | -| `contrib_ops/cuda/tensor/crop.cc` | Still excluded in CMake even though the constructor-side helper work is mostly done | Finish and validate the remaining plugin-safe path, then remove the CMake exclusion | +| `contrib_ops/cuda/llm/*` | Contrib LLM sources have not gone through the same adapter-migration pass as the core CUDA LLM kernels | Finish contrib-LLM-specific adapter work | +| `contrib_ops/cuda/tensor/shrunken_gather.cc` | The training-side header path still depends on framework/provider API wiring | Low-priority training-specific adapter work | +| `contrib_ops/cuda/math/fft_ops.cc` | Still has framework stream/type assumptions that are not yet adapter-safe | Finish FFT-specific adapter cleanup | +| `contrib_ops/cuda/tensor/crop.cc` | Still has remaining framework assumptions, so it stays excluded even though some helper-side migration work is already done | Finish and validate the remaining plugin-safe path, then remove the CMake exclusion | | `contrib_ops/cuda/tensor/dynamicslice.cc` | Still excluded in CMake due to remaining framework assumptions | Finish dynamicslice-specific adapter cleanup | -| `contrib_ops/cuda/transformers/*` | Beam search / greedy search / sampling require broader framework integration | Significant adapter and subgraph support work | -| `onnxruntime/contrib_ops/cuda/aten_ops/*` | ATen interop is out of scope for the plugin build | Separate ATen plugin strategy | -| `onnxruntime/contrib_ops/cuda/collective/*` | Collective/NCCL path is out of scope for the plugin build | Separate collective/NCCL plugin strategy | +| `contrib_ops/cuda/transformers/*` | Beam search, greedy search, and sampling depend on broader framework/subgraph integration that has not been adapted for the plugin build | Significant adapter and subgraph support work | +| `contrib_ops/cuda/aten_ops/*` | ATen interop is intentionally out of scope for the standalone CUDA plugin build | A separate ATen/plugin strategy | +| `contrib_ops/cuda/collective/*` | Collective/NCCL support is intentionally out of scope for the standalone CUDA plugin build | A separate collective/NCCL plugin strategy | ### 7.4 Common Exclusion Themes @@ -483,12 +601,6 @@ sh build.sh --config Release --build_dir build/cuda --parallel --use_cuda \ --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES="90" ``` -Or using the existing `cuda.sh` convenience script: - -```bash -./cuda.sh --build --test_plugin # --test_plugin sets BUILD_CUDA_EP_AS_PLUGIN=ON -``` - ### 9.2 Impact on Other Build Targets The `onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON` flag has **no impact** on `libonnxruntime_providers_cuda.so` or `libonnxruntime_providers_shared.so`. It only: @@ -536,7 +648,7 @@ The plugin is then available as `CudaPluginExecutionProvider` in session provide |-------|---------------| | Registration | Dynamic loading via `register_execution_provider_library()` and EP device discovery | | Stage 2 | Basic ops: Add, MatMul, Gemm, Conv | -| Stage 3 | NHWC layout: Conv, BatchNorm, MaxPool, AveragePool | +| Stage 3 | NHWC-requested sessions: Conv, BatchNorm, MaxPool, AveragePool. These validate correctness under `prefer_nhwc` and require plugin-backed NHWC execution to succeed; they are the focused regression suite for the plugin NHWC path | | Stage 5A | Standard ops: Reshape, Split, Concat, Gather, Unsqueeze | | Stage 5B | More ops: Tile, CumSum, ConstantOfShape, SpaceToDepth, Pad, Slice, Resize, Sum | | Stage 5C | CPU base class ops: Upsample, DepthToSpace | @@ -659,14 +771,22 @@ include/onnxruntime/ep/ ## 13. Future Work -1. **Contrib LLM migration pass** — The core CUDA LLM attention path is now adapter-safe, but `contrib_ops/cuda/llm/*` is still excluded as a separate follow-up. +1. **Memory arena / allocator parity** — The plugin currently relies on direct `cudaMalloc`/`cudaFree` paths instead of the in-tree CUDA EP's BFC-style arena. Adding a plugin-side arena or a clean way to reuse ORT's allocator infrastructure would reduce allocation overhead, improve memory reuse, and let the plugin honor options such as `gpu_mem_limit` and `arena_extend_strategy`. + +2. **Profiling and observability** — The in-tree CUDA EP exposes an EP profiler, while the plugin shim currently does not surface equivalent profiling hooks. Future work should wire up `GetProfiler()` for the plugin path, integrate CUDA/NVTX/CUPTI-based tracing where appropriate, and make plugin execution visible in the same profiling flows users already rely on for the bundled CUDA EP. + +3. **Stream/adapter parity for framework-style `Stream*` consumers** — A number of excluded or recently re-included kernels still assume access to a richer framework `Stream*` object rather than only a raw `cudaStream_t` view. Extending the adapter path here would unblock additional LLM, FFT, quantization, diffusion, and other CUDA kernels. + +4. **Contrib LLM migration pass** — The core CUDA LLM attention path is now adapter-safe, but `contrib_ops/cuda/llm/*` remains excluded as a separate follow-up. + +5. **Tunable ops** — Implement a plugin-side `ITuningContext` and remove the `ORT_USE_EP_API_ADAPTERS` guards in `matmul.cc`/`gemm.cc` so the plugin can recover runtime kernel selection and profiling-based tuning behavior. -2. **Tunable ops** — Implement a plugin-side `ITuningContext` and remove the `ORT_USE_EP_API_ADAPTERS` guards in `matmul.cc`/`gemm.cc`. +6. **TensorSeq and additional C API coverage** — Add enough sequence/tensor-sequence support to unblock `identity_op.cc` and `sequence_op.cc`, and extend the ORT C API where needed for remaining framework-style attribute accessors such as string-array attributes used by RNN kernels. -3. **TensorSeq adapter coverage** — Add enough sequence/tensor-sequence support to unblock `identity_op.cc` and `sequence_op.cc`. +7. **Remaining contrib exclusions** — Remove the current CMake exclusions for FFT, crop, dynamicslice, and other remaining contrib paths once their framework assumptions are gone or adapter equivalents exist. -4. **Remaining contrib exclusions** — Remove the CMake exclusions for FFT, crop, and dynamicslice once their remaining framework assumptions are gone. +8. **CI integration and targeted benchmarking** — Add plugin build + test coverage to CI and include perf-oriented validation so allocator, profiling, and tunable-op regressions are caught early. -5. **CI integration** — Add plugin build + test to the CI pipeline. +9. **NHWC cleanup and hardening** — Complete the follow-up work described in [Section 5.3.1](#531-nhwc-layout-transformation-support): unify the allowlist, improve internal-domain diagnostics, and add stronger structural NHWC assertions. -6. **CUDA Graph API for plugin EPs** — Add `IsGraphCaptureEnabled`, `IsGraphCaptured`, and `ReplayGraph` callbacks to the `OrtEp` C API (see [Section 5.4.4](#544-what-needs-to-change-in-ort-core-option-a)). This is required for efficient CUDA graph replay in the plugin EP. The capture/replay infrastructure will be reintroduced once the API is extended. +10. **CUDA Graph API for plugin EPs** — Add `IsGraphCaptureEnabled`, `IsGraphCaptured`, and `ReplayGraph` callbacks to the `OrtEp` C API (see [Section 5.4.4](#544-what-needs-to-change-in-ort-core-option-a)). This is required for efficient CUDA graph replay in the plugin EP. The capture/replay infrastructure will be reintroduced once the API is extended. diff --git a/include/onnxruntime/ep/adapter/op_kernel.h b/include/onnxruntime/ep/adapter/op_kernel.h index 969f34ae820ba..c90ba2205abf1 100644 --- a/include/onnxruntime/ep/adapter/op_kernel.h +++ b/include/onnxruntime/ep/adapter/op_kernel.h @@ -131,10 +131,18 @@ struct OpKernelContext { return Output(index, TensorShape{shape}); } [[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const { - return static_cast(op_kernel_.Info().GetKernelInfo().GetEp())->GetTempSpaceCPUAllocator(output); + const auto* execution_provider = op_kernel_.Info().GetExecutionProvider(); + ORT_ENFORCE(execution_provider != nullptr, "Kernel does not have an execution provider."); + const auto* ort_ep = execution_provider->GetOrtEp(); + ORT_ENFORCE(ort_ep != nullptr, "Kernel execution provider is not associated with an OrtEp instance."); + return static_cast(ort_ep)->GetTempSpaceCPUAllocator(output); } [[nodiscard]] Status GetTempSpaceAllocator(AllocatorPtr* output) const { - return static_cast(op_kernel_.Info().GetKernelInfo().GetEp())->GetTempSpaceAllocator(output); + const auto* execution_provider = op_kernel_.Info().GetExecutionProvider(); + ORT_ENFORCE(execution_provider != nullptr, "Kernel does not have an execution provider."); + const auto* ort_ep = execution_provider->GetOrtEp(); + ORT_ENFORCE(ort_ep != nullptr, "Kernel execution provider is not associated with an OrtEp instance."); + return static_cast(ort_ep)->GetTempSpaceAllocator(output); } int InputCount() const { return static_cast(input_tensors_.size()); diff --git a/include/onnxruntime/ep/adapter/op_kernel_info.h b/include/onnxruntime/ep/adapter/op_kernel_info.h index 638beabb8cbf0..3ed5f3034d9ee 100644 --- a/include/onnxruntime/ep/adapter/op_kernel_info.h +++ b/include/onnxruntime/ep/adapter/op_kernel_info.h @@ -12,6 +12,7 @@ #include "core/common/narrow.h" #include "core/common/status.h" #include "core/framework/config_options.h" +#include "core/framework/op_kernel_info.h" #include "core/framework/tensor_shape.h" #include "core/framework/tensor.h" @@ -42,6 +43,11 @@ struct OpKernelInfo { // to manage the lifetime of the cached data. struct KernelInfoCache { explicit KernelInfoCache(const OrtKernelInfo* kernel_info) : kernel_info_(kernel_info) { + const auto* core_kernel_info = reinterpret_cast(kernel_info); + execution_provider_ = core_kernel_info->GetExecutionProvider(); + ort_ep_ = execution_provider_ != nullptr ? execution_provider_->GetOrtEp() : nullptr; + ep_impl_ = ort_ep_ != nullptr ? (static_cast(ort_ep_))->EpImpl() : execution_provider_; + Ort::ConstKernelInfo info{kernel_info}; const size_t input_count = info.GetInputCount(); constant_input_tensors.resize(input_count); @@ -54,6 +60,9 @@ struct OpKernelInfo { } } const OrtKernelInfo* kernel_info_; + const ::onnxruntime::IExecutionProvider* execution_provider_{}; + const OrtEp* ort_ep_{}; + const ::onnxruntime::IExecutionProvider* ep_impl_{}; std::vector constant_input_tensors; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(KernelInfoCache); }; @@ -62,13 +71,13 @@ struct OpKernelInfo { } const DataTransferManager& GetDataTransferManager() const noexcept { - return (static_cast(info_.GetEp()))->GetDataTransferManager(); + return (static_cast(cache_->ort_ep_))->GetDataTransferManager(); } Node node() const noexcept { return Node{cache_->kernel_info_}; } const IExecutionProvider* GetExecutionProvider() const noexcept { - return (static_cast(info_.GetEp()))->EpImpl(); + return cache_->ep_impl_; } KernelDef GetKernelDef() const noexcept { @@ -76,7 +85,7 @@ struct OpKernelInfo { } const Ort::ConstKernelInfo GetKernelInfo() const noexcept { - return info_; + return Ort::ConstKernelInfo{cache_->kernel_info_}; } ConfigOptions GetConfigOptions() const noexcept { diff --git a/onnxruntime/core/providers/cuda/math/gemm.cc b/onnxruntime/core/providers/cuda/math/gemm.cc index 7fa5e74b54248..59602c2458a7b 100644 --- a/onnxruntime/core/providers/cuda/math/gemm.cc +++ b/onnxruntime/core/providers/cuda/math/gemm.cc @@ -5,7 +5,9 @@ #include "core/providers/cpu/math/gemm_helper.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" +#ifndef BUILD_CUDA_EP_AS_PLUGIN #include "core/providers/cuda/tunable/math/gemm.h" +#endif namespace onnxruntime { namespace cuda { @@ -72,9 +74,11 @@ Status Gemm::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); +#ifndef BUILD_CUDA_EP_AS_PLUGIN if (GetTuningContext()->IsTunableOpEnabled()) { return tunable::TunableGemm(M, N, K, trans_A_, trans_B_, alpha_, B ? beta_ : 0.0f, this, ctx); } +#endif return ComputeDefault(ctx, M, N, K); } diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 7d6335dbd2cbf..56e5a35a5c73d 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -237,8 +237,7 @@ Status Conv::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShap CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_operation_graph(handle)); CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->create_execution_plans({heur_mode})); } catch (const std::exception& ex) { - std::string message = MakeString("Failed to initialize CUDNN Frontend", ex.what(), - "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + std::string message = MakeString("Failed to initialize CUDNN Frontend: ", ex.what()); return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); } @@ -249,8 +248,7 @@ Status Conv::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShap CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_plans(handle)); } catch (const std::exception& ex) { if (!fuse_bias && !fuse_act && use_tf32) { - std::string message = MakeString("OP not supported by CUDNN Frontend", ex.what(), - "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + std::string message = MakeString("OP not supported by CUDNN Frontend: ", ex.what()); return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); } diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 16d219ee4ef1c..a875b1c0b2aaa 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -194,8 +194,7 @@ Status ConvTranspose::CreateCudnnFeExecutionPlan(const onnxruntime::T CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_operation_graph(handle)); CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->create_execution_plans({heur_mode})); } catch (const std::exception& ex) { - std::string message = MakeString("Failed to initialize CUDNN Frontend", ex.what(), - "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + std::string message = MakeString("Failed to initialize CUDNN Frontend: ", ex.what()); return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); } @@ -206,8 +205,7 @@ Status ConvTranspose::CreateCudnnFeExecutionPlan(const onnxruntime::T CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_plans(handle)); } catch (const std::exception& ex) { if (!fuse_bias && !fuse_act && use_tf32) { - std::string message = MakeString("OP not supported by CUDNN Frontend", ex.what(), - "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + std::string message = MakeString("OP not supported by CUDNN Frontend: ", ex.what()); return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc index 8cda066ffccea..31ca602a60b77 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc @@ -282,7 +282,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes()), + // The adapter EP API currently exposes tensor OrtDataType creation only. + .TypeConstraint("V", DataTypeImpl::AllTensorTypes()), PluginIfKernel); ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, @@ -292,7 +293,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), PluginIfKernel); ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, @@ -302,7 +303,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), PluginIfKernel); ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, @@ -312,7 +313,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), PluginIfKernel); ONNX_OPERATOR_KERNEL_EX(If, @@ -322,7 +323,7 @@ ONNX_OPERATOR_KERNEL_EX(If, (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), PluginIfKernel); // --- Loop --- @@ -360,7 +361,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, .InputMemoryType(OrtMemTypeCPUInput, 1) .TypeConstraint("I", DataTypeImpl::GetTensorType()) .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypes()), + .TypeConstraint("V", DataTypeImpl::AllTensorTypes()), PluginLoopKernel); ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, @@ -372,7 +373,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, .InputMemoryType(OrtMemTypeCPUInput, 1) .TypeConstraint("I", DataTypeImpl::GetTensorType()) .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), PluginLoopKernel); ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, @@ -384,7 +385,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, .InputMemoryType(OrtMemTypeCPUInput, 1) .TypeConstraint("I", DataTypeImpl::GetTensorType()) .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), PluginLoopKernel); ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, @@ -396,7 +397,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, .InputMemoryType(OrtMemTypeCPUInput, 1) .TypeConstraint("I", DataTypeImpl::GetTensorType()) .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), PluginLoopKernel); ONNX_OPERATOR_KERNEL_EX(Loop, @@ -408,7 +409,7 @@ ONNX_OPERATOR_KERNEL_EX(Loop, .InputMemoryType(OrtMemTypeCPUInput, 1) .TypeConstraint("I", DataTypeImpl::GetTensorType()) .TypeConstraint("B", DataTypeImpl::GetTensorType()) - .TypeConstraint("V", DataTypeImpl::AllTensorAndSequenceTensorTypesIRv9()), + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), PluginLoopKernel); // --- Scan (opset 8) --- diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index b619f845a7b21..807122ff6ad97 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -5,6 +5,8 @@ #include "cuda_ep_factory.h" #include "cuda_stream_plugin.h" #include "core/providers/cuda/plugin/cuda_kernel_adapter.h" +#include "core/providers/cuda/cuda_allocator.h" +#include "core/framework/allocator.h" #include "ep/get_capability_utils.h" #include @@ -15,8 +17,26 @@ namespace onnxruntime { namespace cuda_plugin { +namespace { + +std::unique_ptr CreateCudaPluginProvider(std::string_view ep_name, const OrtEp* ort_ep) { + return std::make_unique<::onnxruntime::CUDAExecutionProvider>(std::string{ep_name}, ort_ep); +} + +AllocatorPtr CreateCudaPluginTempSpaceAllocator(int device_id) { + return std::make_shared<::onnxruntime::CUDAAllocator>(device_id, ::onnxruntime::CUDA); +} + +AllocatorPtr CreateCudaPluginTempSpaceCpuAllocator() { + return ::onnxruntime::CPUAllocator::DefaultInstance(); +} + +} // namespace + CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& logger) - : OrtEp{}, + : onnxruntime::ep::adapter::Ep{CreateCudaPluginProvider(factory.GetEpName(), static_cast(this)), + CreateCudaPluginTempSpaceCpuAllocator(), + CreateCudaPluginTempSpaceAllocator(config.device_id)}, factory_(factory), name_(factory.GetEpName()), config_(config), @@ -42,18 +62,17 @@ CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& lo ORT_FILE, __LINE__, __FUNCTION__)); // Store per-EP runtime configuration (TF32, device ID, tuning options, etc.) - // in a global map keyed by OrtEp pointer. Migrated kernels retrieve these - // settings at runtime via GetCudaKernelAdapterRuntimeConfig() without needing - // to thread the config through multiple layers of framework code. + // in a global map keyed by the adapter-wrapped execution provider. Migrated + // kernels retrieve these settings via info.GetExecutionProvider(). onnxruntime::cuda::SetCudaKernelAdapterRuntimeConfigForProvider( - static_cast(static_cast(this)), + static_cast(EpImpl()), config_.use_tf32, config_.device_id, config_.enable_skip_layer_norm_strict_mode, config_.cudnn_conv_algo, config_.cudnn_conv_use_max_workspace, config_.cudnn_conv1d_pad_to_nc1d, config_.fuse_conv_bias, config_.sdpa_kernel); } CudaEp::~CudaEp() { - onnxruntime::cuda::detail::RemoveCudaKernelAdapterRuntimeConfigForProvider(static_cast(static_cast(this))); + onnxruntime::cuda::detail::RemoveCudaKernelAdapterRuntimeConfigForProvider(static_cast(EpImpl())); } /*static*/ @@ -84,13 +103,17 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( // Phase 3: Register remaining nodes as supported by this EP. // Phase 1: Collect tentative nodes — those for which we have a registered kernel. + std::vector candidate_nodes; + candidate_nodes.reserve(all_nodes.size()); std::vector tentative_nodes; tentative_nodes.reserve(all_nodes.size()); for (const auto& node : all_nodes) { - // Skip nodes already assigned to another EP. std::string ep_name = node.GetEpName(); if (!ep_name.empty()) { + if (ep_name == ep->name_) { + candidate_nodes.push_back(node); + } continue; } @@ -99,6 +122,7 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( graph_support_info, node, &kernel_def)); if (kernel_def != nullptr) { + candidate_nodes.push_back(node); tentative_nodes.push_back(node); } } @@ -112,7 +136,7 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( cpu_preferred_nodes)); // Phase 3: Add final supported nodes (tentative minus CPU-preferred). - for (const OrtNode* ort_node : tentative_nodes) { + for (const OrtNode* ort_node : candidate_nodes) { if (cpu_preferred_nodes.count(ort_node) == 0) { Ort::ConstNode node{ort_node}; RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_AddSingleNode( @@ -140,7 +164,12 @@ OrtStatus* ORT_API_CALL CudaEp::GetKernelRegistryImpl( OrtStatus* ORT_API_CALL CudaEp::GetPreferredDataLayoutImpl( OrtEp* this_ptr, OrtEpDataLayout* preferred_data_layout) noexcept { const auto* ep = static_cast(this_ptr); +#ifdef ENABLE_CUDA_NHWC_OPS *preferred_data_layout = ep->config_.prefer_nhwc ? OrtEpDataLayout_NHWC : OrtEpDataLayout_NCHW; +#else + ORT_UNUSED_PARAMETER(ep); + *preferred_data_layout = OrtEpDataLayout_NCHW; +#endif return nullptr; } @@ -148,7 +177,15 @@ OrtStatus* ORT_API_CALL CudaEp::GetPreferredDataLayoutImpl( OrtStatus* ORT_API_CALL CudaEp::ShouldConvertDataLayoutForOpImpl( OrtEp* this_ptr, const char* domain, const char* op_type, OrtEpDataLayout target_data_layout, int* should_convert) noexcept { - (void)this_ptr; + ORT_UNUSED_PARAMETER(this_ptr); + +#ifndef ENABLE_CUDA_NHWC_OPS + ORT_UNUSED_PARAMETER(domain); + ORT_UNUSED_PARAMETER(op_type); + ORT_UNUSED_PARAMETER(target_data_layout); + *should_convert = 0; // NHWC kernels are not compiled into this plugin build. + return nullptr; +#else // Only convert to NHWC; for any other target layout, let ORT decide. if (target_data_layout != OrtEpDataLayout_NHWC) { @@ -185,8 +222,9 @@ OrtStatus* ORT_API_CALL CudaEp::ShouldConvertDataLayoutForOpImpl( return nullptr; } - *should_convert = -1; // Let ORT decide for other ops + *should_convert = 0; // Explicitly decline conversion for unsupported NHWC ops. return nullptr; +#endif } } // namespace cuda_plugin diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.h b/onnxruntime/core/providers/cuda/plugin/cuda_ep.h index 1d5ce8a77118e..5f961fe3b0a8c 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.h @@ -4,6 +4,7 @@ #pragma once #include "cuda_plugin_utils.h" +#include "ep/adapters.h" #include #include @@ -14,7 +15,7 @@ namespace cuda_plugin { class CudaEpFactory; /// CUDA execution provider implementation using public OrtEp interface. -class CudaEp : public OrtEp { +class CudaEp : public onnxruntime::ep::adapter::Ep { public: /// Configuration parameters for the CUDA EP, parsed from session options. struct Config { diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index 217b831fd5935..3e3eeb5b4af1a 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -4,7 +4,7 @@ #include "cuda_ep_factory.h" #include "cuda_ep.h" #include "cuda_plugin_kernels.h" -#include "core/session/abi_session_options_impl.h" +#include "core/common/string_utils.h" #include #include @@ -116,6 +116,10 @@ std::string ToUpper(std::string value) { return value; } +std::string GetProviderOptionPrefix(std::string_view provider_name) { + return std::format("ep.{}.", onnxruntime::utils::GetLowercaseString(std::string{provider_name})); +} + } // namespace /*static*/ @@ -364,7 +368,7 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( } }; - const std::string ep_options_prefix = OrtSessionOptions::GetProviderOptionPrefix(factory->GetEpName().c_str()); + const std::string ep_options_prefix = GetProviderOptionPrefix(factory->GetEpName()); const std::string prefer_nhwc_key = ep_options_prefix + "prefer_nhwc"; const std::string prefer_nhwc_layout_key = ep_options_prefix + "prefer_nhwc_layout"; const std::string use_tf32_key = ep_options_prefix + "use_tf32"; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 69eef22f6dbfe..cbea32b81f6ed 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -40,6 +40,7 @@ // =================================================================== #include "core/providers/cuda/plugin/cuda_stream_plugin.h" +#include "core/providers/cuda/gpu_data_transfer.h" #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/session/onnxruntime_cxx_api.h" @@ -96,6 +97,10 @@ class OrtStreamAdapter { #include "core/framework/op_kernel.h" #include "core/providers/common.h" +namespace onnxruntime { +inline constexpr const char* kCudaPluginExecutionProvider = "CudaPluginExecutionProvider"; +} + namespace onnxruntime { namespace cuda { @@ -218,47 +223,49 @@ class PluginKernelCollector { // These macros mirror the framework's ONNX_OPERATOR_*_KERNEL_EX definitions // from core/framework/op_kernel.h, but additionally register each // BuildKernelCreateInfoFn into PluginKernelCollector at static init time. - #define ORT_ADAPTER_CONCAT_IMPL(x, y) x##y #define ORT_ADAPTER_CONCAT(x, y) ORT_ADAPTER_CONCAT_IMPL(x, y) +// The provider parameter are not used in below macros since we are hardcoding the provider to cuda plugin. +#define CUDA_PLUGIN_EP ::onnxruntime::kCudaPluginExecutionProvider + #undef ONNX_OPERATOR_KERNEL_EX -#define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \ - class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \ - template <> \ - KernelCreateInfo \ - BuildKernelCreateInfo() { \ - return KernelCreateInfo( \ - builder.SetName(#name).SetDomain(domain).SinceVersion(ver).Provider(provider).Build(), \ - static_cast( \ - [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ - out = std::make_unique<__VA_ARGS__>(info); \ - return Status::OK(); \ - })); \ - } \ - static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_, __COUNTER__) = \ - (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ - &BuildKernelCreateInfo), \ +#define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \ + class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(ver).Provider(CUDA_PLUGIN_EP).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ true); #undef ONNX_OPERATOR_VERSIONED_KERNEL_EX -#define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, provider, builder, ...) \ - class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name); \ - template <> \ - KernelCreateInfo \ - BuildKernelCreateInfo() { \ - return KernelCreateInfo( \ - builder.SetName(#name).SetDomain(domain).SinceVersion(startver, endver).Provider(provider).Build(), \ - static_cast( \ - [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ - out = std::make_unique<__VA_ARGS__>(info); \ - return Status::OK(); \ - })); \ - } \ - static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_, __COUNTER__) = \ - (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ - &BuildKernelCreateInfo), \ +#define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, provider, builder, ...) \ + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(startver, endver).Provider(CUDA_PLUGIN_EP).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ true); #undef ONNX_OPERATOR_TYPED_KERNEL_EX @@ -268,7 +275,7 @@ class PluginKernelCollector { KernelCreateInfo \ BuildKernelCreateInfo() { \ return KernelCreateInfo( \ - builder.SetName(#name).SetDomain(domain).SinceVersion(ver).Provider(provider).Build(), \ + builder.SetName(#name).SetDomain(domain).SinceVersion(ver).Provider(CUDA_PLUGIN_EP).Build(), \ static_cast( \ [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ out = std::make_unique<__VA_ARGS__>(info); \ @@ -281,24 +288,24 @@ class PluginKernelCollector { true); #undef ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX -#define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, domain, startver, endver, type, provider, builder, ...) \ - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name); \ - template <> \ - KernelCreateInfo \ - BuildKernelCreateInfo() { \ - return KernelCreateInfo( \ - builder.SetName(#name).SetDomain(domain).SinceVersion(startver, endver).Provider(provider).Build(), \ - static_cast( \ - [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ - out = std::make_unique<__VA_ARGS__>(info); \ - return Status::OK(); \ - })); \ - } \ - static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_##type##_, __COUNTER__) = \ - (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ - &BuildKernelCreateInfo), \ +#define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, domain, startver, endver, type, provider, builder, ...) \ + class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(startver, endver).Provider(CUDA_PLUGIN_EP).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_##type##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ true); // =================================================================== @@ -353,7 +360,7 @@ struct CudaKernelAdapterRuntimeConfig { }; // Shared storage for per-provider runtime configurations. // Both Get and Remove must operate on the same static map instance, -// so we centralise them in a single struct with static lifetime. +// so we centralize them in a single struct with static lifetime. struct ProviderConfigStore { std::shared_mutex mutex; std::unordered_map> configs; @@ -475,31 +482,24 @@ inline const cudaDeviceProp& GetDevicePropForDevice(int device_id) { // (GetCudnnConvAlgo, UseTF32, GetDeviceProp, etc.) without the full // CUDAExecutionProvider class from onnxruntime/core/providers/cuda/. // -// DESIGN NOTE: Why does this class have no state/member variables? -// In the plugin build, the object returned by `info.GetExecutionProvider()` -// is an opaque C-API struct (`OrtEp*`/`CudaEp*`), NOT this class. -// The raw kernel code performs `static_cast` on it. -// If this shim class defined any member variables (e.g., `config_`), the -// compiler would read them at specific byte offsets relative to `this`, causing -// memory layout UB (garbage reads/segfaults) since the underlying object in -// memory is actually an `OrtEp`. -// Therefore, `CUDAExecutionProvider` here must remain a pure "phantom shim." -// To safely access state (like TF32 settings), it dynamically queries a static -// map keyed by its own `this` pointer (which equals the `CudaEp*` memory address). +// In the plugin build this shim is wrapped by adapter::Ep, so migrated CUDA +// kernels can keep casting `info.GetExecutionProvider()` to +// `CUDAExecutionProvider*` and retrieve the plugin `OrtEp` via GetOrtEp(). // =================================================================== // Shim for CUDAExecutionProvider required by conv.cc, einsum, and others class CUDAExecutionProvider : public onnxruntime::IExecutionProvider { public: - explicit CUDAExecutionProvider(const std::string& name) : onnxruntime::IExecutionProvider{name} {} + explicit CUDAExecutionProvider(const std::string& name, const OrtEp* ort_ep = nullptr) + : onnxruntime::IExecutionProvider{name}, ort_ep_{ort_ep} {} + + std::unique_ptr GetDataTransfer() const override { + return std::make_unique(); + } - // SAFETY: This class must remain empty (no added member variables beyond - // IExecutionProvider). In the plugin build, an OrtEp*/CudaEp* is cast to - // CUDAExecutionProvider*. Adding members would cause the compiler to read - // them at incorrect byte offsets, silently corrupting data. All runtime - // state is stored in ProviderConfigStore, keyed by `this`. - // If the static_assert below fires, move the new state into - // CudaKernelAdapterRuntimeConfig instead of adding members here. + const OrtEp* GetOrtEp() const override { + return ort_ep_; + } int GetCudnnConvAlgo() const { return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).cudnn_conv_algo; @@ -524,11 +524,10 @@ class CUDAExecutionProvider : public onnxruntime::IExecutionProvider { const cudaDeviceProp& GetDeviceProp() const { return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).device_prop; } -}; -// Verify CUDAExecutionProvider has no added members — see phantom cast design note. -static_assert(sizeof(CUDAExecutionProvider) == sizeof(onnxruntime::IExecutionProvider), - "CUDAExecutionProvider must not add member variables."); + private: + const OrtEp* ort_ep_ = nullptr; +}; namespace cuda { @@ -783,7 +782,19 @@ class CudaKernel : public OpKernel { static inline cudnnHandle_t GetCudnnHandle(onnxruntime::Stream* stream) { return stream ? GetCudnnHandle(static_cast(stream->GetHandle())) : nullptr; } - cudnnHandle_t GetCudnnHandle(OpKernelContext* ctx) const { return GetCudnnHandle(Stream(ctx)); } + cudnnHandle_t GetCudnnHandle(OpKernelContext* ctx) const { + auto stream = Stream(ctx); + auto handle = GetCudnnHandle(stream); + if (handle != nullptr) { + return handle; + } + + handle = DefaultCudnnHandle(); + if (stream != nullptr) { + CUDNN_CALL_THROW(cudnnSetStream(handle, stream)); + } + return handle; + } static cublasHandle_t GetCublasHandle(cudaStream_t s) { auto* sync = cuda_plugin::CudaSyncStream::FromCudaStream(s); @@ -795,7 +806,19 @@ class CudaKernel : public OpKernel { static inline cublasHandle_t GetCublasHandle(onnxruntime::Stream* stream) { return stream ? GetCublasHandle(static_cast(stream->GetHandle())) : nullptr; } - cublasHandle_t GetCublasHandle(OpKernelContext* ctx) const { return GetCublasHandle(Stream(ctx)); } + cublasHandle_t GetCublasHandle(OpKernelContext* ctx) const { + auto stream = Stream(ctx); + auto handle = GetCublasHandle(stream); + if (handle != nullptr) { + return handle; + } + + handle = DefaultCublasHandle(); + if (stream != nullptr) { + CUBLAS_CALL_THROW(cublasSetStream(handle, stream)); + } + return handle; + } static cublasLtHandle_t GetCublasLtHandle(cudaStream_t s) { auto* sync = cuda_plugin::CudaSyncStream::FromCudaStream(s); @@ -809,9 +832,18 @@ class CudaKernel : public OpKernel { } cublasLtHandle_t GetCublasLtHandle(OpKernelContext* ctx) const { return GetCublasLtHandle(Stream(ctx)); } - const cudaDeviceProp& GetDeviceProp() const { return device_prop_; } + const cudaDeviceProp& GetDeviceProp() const { + // Some migrated kernels size their launches from device properties. If the + // per-provider cache was not populated for this kernel instance, fall back + // to a direct lookup instead of returning an all-zero struct. + if (device_prop_.maxThreadsPerMultiProcessor == 0 || device_prop_.multiProcessorCount == 0) { + return detail::GetDevicePropForDevice(device_id_); + } + + return device_prop_; + } bool UseTF32() const { return use_tf32_; } - bool IsArchAvailable(int arch) const { return device_prop_.major >= arch; } + bool IsArchAvailable(int arch) const { return GetDeviceProp().major >= arch; } const OpKernelInfo& Info() const { return info_; } const onnxruntime::AttentionKernelOptions* GetAttentionKernelOptions() const { return static_cast(info_.GetExecutionProvider())->GetAttentionKernelOptions(); diff --git a/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py b/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py index bc024c9274f03..581edc5940c77 100644 --- a/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py +++ b/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py @@ -6,6 +6,8 @@ from importlib.metadata import PackageNotFoundError, distribution from pathlib import Path +import torch + import onnxruntime as onnxrt @@ -45,7 +47,15 @@ def _get_package_root(package_name: str, directory_name: str | None = None): def _is_cuda_plugin_ep_built() -> bool: build_info = onnxrt.get_build_info() - return ", cuda-plugin-ep=" in build_info + if ", cuda-plugin-ep=" in build_info: + return True + + ep_lib_path = os.environ.get("ORT_CUDA_PLUGIN_PATH", "") + if ep_lib_path and os.path.exists(ep_lib_path): + return True + + detected_path = _get_default_cuda_plugin_ep_path() + return bool(detected_path and os.path.exists(detected_path)) def _get_cuda_plugin_library_name() -> str: @@ -159,6 +169,22 @@ def resolve_cuda_plugin_ep(ep: str, default_test_with_cuda_plugin_ep: bool = Tru return ep +def get_cuda_provider_name() -> str | None: + if not torch.cuda.is_available(): + return None + + resolved_provider = resolve_cuda_plugin_ep("CUDAExecutionProvider") + available_providers = onnxrt.get_available_providers() + + if resolved_provider in available_providers: + return resolved_provider + + if "CUDAExecutionProvider" in available_providers: + return "CUDAExecutionProvider" + + return None + + def _is_plugin_provider_type_available() -> bool: try: return CUDA_PLUGIN_EP_NAME in onnxrt.get_available_providers() diff --git a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py index 41132e7bf2b89..ea645bcbdc39b 100644 --- a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py +++ b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py @@ -25,6 +25,7 @@ TEST_PASS = "PASS" TEST_SKIP = "SKIP" TEST_FAIL = "FAIL" +EP_GRAPH_ASSIGNMENT_CONFIG_KEY = "session.record_ep_graph_assignment_info" def require_cuda_plugin_ep(): @@ -50,6 +51,46 @@ def get_cuda_plugin_device(): return plugin_devices[0] +def _create_session_options(session_config=None): + sess_options = onnxrt.SessionOptions() + if session_config: + for key, value in session_config.items(): + sess_options.add_session_config_entry(key, value) + + # Require graph-assignment data so the tests validate that nodes actually run on the plugin. + sess_options.add_session_config_entry(EP_GRAPH_ASSIGNMENT_CONFIG_KEY, "1") + return sess_options + + +def _format_assigned_node(node): + domain = node.domain or "ai.onnx" + if node.name: + return f"{domain}::{node.op_type}:{node.name}" + return f"{domain}::{node.op_type}" + + +def _get_assigned_nodes(session, ep_name): + assignment_info = list(session.get_provider_graph_assignment_info()) + assigned_nodes = [] + for subgraph in assignment_info: + if subgraph.ep_name == ep_name: + assigned_nodes.extend(subgraph.get_nodes()) + + return assigned_nodes, assignment_info + + +def _format_assignment_summary(assignment_info): + if not assignment_info: + return "" + + summaries = [] + for subgraph in assignment_info: + node_summary = ", ".join(_format_assigned_node(node) for node in subgraph.get_nodes()) or "" + summaries.append(f"{subgraph.ep_name}[{node_summary}]") + + return "; ".join(summaries) + + def create_add_model(model_path): # Create a simple Add model: Y = A + B node_def = helper.make_node("Add", ["A", "B"], ["Y"]) @@ -238,22 +279,25 @@ def run_operator_test( model_path = tmp.name try: model_creator(model_path) - sess_options = onnxrt.SessionOptions() - if session_config: - for key, value in session_config.items(): - sess_options.add_session_config_entry(key, value) + sess_options = _create_session_options(session_config) sess_options.add_provider_for_devices([target_device], {}) sess = onnxrt.InferenceSession(model_path, sess_options=sess_options) active_providers = sess.get_providers() - if ep_name not in active_providers: - print(f"FAILURE: {ep_name} is NOT active for this operator. Providers: {active_providers}") + assigned_nodes, assignment_info = _get_assigned_nodes(sess, ep_name) + if not assigned_nodes: + print( + f"FAILURE: {ep_name} was assigned no nodes. Providers: {active_providers}. " + f"Assignments: {_format_assignment_summary(assignment_info)}" + ) return False - print(f"(Session created with {active_providers})", end=" ", flush=True) - print("Running...", end=" ", flush=True) + print( + f"(Session created with {active_providers}; assigned nodes: " + f"{', '.join(_format_assigned_node(node) for node in assigned_nodes)})", + flush=True, + ) res = sess.run(None, inputs) - print("Done.", end=" ", flush=True) expected = expected_fn(inputs) np.testing.assert_allclose(res[0], expected, rtol=1e-3, atol=1e-3) return True @@ -269,14 +313,21 @@ def run_provider_options_test(provider_options, expect_plugin_provider=True): try: create_add_model(model_path) providers = [(CUDA_PLUGIN_EP_NAME, provider_options), "CPUExecutionProvider"] - sess = onnxrt.InferenceSession(model_path, providers=providers) + sess = onnxrt.InferenceSession(model_path, sess_options=_create_session_options(), providers=providers) active_providers = sess.get_providers() + assigned_nodes, assignment_info = _get_assigned_nodes(sess, CUDA_PLUGIN_EP_NAME) - if expect_plugin_provider and CUDA_PLUGIN_EP_NAME not in active_providers: - print(f"FAILURE: {CUDA_PLUGIN_EP_NAME} is NOT active. Providers: {active_providers}") + if expect_plugin_provider and not assigned_nodes: + print( + f"FAILURE: {CUDA_PLUGIN_EP_NAME} was assigned no nodes. Providers: {active_providers}. " + f"Assignments: {_format_assignment_summary(assignment_info)}" + ) return False - if not expect_plugin_provider and CUDA_PLUGIN_EP_NAME in active_providers: - print(f"FAILURE: {CUDA_PLUGIN_EP_NAME} unexpectedly active. Providers: {active_providers}") + if not expect_plugin_provider and assigned_nodes: + print( + f"FAILURE: {CUDA_PLUGIN_EP_NAME} unexpectedly owned nodes. " + f"Assignments: {_format_assignment_summary(assignment_info)}" + ) return False a = np.random.rand(3, 2).astype(np.float32) @@ -296,120 +347,15 @@ def run_provider_options_test(provider_options, expect_plugin_provider=True): os.remove(model_path) -def _run_registration_checks(test_case: unittest.TestCase): - target_device = get_cuda_plugin_device() - print(f"Using registered plugin: {CUDA_PLUGIN_EP_NAME}", flush=True) - print(f"Using device: {target_device.ep_name}", flush=True) - - x = np.random.rand(1, 2, 4, 4).astype(np.float32) - w = np.random.rand(3, 2, 3, 3).astype(np.float32) - - def expected_conv(inputs): - return F.conv2d(torch.from_numpy(inputs["X"]), torch.from_numpy(inputs["W"]), padding=1).numpy() - - stage2_cases = [ - ( - "Add", - create_add_model, - {"A": np.random.rand(3, 2).astype(np.float32), "B": np.random.rand(3, 2).astype(np.float32)}, - lambda feed: feed["A"] + feed["B"], - None, - ), - ( - "MatMul", - create_matmul_model, - {"A": np.random.rand(3, 4).astype(np.float32), "B": np.random.rand(4, 5).astype(np.float32)}, - lambda feed: feed["A"] @ feed["B"], - None, - ), - ( - "Gemm", - lambda model_path: create_gemm_model(model_path, alpha=2.0, beta=0.5), - { - "A": np.random.rand(3, 4).astype(np.float32), - "B": np.random.rand(4, 5).astype(np.float32), - "C": np.random.rand(5).astype(np.float32), - }, - lambda feed: 2.0 * (feed["A"] @ feed["B"]) + 0.5 * feed["C"], - None, - ), - ("Conv", create_conv_model, {"X": x, "W": w}, expected_conv, None), - ] - - for name, model_creator, inputs, expected_fn, session_config in stage2_cases: - print(f"Testing {name}...", end=" ", flush=True) - result = run_operator_test(target_device, model_creator, inputs, expected_fn, session_config=session_config) - with test_case.subTest(op=name): - test_case.assertTrue( - result, - f"{name} plugin registration test failed", - ) - print(TEST_PASS if result else TEST_FAIL, flush=True) - - print("\nAll Stage 2 tests finished successfully.", flush=True) - - nhwc_config = {"ep.cuda.prefer_nhwc_layout": "1"} - - def expected_batchnorm(inputs): - return inputs["X"] / np.sqrt(1.0 + 1e-5) - - stage3_cases = [ - ( - "Conv (NHWC)", - create_conv_model, - { - "X": np.random.rand(1, 2, 4, 4).astype(np.float32), - "W": np.random.rand(3, 2, 3, 3).astype(np.float32), - }, - expected_conv, - ), - ( - "BatchNormalization (NHWC)", - create_batch_norm_model, - {"X": np.random.rand(1, 3, 4, 4).astype(np.float32)}, - expected_batchnorm, - ), - ( - "MaxPool (NHWC)", - create_maxpool_model, - {"X": np.random.rand(1, 3, 4, 4).astype(np.float32)}, - lambda feed: F.max_pool2d(torch.from_numpy(feed["X"]), kernel_size=2, stride=2).numpy(), - ), - ( - "AveragePool (NHWC)", - create_avgpool_model, - {"X": np.random.rand(1, 3, 4, 4).astype(np.float32)}, - lambda feed: F.avg_pool2d(torch.from_numpy(feed["X"]), kernel_size=2, stride=2).numpy(), - ), - ] - - for name, model_creator, inputs, expected_fn in stage3_cases: - print(f"Testing {name}...", end=" ", flush=True) - result = run_operator_test(target_device, model_creator, inputs, expected_fn, session_config=nhwc_config) - with test_case.subTest(op=name): - test_case.assertTrue( - result, - f"{name} plugin NHWC test failed", - ) - print(TEST_PASS if result else TEST_FAIL, flush=True) - - print("\nAll Stage 3 NHWC tests finished successfully.", flush=True) - - provider_option_cases = [ - ("provider options with valid device_id/use_tf32", {"device_id": "0", "use_tf32": "0"}, True), - ("provider options with invalid device_id", {"device_id": "999"}, False), - ] - - print("\nTesting provider options path...", flush=True) - for name, provider_options, expect_plugin_provider in provider_option_cases: - print(f"Testing {name}...", end=" ", flush=True) - result = run_provider_options_test(provider_options, expect_plugin_provider=expect_plugin_provider) - with test_case.subTest(op=name): - test_case.assertTrue( - result, - f"{name} failed", - ) - print(TEST_PASS if result else TEST_FAIL, flush=True) +def _expected_conv(inputs): + return F.conv2d(torch.from_numpy(inputs["X"]), torch.from_numpy(inputs["W"]), padding=1).numpy() + + +_NHWC_CONFIG = {"ep.cuda.prefer_nhwc_layout": "1"} + + +def _expected_batchnorm(inputs): + return inputs["X"] / np.sqrt(1.0 + 1e-5) def _make_simple_model(op_type, inputs_info, outputs_info, attrs=None, opset=13, domain=""): @@ -451,13 +397,17 @@ def _run_model_test( model_path = tmp.name try: save(model, model_path) - sess_options = onnxrt.SessionOptions() + sess_options = _create_session_options() sess_options.add_provider_for_devices([target_device], {}) sess = onnxrt.InferenceSession(model_path, sess_options=sess_options) active_providers = sess.get_providers() - if ep_name not in active_providers: - print(f"{TEST_SKIP} ({ep_name} not active)") - return TEST_SKIP + assigned_nodes, assignment_info = _get_assigned_nodes(sess, ep_name) + if not assigned_nodes: + print( + f"{TEST_FAIL} ({ep_name} was assigned no nodes; providers={active_providers}; " + f"assignments={_format_assignment_summary(assignment_info)})" + ) + return TEST_FAIL res = sess.run(None, feed_dict) expected = expected_fn(feed_dict) if isinstance(expected, (list, tuple)): @@ -477,378 +427,483 @@ def _run_model_test( os.remove(model_path) -def _run_stage5_checks(test_case: unittest.TestCase): - """Stage 5: Test all ops enabled during Stage 5 (5A through 5D).""" - target_device = get_cuda_plugin_device() - passed = 0 - failed = 0 - skipped = 0 - - def run_test(name, model, feed, expected_fn, rtol=1e-3, atol=1e-3): - nonlocal passed, failed, skipped - print(f" {name}...", end=" ", flush=True) - result = _run_model_test(target_device, name, model, feed, expected_fn, rtol=rtol, atol=atol) - with test_case.subTest(op=name): - if result == TEST_PASS: - passed += 1 - print(TEST_PASS, flush=True) - return - - if result == TEST_SKIP: - skipped += 1 - print(TEST_SKIP, flush=True) - return - - failed += 1 - print(TEST_FAIL, flush=True) - test_case.fail(f"{name} Stage 5 plugin op test failed") - - print("\n==================== Stage 5: Expanded Op Tests ====================", flush=True) - f_dtype = TensorProto.FLOAT - - # ---- 5A/5B: Standard ops ---- - print("\n--- Standard Ops (5A/5B) ---", flush=True) - - # Reshape - model = _make_simple_model( - "Reshape", [("X", f_dtype, [2, 3, 4]), ("shape", TensorProto.INT64, [2])], [("Y", f_dtype, [6, 4])] - ) - # Need shape as initializer; build manually - shape_init = helper.make_tensor("shape", TensorProto.INT64, [2], [6, 4]) - model.graph.initializer.append(shape_init) - x = np.random.rand(2, 3, 4).astype(np.float32) - run_test("Reshape", model, {"X": x}, lambda f: f["X"].reshape(6, 4)) - - # Split (opset 18 supports num_outputs; use split input for opset 13) - node = helper.make_node("Split", ["X", "split"], ["Y1", "Y2"], axis=0) - graph = helper.make_graph( - [node], - "test-Split", - [helper.make_tensor_value_info("X", f_dtype, [6, 4])], - [helper.make_tensor_value_info("Y1", f_dtype, [3, 4]), helper.make_tensor_value_info("Y2", f_dtype, [3, 4])], - ) - opset = OperatorSetIdProto() - opset.version = 13 - model = helper.make_model(graph, opset_imports=[opset]) - model.graph.initializer.append(helper.make_tensor("split", TensorProto.INT64, [2], [3, 3])) - x = np.random.rand(6, 4).astype(np.float32) - run_test("Split", model, {"X": x}, lambda f: [f["X"][:3], f["X"][3:]]) - - # Concat - model = _make_simple_model( - "Concat", [("A", f_dtype, [2, 3]), ("B", f_dtype, [2, 3])], [("Y", f_dtype, [4, 3])], attrs={"axis": 0} - ) - a = np.random.rand(2, 3).astype(np.float32) - b = np.random.rand(2, 3).astype(np.float32) - run_test("Concat", model, {"A": a, "B": b}, lambda f: np.concatenate([f["A"], f["B"]], axis=0)) - - # Gather - gather_model = _make_simple_model( - "Gather", - [("X", f_dtype, [5, 4]), ("indices", TensorProto.INT64, [3])], - [("Y", f_dtype, [3, 4])], - attrs={"axis": 0}, - opset=13, - ) - x = np.random.rand(5, 4).astype(np.float32) - idx = np.array([0, 2, 4], dtype=np.int64) - run_test("Gather", gather_model, {"X": x, "indices": idx}, lambda f: f["X"][f["indices"]]) - - # Unsqueeze (opset 13 uses axes as input) - node = helper.make_node("Unsqueeze", ["X", "axes"], ["Y"]) - graph = helper.make_graph( - [node], - "test-Unsqueeze", - [helper.make_tensor_value_info("X", f_dtype, [3, 4])], - [helper.make_tensor_value_info("Y", f_dtype, [1, 3, 4])], - ) - opset = OperatorSetIdProto() - opset.version = 13 - model = helper.make_model(graph, opset_imports=[opset]) - axes_init = helper.make_tensor("axes", TensorProto.INT64, [1], [0]) - model.graph.initializer.append(axes_init) - x = np.random.rand(3, 4).astype(np.float32) - run_test("Unsqueeze", model, {"X": x}, lambda f: np.expand_dims(f["X"], 0)) - - # Tile - node = helper.make_node("Tile", ["X", "repeats"], ["Y"]) - graph = helper.make_graph( - [node], - "test-Tile", - [helper.make_tensor_value_info("X", f_dtype, [2, 3])], - [helper.make_tensor_value_info("Y", f_dtype, [4, 9])], - ) - opset = OperatorSetIdProto() - opset.version = 13 - model = helper.make_model(graph, opset_imports=[opset]) - repeats_init = helper.make_tensor("repeats", TensorProto.INT64, [2], [2, 3]) - model.graph.initializer.append(repeats_init) - x = np.random.rand(2, 3).astype(np.float32) - run_test("Tile", model, {"X": x}, lambda f: np.tile(f["X"], (2, 3))) - - # CumSum - node = helper.make_node("CumSum", ["X", "axis"], ["Y"]) - graph = helper.make_graph( - [node], - "test-CumSum", - [helper.make_tensor_value_info("X", f_dtype, [3, 4])], - [helper.make_tensor_value_info("Y", f_dtype, [3, 4])], - ) - opset = OperatorSetIdProto() - opset.version = 14 - model = helper.make_model(graph, opset_imports=[opset]) - axis_init = helper.make_tensor("axis", TensorProto.INT64, [], [1]) - model.graph.initializer.append(axis_init) - x = np.random.rand(3, 4).astype(np.float32) - run_test("CumSum", model, {"X": x}, lambda f: np.cumsum(f["X"], axis=1)) - - # ConstantOfShape - node = helper.make_node( - "ConstantOfShape", ["shape"], ["Y"], value=helper.make_tensor("value", TensorProto.FLOAT, [1], [3.14]) - ) - graph = helper.make_graph( - [node], - "test-ConstantOfShape", - [helper.make_tensor_value_info("shape", TensorProto.INT64, [2])], - [helper.make_tensor_value_info("Y", f_dtype, None)], - ) - opset = OperatorSetIdProto() - opset.version = 9 - model = helper.make_model(graph, opset_imports=[opset]) - run_test( - "ConstantOfShape", - model, - {"shape": np.array([2, 3], dtype=np.int64)}, - lambda f: np.full((2, 3), 3.14, dtype=np.float32), - ) - - # SpaceToDepth - model = _make_simple_model( - "SpaceToDepth", [("X", f_dtype, [1, 2, 4, 4])], [("Y", f_dtype, [1, 8, 2, 2])], attrs={"blocksize": 2}, opset=13 - ) - x = np.random.rand(1, 2, 4, 4).astype(np.float32) - - def space_to_depth(f): - inp = f["X"] - b, c, h, w = inp.shape - bs = 2 - # ONNX SpaceToDepth: rearrange blocks of spatial data into depth - # (b, c, h, w) -> (b, c, h/bs, bs, w/bs, bs) -> (b, c*bs*bs, h/bs, w/bs) - tmp = inp.reshape(b, c, h // bs, bs, w // bs, bs) - tmp = tmp.transpose(0, 3, 5, 1, 2, 4) - return tmp.reshape(b, c * bs * bs, h // bs, w // bs) - - run_test("SpaceToDepth", model, {"X": x}, space_to_depth) - - # Pad - node = helper.make_node("Pad", ["X", "pads", "constant_value"], ["Y"]) - graph = helper.make_graph( - [node], - "test-Pad", - [helper.make_tensor_value_info("X", f_dtype, [2, 3])], - [helper.make_tensor_value_info("Y", f_dtype, [4, 5])], - ) - opset = OperatorSetIdProto() - opset.version = 13 - model = helper.make_model(graph, opset_imports=[opset]) - model.graph.initializer.append(helper.make_tensor("pads", TensorProto.INT64, [4], [1, 1, 1, 1])) - model.graph.initializer.append(helper.make_tensor("constant_value", TensorProto.FLOAT, [], [0.0])) - x = np.random.rand(2, 3).astype(np.float32) - run_test("Pad", model, {"X": x}, lambda f: np.pad(f["X"], ((1, 1), (1, 1)), constant_values=0)) - - # Slice - node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"]) - graph = helper.make_graph( - [node], - "test-Slice", - [helper.make_tensor_value_info("X", f_dtype, [4, 6])], - [helper.make_tensor_value_info("Y", f_dtype, [2, 4])], - ) - opset = OperatorSetIdProto() - opset.version = 13 - model = helper.make_model(graph, opset_imports=[opset]) - model.graph.initializer.append(helper.make_tensor("starts", TensorProto.INT64, [2], [1, 1])) - model.graph.initializer.append(helper.make_tensor("ends", TensorProto.INT64, [2], [3, 5])) - model.graph.initializer.append(helper.make_tensor("axes", TensorProto.INT64, [2], [0, 1])) - x = np.random.rand(4, 6).astype(np.float32) - run_test("Slice", model, {"X": x}, lambda f: f["X"][1:3, 1:5]) - - # Resize (nearest) - node = helper.make_node("Resize", ["X", "", "scales"], ["Y"], mode="nearest") - graph = helper.make_graph( - [node], - "test-Resize", - [helper.make_tensor_value_info("X", f_dtype, [1, 1, 2, 2])], - [helper.make_tensor_value_info("Y", f_dtype, [1, 1, 4, 4])], - ) - opset = OperatorSetIdProto() - opset.version = 13 - model = helper.make_model(graph, opset_imports=[opset]) - model.graph.initializer.append(helper.make_tensor("scales", TensorProto.FLOAT, [4], [1.0, 1.0, 2.0, 2.0])) - x = np.random.rand(1, 1, 2, 2).astype(np.float32) - run_test("Resize", model, {"X": x}, lambda f: np.repeat(np.repeat(f["X"], 2, axis=2), 2, axis=3)) - - # Sum (variadic) - model = _make_simple_model( - "Sum", - [("A", f_dtype, [3, 4]), ("B", f_dtype, [3, 4]), ("C", f_dtype, [3, 4])], - [("Y", f_dtype, [3, 4])], - opset=13, - ) - a = np.random.rand(3, 4).astype(np.float32) - b = np.random.rand(3, 4).astype(np.float32) - c = np.random.rand(3, 4).astype(np.float32) - run_test("Sum_variadic", model, {"A": a, "B": b, "C": c}, lambda f: f["A"] + f["B"] + f["C"]) - - # ---- 5C: CPU base class ops ---- - print("\n--- CPU Base Class Ops (5C) ---", flush=True) - - # Upsample (deprecated but still present) - node = helper.make_node("Upsample", ["X", "scales"], ["Y"], mode="nearest") - graph = helper.make_graph( - [node], - "test-Upsample", - [helper.make_tensor_value_info("X", f_dtype, [1, 1, 2, 2])], - [helper.make_tensor_value_info("Y", f_dtype, [1, 1, 4, 4])], - ) - opset = OperatorSetIdProto() - opset.version = 9 - model = helper.make_model(graph, opset_imports=[opset]) - model.graph.initializer.append(helper.make_tensor("scales", TensorProto.FLOAT, [4], [1.0, 1.0, 2.0, 2.0])) - x = np.random.rand(1, 1, 2, 2).astype(np.float32) - run_test("Upsample", model, {"X": x}, lambda f: np.repeat(np.repeat(f["X"], 2, axis=2), 2, axis=3)) - - # DepthToSpace - model = _make_simple_model( - "DepthToSpace", - [("X", f_dtype, [1, 8, 2, 2])], - [("Y", f_dtype, [1, 2, 4, 4])], - attrs={"blocksize": 2, "mode": "DCR"}, - opset=13, - ) - x = np.random.rand(1, 8, 2, 2).astype(np.float32) - - def depth_to_space_dcr(f): - inp = f["X"] - b, c, h, w = inp.shape - bs = 2 - return ( - inp.reshape(b, bs, bs, c // (bs * bs), h, w) - .transpose(0, 3, 4, 1, 5, 2) - .reshape(b, c // (bs * bs), h * bs, w * bs) +class TestCudaPluginEP(unittest.TestCase): + # ---- Registration tests (verify nodes run on the plugin EP) ---- + + def test_registration_add(self): + target_device = get_cuda_plugin_device() + inputs = {"A": np.random.rand(3, 2).astype(np.float32), "B": np.random.rand(3, 2).astype(np.float32)} + result = run_operator_test(target_device, create_add_model, inputs, lambda feed: feed["A"] + feed["B"]) + self.assertTrue(result, "Add plugin registration test failed") + + def test_registration_matmul(self): + target_device = get_cuda_plugin_device() + inputs = {"A": np.random.rand(3, 4).astype(np.float32), "B": np.random.rand(4, 5).astype(np.float32)} + result = run_operator_test(target_device, create_matmul_model, inputs, lambda feed: feed["A"] @ feed["B"]) + self.assertTrue(result, "MatMul plugin registration test failed") + + def test_registration_gemm(self): + target_device = get_cuda_plugin_device() + inputs = { + "A": np.random.rand(3, 4).astype(np.float32), + "B": np.random.rand(4, 5).astype(np.float32), + "C": np.random.rand(5).astype(np.float32), + } + result = run_operator_test( + target_device, + lambda model_path: create_gemm_model(model_path, alpha=2.0, beta=0.5), + inputs, + lambda feed: 2.0 * (feed["A"] @ feed["B"]) + 0.5 * feed["C"], + ) + self.assertTrue(result, "Gemm plugin registration test failed") + + def test_registration_conv(self): + target_device = get_cuda_plugin_device() + inputs = { + "X": np.random.rand(1, 2, 4, 4).astype(np.float32), + "W": np.random.rand(3, 2, 3, 3).astype(np.float32), + } + result = run_operator_test(target_device, create_conv_model, inputs, _expected_conv) + self.assertTrue(result, "Conv plugin registration test failed") + + # ---- Provider options tests ---- + + def test_provider_options_valid(self): + result = run_provider_options_test({"device_id": "0", "use_tf32": "0"}, expect_plugin_provider=True) + self.assertTrue(result, "Provider options with valid device_id/use_tf32 failed") + + def test_provider_options_invalid_device(self): + result = run_provider_options_test({"device_id": "999"}, expect_plugin_provider=False) + self.assertTrue(result, "Provider options with invalid device_id failed") + + # ---- NHWC layout tests ---- + + def test_nhwc_conv(self): + target_device = get_cuda_plugin_device() + inputs = { + "X": np.random.rand(1, 2, 4, 4).astype(np.float32), + "W": np.random.rand(3, 2, 3, 3).astype(np.float32), + } + result = run_operator_test( + target_device, create_conv_model, inputs, _expected_conv, session_config=_NHWC_CONFIG ) + self.assertTrue(result, "Conv (NHWC) plugin test failed") - run_test("DepthToSpace", model, {"X": x}, depth_to_space_dcr) + def test_nhwc_batch_normalization(self): + target_device = get_cuda_plugin_device() + inputs = {"X": np.random.rand(1, 3, 4, 4).astype(np.float32)} + result = run_operator_test( + target_device, create_batch_norm_model, inputs, _expected_batchnorm, session_config=_NHWC_CONFIG + ) + self.assertTrue(result, "BatchNormalization (NHWC) plugin test failed") - # ---- 5D: Contrib Ops ---- - print("\n--- Contrib Ops (5D) ---", flush=True) + def test_nhwc_maxpool(self): + target_device = get_cuda_plugin_device() + inputs = {"X": np.random.rand(1, 3, 4, 4).astype(np.float32)} + result = run_operator_test( + target_device, + create_maxpool_model, + inputs, + lambda feed: F.max_pool2d(torch.from_numpy(feed["X"]), kernel_size=2, stride=2).numpy(), + session_config=_NHWC_CONFIG, + ) + self.assertTrue(result, "MaxPool (NHWC) plugin test failed") - # FastGelu (com.microsoft domain) - node = helper.make_node("FastGelu", ["X"], ["Y"], domain="com.microsoft") - graph = helper.make_graph( - [node], - "test-FastGelu", - [helper.make_tensor_value_info("X", f_dtype, [2, 4])], - [helper.make_tensor_value_info("Y", f_dtype, [2, 4])], - ) - opset_onnx = OperatorSetIdProto() - opset_onnx.version = 13 - opset_ms = OperatorSetIdProto() - opset_ms.domain = "com.microsoft" - opset_ms.version = 1 - model = helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) - x = np.random.rand(2, 4).astype(np.float32) - - def fast_gelu_ref(f): - x = f["X"] - # FastGelu approximation: x * sigmoid(1.702 * x) - return x * (1.0 / (1.0 + np.exp(-1.702 * x))) - - run_test("FastGelu", model, {"X": x}, fast_gelu_ref, rtol=1e-2, atol=1e-2) - - # BiasDropout (com.microsoft). We force inference mode so the op is deterministic. - model = make_bias_dropout_model() - x = np.random.rand(2, 4).astype(np.float32) - bias = np.random.rand(4).astype(np.float32) - residual = np.random.rand(2, 4).astype(np.float32) - ratio = np.array(0.5, dtype=np.float32) - training_mode = np.array(False, dtype=np.bool_) - run_test( - "BiasDropout", - model, - { - "X": x, - "bias": bias, - "residual": residual, - "ratio": ratio, - "training_mode": training_mode, - }, - lambda feed: feed["X"] + feed["bias"] + feed["residual"], - ) + def test_nhwc_avgpool(self): + target_device = get_cuda_plugin_device() + inputs = {"X": np.random.rand(1, 3, 4, 4).astype(np.float32)} + result = run_operator_test( + target_device, + create_avgpool_model, + inputs, + lambda feed: F.avg_pool2d(torch.from_numpy(feed["X"]), kernel_size=2, stride=2).numpy(), + session_config=_NHWC_CONFIG, + ) + self.assertTrue(result, "AveragePool (NHWC) plugin test failed") - # SkipLayerNormalization (com.microsoft) - hidden_size = 8 - node = helper.make_node( - "SkipLayerNormalization", - ["X", "skip", "gamma", "beta"], - ["Y", "", "", "input_skip_bias_sum"], - domain="com.microsoft", - epsilon=1e-5, - ) - graph = helper.make_graph( - [node], - "test-SkipLayerNorm", - [ - helper.make_tensor_value_info("X", f_dtype, [2, hidden_size]), - helper.make_tensor_value_info("skip", f_dtype, [2, hidden_size]), - helper.make_tensor_value_info("gamma", f_dtype, [hidden_size]), - helper.make_tensor_value_info("beta", f_dtype, [hidden_size]), - ], - [ - helper.make_tensor_value_info("Y", f_dtype, [2, hidden_size]), - helper.make_tensor_value_info("input_skip_bias_sum", f_dtype, None), - ], - ) - opset_onnx = OperatorSetIdProto() - opset_onnx.version = 13 - opset_ms = OperatorSetIdProto() - opset_ms.domain = "com.microsoft" - opset_ms.version = 1 - model = helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) - x = np.random.rand(2, hidden_size).astype(np.float32) - skip = np.random.rand(2, hidden_size).astype(np.float32) - gamma = np.ones(hidden_size, dtype=np.float32) - beta = np.zeros(hidden_size, dtype=np.float32) - - def skip_layer_norm_ref(f): - added = f["X"] + f["skip"] - mean = added.mean(axis=-1, keepdims=True) - var = added.var(axis=-1, keepdims=True) - normed = (added - mean) / np.sqrt(var + 1e-5) - return [normed * f["gamma"] + f["beta"], added] - - run_test( - "SkipLayerNorm", - model, - {"X": x, "skip": skip, "gamma": gamma, "beta": beta}, - skip_layer_norm_ref, - rtol=1e-2, - atol=1e-2, - ) + # ---- Standard op tests ---- - # ---- Summary ---- - total = passed + failed + skipped - print(f"\n--- Stage 5 Results: {passed} passed, {failed} failed, {skipped} skipped ({total} total) ---", flush=True) - test_case.assertEqual(failed, 0, f"Stage 5 had {failed} failing plugin op checks") - print("All Stage 5 tests finished successfully.", flush=True) + def test_op_reshape(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "Reshape", [("X", f_dtype, [2, 3, 4]), ("shape", TensorProto.INT64, [2])], [("Y", f_dtype, [6, 4])] + ) + model.graph.initializer.append(helper.make_tensor("shape", TensorProto.INT64, [2], [6, 4])) + x = np.random.rand(2, 3, 4).astype(np.float32) + result = _run_model_test(target_device, "Reshape", model, {"X": x}, lambda f: f["X"].reshape(6, 4)) + self.assertEqual(result, TEST_PASS, "Reshape plugin op test failed") + + def test_op_split(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Split", ["X", "split"], ["Y1", "Y2"], axis=0) + graph = helper.make_graph( + [node], + "test-Split", + [helper.make_tensor_value_info("X", f_dtype, [6, 4])], + [ + helper.make_tensor_value_info("Y1", f_dtype, [3, 4]), + helper.make_tensor_value_info("Y2", f_dtype, [3, 4]), + ], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("split", TensorProto.INT64, [2], [3, 3])) + x = np.random.rand(6, 4).astype(np.float32) + result = _run_model_test(target_device, "Split", model, {"X": x}, lambda f: [f["X"][:3], f["X"][3:]]) + self.assertEqual(result, TEST_PASS, "Split plugin op test failed") + + def test_op_concat(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "Concat", [("A", f_dtype, [2, 3]), ("B", f_dtype, [2, 3])], [("Y", f_dtype, [4, 3])], attrs={"axis": 0} + ) + a = np.random.rand(2, 3).astype(np.float32) + b = np.random.rand(2, 3).astype(np.float32) + result = _run_model_test( + target_device, "Concat", model, {"A": a, "B": b}, lambda f: np.concatenate([f["A"], f["B"]], axis=0) + ) + self.assertEqual(result, TEST_PASS, "Concat plugin op test failed") + + def test_op_gather(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "Gather", + [("X", f_dtype, [5, 4]), ("indices", TensorProto.INT64, [3])], + [("Y", f_dtype, [3, 4])], + attrs={"axis": 0}, + opset=13, + ) + x = np.random.rand(5, 4).astype(np.float32) + idx = np.array([0, 2, 4], dtype=np.int64) + result = _run_model_test( + target_device, "Gather", model, {"X": x, "indices": idx}, lambda f: f["X"][f["indices"]] + ) + self.assertEqual(result, TEST_PASS, "Gather plugin op test failed") + + def test_op_unsqueeze(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Unsqueeze", ["X", "axes"], ["Y"]) + graph = helper.make_graph( + [node], + "test-Unsqueeze", + [helper.make_tensor_value_info("X", f_dtype, [3, 4])], + [helper.make_tensor_value_info("Y", f_dtype, [1, 3, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("axes", TensorProto.INT64, [1], [0])) + x = np.random.rand(3, 4).astype(np.float32) + result = _run_model_test(target_device, "Unsqueeze", model, {"X": x}, lambda f: np.expand_dims(f["X"], 0)) + self.assertEqual(result, TEST_PASS, "Unsqueeze plugin op test failed") + + def test_op_tile(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Tile", ["X", "repeats"], ["Y"]) + graph = helper.make_graph( + [node], + "test-Tile", + [helper.make_tensor_value_info("X", f_dtype, [2, 3])], + [helper.make_tensor_value_info("Y", f_dtype, [4, 9])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("repeats", TensorProto.INT64, [2], [2, 3])) + x = np.random.rand(2, 3).astype(np.float32) + result = _run_model_test(target_device, "Tile", model, {"X": x}, lambda f: np.tile(f["X"], (2, 3))) + self.assertEqual(result, TEST_PASS, "Tile plugin op test failed") + + def test_op_cumsum(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("CumSum", ["X", "axis"], ["Y"]) + graph = helper.make_graph( + [node], + "test-CumSum", + [helper.make_tensor_value_info("X", f_dtype, [3, 4])], + [helper.make_tensor_value_info("Y", f_dtype, [3, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 14 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("axis", TensorProto.INT64, [], [1])) + x = np.random.rand(3, 4).astype(np.float32) + result = _run_model_test(target_device, "CumSum", model, {"X": x}, lambda f: np.cumsum(f["X"], axis=1)) + self.assertEqual(result, TEST_PASS, "CumSum plugin op test failed") + + def test_op_constant_of_shape(self): + target_device = get_cuda_plugin_device() + node = helper.make_node( + "ConstantOfShape", ["shape"], ["Y"], value=helper.make_tensor("value", TensorProto.FLOAT, [1], [3.14]) + ) + graph = helper.make_graph( + [node], + "test-ConstantOfShape", + [helper.make_tensor_value_info("shape", TensorProto.INT64, [2])], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, None)], + ) + opset = OperatorSetIdProto() + opset.version = 9 + model = helper.make_model(graph, opset_imports=[opset]) + result = _run_model_test( + target_device, + "ConstantOfShape", + model, + {"shape": np.array([2, 3], dtype=np.int64)}, + lambda f: np.full((2, 3), 3.14, dtype=np.float32), + ) + self.assertEqual(result, TEST_PASS, "ConstantOfShape plugin op test failed") + + def test_op_space_to_depth(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "SpaceToDepth", + [("X", f_dtype, [1, 2, 4, 4])], + [("Y", f_dtype, [1, 8, 2, 2])], + attrs={"blocksize": 2}, + opset=13, + ) + x = np.random.rand(1, 2, 4, 4).astype(np.float32) + + def expected(f): + inp = f["X"] + b, c, h, w = inp.shape + bs = 2 + tmp = inp.reshape(b, c, h // bs, bs, w // bs, bs) + tmp = tmp.transpose(0, 3, 5, 1, 2, 4) + return tmp.reshape(b, c * bs * bs, h // bs, w // bs) + + result = _run_model_test(target_device, "SpaceToDepth", model, {"X": x}, expected) + self.assertEqual(result, TEST_PASS, "SpaceToDepth plugin op test failed") + + def test_op_pad(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Pad", ["X", "pads", "constant_value"], ["Y"]) + graph = helper.make_graph( + [node], + "test-Pad", + [helper.make_tensor_value_info("X", f_dtype, [2, 3])], + [helper.make_tensor_value_info("Y", f_dtype, [4, 5])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("pads", TensorProto.INT64, [4], [1, 1, 1, 1])) + model.graph.initializer.append(helper.make_tensor("constant_value", TensorProto.FLOAT, [], [0.0])) + x = np.random.rand(2, 3).astype(np.float32) + result = _run_model_test( + target_device, "Pad", model, {"X": x}, lambda f: np.pad(f["X"], ((1, 1), (1, 1)), constant_values=0) + ) + self.assertEqual(result, TEST_PASS, "Pad plugin op test failed") + + def test_op_slice(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"]) + graph = helper.make_graph( + [node], + "test-Slice", + [helper.make_tensor_value_info("X", f_dtype, [4, 6])], + [helper.make_tensor_value_info("Y", f_dtype, [2, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("starts", TensorProto.INT64, [2], [1, 1])) + model.graph.initializer.append(helper.make_tensor("ends", TensorProto.INT64, [2], [3, 5])) + model.graph.initializer.append(helper.make_tensor("axes", TensorProto.INT64, [2], [0, 1])) + x = np.random.rand(4, 6).astype(np.float32) + result = _run_model_test(target_device, "Slice", model, {"X": x}, lambda f: f["X"][1:3, 1:5]) + self.assertEqual(result, TEST_PASS, "Slice plugin op test failed") + + def test_op_resize(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Resize", ["X", "", "scales"], ["Y"], mode="nearest") + graph = helper.make_graph( + [node], + "test-Resize", + [helper.make_tensor_value_info("X", f_dtype, [1, 1, 2, 2])], + [helper.make_tensor_value_info("Y", f_dtype, [1, 1, 4, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("scales", TensorProto.FLOAT, [4], [1.0, 1.0, 2.0, 2.0])) + x = np.random.rand(1, 1, 2, 2).astype(np.float32) + result = _run_model_test( + target_device, "Resize", model, {"X": x}, lambda f: np.repeat(np.repeat(f["X"], 2, axis=2), 2, axis=3) + ) + self.assertEqual(result, TEST_PASS, "Resize plugin op test failed") + + def test_op_sum_variadic(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "Sum", + [("A", f_dtype, [3, 4]), ("B", f_dtype, [3, 4]), ("C", f_dtype, [3, 4])], + [("Y", f_dtype, [3, 4])], + opset=13, + ) + a = np.random.rand(3, 4).astype(np.float32) + b = np.random.rand(3, 4).astype(np.float32) + c = np.random.rand(3, 4).astype(np.float32) + result = _run_model_test( + target_device, "Sum_variadic", model, {"A": a, "B": b, "C": c}, lambda f: f["A"] + f["B"] + f["C"] + ) + self.assertEqual(result, TEST_PASS, "Sum_variadic plugin op test failed") + + # ---- CPU base class op tests ---- + + def test_op_upsample(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Upsample", ["X", "scales"], ["Y"], mode="nearest") + graph = helper.make_graph( + [node], + "test-Upsample", + [helper.make_tensor_value_info("X", f_dtype, [1, 1, 2, 2])], + [helper.make_tensor_value_info("Y", f_dtype, [1, 1, 4, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 9 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("scales", TensorProto.FLOAT, [4], [1.0, 1.0, 2.0, 2.0])) + x = np.random.rand(1, 1, 2, 2).astype(np.float32) + result = _run_model_test( + target_device, "Upsample", model, {"X": x}, lambda f: np.repeat(np.repeat(f["X"], 2, axis=2), 2, axis=3) + ) + self.assertEqual(result, TEST_PASS, "Upsample plugin op test failed") + + def test_op_depth_to_space(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "DepthToSpace", + [("X", f_dtype, [1, 8, 2, 2])], + [("Y", f_dtype, [1, 2, 4, 4])], + attrs={"blocksize": 2, "mode": "DCR"}, + opset=13, + ) + x = np.random.rand(1, 8, 2, 2).astype(np.float32) + + def expected(f): + inp = f["X"] + b, c, h, w = inp.shape + bs = 2 + return ( + inp.reshape(b, bs, bs, c // (bs * bs), h, w) + .transpose(0, 3, 4, 1, 5, 2) + .reshape(b, c // (bs * bs), h * bs, w * bs) + ) + result = _run_model_test(target_device, "DepthToSpace", model, {"X": x}, expected) + self.assertEqual(result, TEST_PASS, "DepthToSpace plugin op test failed") -class TestCudaPluginEP(unittest.TestCase): - def test_cuda_plugin_registration(self): - _run_registration_checks(self) + # ---- Contrib op tests ---- - def test_cuda_plugin_stage5_ops(self): - _run_stage5_checks(self) + def test_op_fast_gelu(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("FastGelu", ["X"], ["Y"], domain="com.microsoft") + graph = helper.make_graph( + [node], + "test-FastGelu", + [helper.make_tensor_value_info("X", f_dtype, [2, 4])], + [helper.make_tensor_value_info("Y", f_dtype, [2, 4])], + ) + opset_onnx = OperatorSetIdProto() + opset_onnx.version = 13 + opset_ms = OperatorSetIdProto() + opset_ms.domain = "com.microsoft" + opset_ms.version = 1 + model = helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) + x = np.random.rand(2, 4).astype(np.float32) + + def expected(f): + v = f["X"] + return v * (1.0 / (1.0 + np.exp(-1.702 * v))) + + result = _run_model_test(target_device, "FastGelu", model, {"X": x}, expected, rtol=1e-2, atol=1e-2) + self.assertEqual(result, TEST_PASS, "FastGelu plugin op test failed") + + def test_op_bias_dropout(self): + target_device = get_cuda_plugin_device() + model = make_bias_dropout_model() + x = np.random.rand(2, 4).astype(np.float32) + bias = np.random.rand(4).astype(np.float32) + residual = np.random.rand(2, 4).astype(np.float32) + ratio = np.array(0.5, dtype=np.float32) + training_mode = np.array(False, dtype=np.bool_) + feed = {"X": x, "bias": bias, "residual": residual, "ratio": ratio, "training_mode": training_mode} + result = _run_model_test( + target_device, "BiasDropout", model, feed, lambda f: f["X"] + f["bias"] + f["residual"] + ) + self.assertEqual(result, TEST_PASS, "BiasDropout plugin op test failed") + + def test_op_skip_layer_norm(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + hidden_size = 8 + node = helper.make_node( + "SkipLayerNormalization", + ["X", "skip", "gamma", "beta"], + ["Y", "", "", "input_skip_bias_sum"], + domain="com.microsoft", + epsilon=1e-5, + ) + graph = helper.make_graph( + [node], + "test-SkipLayerNorm", + [ + helper.make_tensor_value_info("X", f_dtype, [2, hidden_size]), + helper.make_tensor_value_info("skip", f_dtype, [2, hidden_size]), + helper.make_tensor_value_info("gamma", f_dtype, [hidden_size]), + helper.make_tensor_value_info("beta", f_dtype, [hidden_size]), + ], + [ + helper.make_tensor_value_info("Y", f_dtype, [2, hidden_size]), + helper.make_tensor_value_info("input_skip_bias_sum", f_dtype, None), + ], + ) + opset_onnx = OperatorSetIdProto() + opset_onnx.version = 13 + opset_ms = OperatorSetIdProto() + opset_ms.domain = "com.microsoft" + opset_ms.version = 1 + model = helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) + x = np.random.rand(2, hidden_size).astype(np.float32) + skip = np.random.rand(2, hidden_size).astype(np.float32) + gamma = np.ones(hidden_size, dtype=np.float32) + beta = np.zeros(hidden_size, dtype=np.float32) + + def expected(f): + added = f["X"] + f["skip"] + mean = added.mean(axis=-1, keepdims=True) + var = added.var(axis=-1, keepdims=True) + normed = (added - mean) / np.sqrt(var + 1e-5) + return [normed * f["gamma"] + f["beta"], added] + + result = _run_model_test( + target_device, + "SkipLayerNorm", + model, + {"X": x, "skip": skip, "gamma": gamma, "beta": beta}, + expected, + rtol=1e-2, + atol=1e-2, + ) + self.assertEqual(result, TEST_PASS, "SkipLayerNorm plugin op test failed") if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 5d15a70c207f3..23c47e84c1630 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -20,7 +20,7 @@ import numpy import torch -from cuda_plugin_ep_helper import resolve_cuda_plugin_ep +from cuda_plugin_ep_helper import get_cuda_provider_name, resolve_cuda_plugin_ep from einops import rearrange, repeat # --- ONNX and Torch/Numpy Dtype Mappings --- @@ -35,7 +35,7 @@ from packaging import version from parameterized import parameterized -from onnxruntime import InferenceSession, SessionOptions, get_available_providers, get_build_info +from onnxruntime import InferenceSession, SessionOptions, get_build_info from onnxruntime import __version__ as ort_version # Set seed for reproducibility @@ -1930,7 +1930,7 @@ def gqa_cuda_quantized_test_cases(is_past: bool): def has_cuda_provider(): - return "CUDAExecutionProvider" in get_available_providers() + return get_cuda_provider_name() is not None def has_cuda_device(min_capability: int = 80): @@ -2346,7 +2346,7 @@ def test_gqa_rope_separate_qkv_bug(self): The bug caused q_out to be nullptr when unpacking separate QKV with only Q rotation (standard GQA), leading to unrotated Q being used in Attention. """ - if "CUDAExecutionProvider" not in get_available_providers(): + if not has_cuda_provider(): self.skipTest("CUDA required") # Config that triggers the path: Prompt phase, Separate QKV inputs, RoPE enabled @@ -2385,7 +2385,7 @@ def test_gqa_int8_large_seq_batch4(self): Regression test for batch_size=4 + max_seq_len=8192 + int8 KV cache crash. This reproduces a CUDA illegal memory access due to scratch size under-allocation. """ - if "CUDAExecutionProvider" not in get_available_providers(): + if not has_cuda_provider(): self.skipTest("CUDA required") # Config that triggers the crash: batch=4, large max_seq_len, int8 kv diff --git a/onnxruntime/test/python/transformers/test_moe_cuda.py b/onnxruntime/test/python/transformers/test_moe_cuda.py index 67caf903f0165..4b9f4e3634a9b 100644 --- a/onnxruntime/test/python/transformers/test_moe_cuda.py +++ b/onnxruntime/test/python/transformers/test_moe_cuda.py @@ -17,7 +17,7 @@ import numpy import torch import torch.nn.functional as F -from cuda_plugin_ep_helper import resolve_cuda_plugin_ep +from cuda_plugin_ep_helper import get_cuda_provider_name from onnx import TensorProto, helper from parameterized import parameterized from torch import nn @@ -29,8 +29,10 @@ onnxruntime.preload_dlls() + # Determine the execution provider and device based on CUDA availability. -use_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers() and torch.cuda.is_available() +cuda_provider = get_cuda_provider_name() +use_cuda = cuda_provider is not None device = torch.device("cuda:0" if use_cuda else "cpu") @@ -38,7 +40,7 @@ def get_ort_provider(): if not use_cuda: return ["CPUExecutionProvider"] - return [resolve_cuda_plugin_ep("CUDAExecutionProvider")] + return [cuda_provider] torch.manual_seed(42) @@ -1412,7 +1414,7 @@ def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): def has_bf16_moe(): - if "CUDAExecutionProvider" not in onnxruntime.get_available_providers() or not torch.cuda.is_available(): + if not use_cuda or not torch.cuda.is_available(): return False major, _ = torch.cuda.get_device_capability() return major >= 8 From ef4dcc05d3d9f5c65d91e0ae730c048dff4db070 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 27 Mar 2026 01:01:48 -0700 Subject: [PATCH 29/48] review feedback --- .../core/providers/cuda/plugin/cuda_ep.cc | 22 ++++++--- .../cuda/plugin/cuda_kernel_adapter.h | 47 ++++++++++--------- 2 files changed, 39 insertions(+), 30 deletions(-) diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index 807122ff6ad97..fe2d845a46cad 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -61,14 +61,22 @@ CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& lo "CUDA Plugin EP created", ORT_FILE, __LINE__, __FUNCTION__)); - // Store per-EP runtime configuration (TF32, device ID, tuning options, etc.) - // in a global map keyed by the adapter-wrapped execution provider. Migrated - // kernels retrieve these settings via info.GetExecutionProvider(). + // Store per-EP runtime configuration in a global map keyed by the + // adapter-wrapped execution provider pointer. Migrated kernels retrieve these + // settings at compute time via GetCudaKernelAdapterRuntimeConfigForProvider(). + // Adding a new config field only requires updating CudaKernelAdapterRuntimeConfig, + // CudaEp::Config, and the struct-initializer below — no function-signature change. + onnxruntime::cuda::detail::CudaKernelAdapterRuntimeConfig adapter_config; + adapter_config.use_tf32 = config_.use_tf32; + adapter_config.skip_layer_norm_strict_mode = config_.enable_skip_layer_norm_strict_mode; + adapter_config.cudnn_conv_algo = config_.cudnn_conv_algo; + adapter_config.cudnn_conv_use_max_workspace = config_.cudnn_conv_use_max_workspace; + adapter_config.cudnn_conv1d_pad_to_nc1d = config_.cudnn_conv1d_pad_to_nc1d; + adapter_config.fuse_conv_bias = config_.fuse_conv_bias; + adapter_config.sdpa_kernel = config_.sdpa_kernel; + adapter_config.device_id = config_.device_id; onnxruntime::cuda::SetCudaKernelAdapterRuntimeConfigForProvider( - static_cast(EpImpl()), - config_.use_tf32, config_.device_id, config_.enable_skip_layer_norm_strict_mode, - config_.cudnn_conv_algo, config_.cudnn_conv_use_max_workspace, config_.cudnn_conv1d_pad_to_nc1d, - config_.fuse_conv_bias, config_.sdpa_kernel); + static_cast(EpImpl()), adapter_config); } CudaEp::~CudaEp() { diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index cbea32b81f6ed..a3d2394741773 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -310,8 +310,10 @@ class PluginKernelCollector { // =================================================================== // Section 4: Logging shim (adapter path only) -// Replaces LOGS_DEFAULT with a no-op stream to avoid pulling in the -// full ORT logging framework inside the plugin shared library. +// LOGS_DEFAULT is re-routed through ep::adapter::LoggingManager, which +// holds the ORT default logger set up in CudaEpFactory::CudaEpFactory. +// All severity levels (including ERROR/WARNING) are forwarded to the +// ORT logger; no log output is suppressed. // =================================================================== // Explicit function instantiation — called once per unique class in each .cc file @@ -320,7 +322,6 @@ class PluginKernelCollector { // The plugin utilizes ep::adapter::LoggingManager for LOGS_DEFAULT, // which is initialized in CudaEpFactory::CudaEpFactory. -#include #include #include #include @@ -338,14 +339,12 @@ namespace cuda { // =================================================================== // Section 5: Runtime configuration for migrated kernels -// Stored as atomics so SetCudaKernelAdapterRuntimeConfig() can be -// called from CudaEp's constructor on any thread. +// Fields are written once during CudaEp construction (under unique_lock) +// and only read afterwards; a shared_mutex in ProviderConfigStore guards +// concurrent access. // =================================================================== namespace detail { -// All fields are written once during CudaEp construction (under unique_lock) -// and only read afterwards, so std::atomic is not needed — the shared_mutex -// in ProviderConfigStore provides the necessary happens-before guarantee. struct CudaKernelAdapterRuntimeConfig { bool use_tf32 = true; bool skip_layer_norm_strict_mode = false; @@ -531,22 +530,24 @@ class CUDAExecutionProvider : public onnxruntime::IExecutionProvider { namespace cuda { -inline void SetCudaKernelAdapterRuntimeConfigForProvider(const void* provider, bool use_tf32, int device_id, - bool skip_layer_norm_strict_mode = false, - int cudnn_conv_algo = 0, bool cudnn_conv_use_max_workspace = true, - bool cudnn_conv1d_pad_to_nc1d = false, - bool fuse_conv_bias = false, - int sdpa_kernel = 0) { +// Populate the per-provider adapter config from a pre-filled initializer struct. +// Callers (e.g. CudaEp constructor) construct a detail::CudaKernelAdapterRuntimeConfig, +// fill every field they care about, then call this function. Adding a new config +// field only requires updating the struct and the call site — no signature change. +inline void SetCudaKernelAdapterRuntimeConfigForProvider( + const void* provider, const detail::CudaKernelAdapterRuntimeConfig& init_config) { auto& config = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); - config.use_tf32 = use_tf32; - config.skip_layer_norm_strict_mode = skip_layer_norm_strict_mode; - config.cudnn_conv_algo = cudnn_conv_algo; - config.cudnn_conv_use_max_workspace = cudnn_conv_use_max_workspace; - config.cudnn_conv1d_pad_to_nc1d = cudnn_conv1d_pad_to_nc1d; - config.fuse_conv_bias = fuse_conv_bias; - config.sdpa_kernel = sdpa_kernel; - config.device_id = device_id; - PL_CUDA_CALL_THROW(cudaGetDeviceProperties(&config.device_prop, device_id)); + // AttentionKernelOptions contains std::once_flag (not copyable), so assign + // the plain-data fields individually rather than relying on operator=. + config.use_tf32 = init_config.use_tf32; + config.skip_layer_norm_strict_mode = init_config.skip_layer_norm_strict_mode; + config.cudnn_conv_algo = init_config.cudnn_conv_algo; + config.cudnn_conv_use_max_workspace = init_config.cudnn_conv_use_max_workspace; + config.cudnn_conv1d_pad_to_nc1d = init_config.cudnn_conv1d_pad_to_nc1d; + config.fuse_conv_bias = init_config.fuse_conv_bias; + config.sdpa_kernel = init_config.sdpa_kernel; + config.device_id = init_config.device_id; + PL_CUDA_CALL_THROW(cudaGetDeviceProperties(&config.device_prop, config.device_id)); } inline bool GetCudaKernelAdapterSkipLayerNormStrictMode(const void* provider) { From 5328c53fcb8218eff691440fb5053df3d7408536 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 27 Mar 2026 01:20:24 -0700 Subject: [PATCH 30/48] refactoring --- docs/cuda_plugin_ep/cuda_plugin_ep_design.md | 8 +-- .../core/providers/cuda/plugin/cuda_ep.cc | 1 - .../cuda/plugin/cuda_kernel_adapter.h | 64 ++++++------------- .../test/framework/dynamic_plugin_ep_test.cc | 4 +- 4 files changed, 25 insertions(+), 52 deletions(-) diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md index 98856422bffab..d7f7730bb4330 100644 --- a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -60,8 +60,8 @@ Migrated CUDA kernels Key ownership relationships: - `CudaEpFactory` implements raw `OrtEpFactory` callbacks and owns shared factory-level state such as the kernel registry, cached `OrtMemoryInfo` instances, and the hardware-device to CUDA-ordinal map. - `CudaEp` inherits from `ep::adapter::Ep`, which itself derives from `OrtEp` and owns a framework-facing `IExecutionProvider` object. -- The plugin-local `CUDAExecutionProvider` in `cuda_kernel_adapter.h` is a real shim object owned by `ep::adapter::Ep`. It is not the full in-tree CUDA EP, but it does have its own object identity and may store plugin-specific members such as the wrapped `OrtEp*`. -- Runtime state needed by migrated kernels is still stored in adapter-side maps keyed by the shim provider address (`EpImpl()`), not by the `CudaEp` object address. +- The plugin-local `CUDAExecutionProvider` in `cuda_kernel_adapter.h` is a real shim object owned by `ep::adapter::Ep`. It is not the full in-tree CUDA EP, but it has its own object identity and stores plugin-specific members — including the wrapped `OrtEp*` and the `CudaKernelAdapterRuntimeConfig` that migrated kernels read at compute time. +- Runtime configuration needed by migrated kernels is stored directly as a member (`config_`) of the shim `CUDAExecutionProvider` object, rather than in a separate map keyed by the provider address. - `CudaSyncStream` owns `cudaStream_t`, `cublasHandle_t`, `cudnnHandle_t`, and `cublasLtHandle_t` for each sync stream created through the EP API. ### 2.4 Plugin DLL Entry Points @@ -266,9 +266,9 @@ This changes the safety model from the earlier "phantom shim" design: Provider options flow through the plugin in two stages: - `CudaEpFactory` parses session/provider options into `CudaEp::Config`. -- `CudaEp` mirrors the subset needed by migrated kernels into the adapter-side `ProviderConfigStore`, keyed by `EpImpl()`. +- `CudaEp` copies the subset needed by migrated kernels into `CUDAExecutionProvider::config_` via `SetCudaKernelAdapterRuntimeConfigForProvider(EpImpl(), ...)` during EP construction. -Today that mirrored subset includes TF32, device ID/device properties, cuDNN convolution settings, skip-layer-norm strict mode, fused-conv-bias, and SDPA kernel selection. Other plugin behaviors, such as preferred layout, are handled directly by `CudaEp` callbacks instead of through the shim. +Because `config_` is a direct member of the shim object, there is no heap-allocated map and no mutex — reads at kernel compute time are simple field accesses. Today that stored subset includes TF32, device ID/device properties, cuDNN convolution settings, skip-layer-norm strict mode, fused-conv-bias, and SDPA kernel selection. Other plugin behaviors, such as preferred layout, are handled directly by `CudaEp` callbacks instead of through the shim. For stream bridging, the preferred helpers are: - `Stream(ctx)` when the kernel only needs a raw `cudaStream_t` diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index fe2d845a46cad..5461e30c8199c 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -80,7 +80,6 @@ CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& lo } CudaEp::~CudaEp() { - onnxruntime::cuda::detail::RemoveCudaKernelAdapterRuntimeConfigForProvider(static_cast(EpImpl())); } /*static*/ diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index a3d2394741773..db8de3865a2c4 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -332,7 +332,6 @@ class PluginKernelCollector { #include #include #include -#include namespace onnxruntime { namespace cuda { @@ -357,40 +356,6 @@ struct CudaKernelAdapterRuntimeConfig { cudaDeviceProp device_prop{}; onnxruntime::AttentionKernelOptions attention_kernel_options; }; -// Shared storage for per-provider runtime configurations. -// Both Get and Remove must operate on the same static map instance, -// so we centralize them in a single struct with static lifetime. -struct ProviderConfigStore { - std::shared_mutex mutex; - std::unordered_map> configs; - - static ProviderConfigStore& Instance() { - static ProviderConfigStore store; - return store; - } -}; - -inline CudaKernelAdapterRuntimeConfig& GetCudaKernelAdapterRuntimeConfigForProvider(const void* provider) { - auto& store = ProviderConfigStore::Instance(); - std::shared_lock lock(store.mutex); - auto it = store.configs.find(provider); - if (it != store.configs.end()) { - return *it->second; - } - lock.unlock(); - std::unique_lock unique_lock(store.mutex); - auto& ptr = store.configs[provider]; - if (!ptr) { - ptr = std::make_unique(); - } - return *ptr; -} - -inline void RemoveCudaKernelAdapterRuntimeConfigForProvider(const void* provider) { - auto& store = ProviderConfigStore::Instance(); - std::unique_lock lock(store.mutex); - store.configs.erase(provider); -} template struct SizeOf { static constexpr size_t value = sizeof(T); @@ -501,34 +466,45 @@ class CUDAExecutionProvider : public onnxruntime::IExecutionProvider { } int GetCudnnConvAlgo() const { - return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).cudnn_conv_algo; + return config_.cudnn_conv_algo; } bool GetCudnnConvUseMaxWorkspace() const { - return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).cudnn_conv_use_max_workspace; + return config_.cudnn_conv_use_max_workspace; } bool GetCudnnConv1dPadToNc1d() const { - return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).cudnn_conv1d_pad_to_nc1d; + return config_.cudnn_conv1d_pad_to_nc1d; } bool UseTF32() const { - return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).use_tf32; + return config_.use_tf32; } bool IsFuseConvBias() const { - return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).fuse_conv_bias; + return config_.fuse_conv_bias; } const onnxruntime::AttentionKernelOptions* GetAttentionKernelOptions() const { - auto& config = cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this); - config.attention_kernel_options.InitializeOnce(config.sdpa_kernel, true, true); - return &config.attention_kernel_options; + config_.attention_kernel_options.InitializeOnce(config_.sdpa_kernel, true, true); + return &config_.attention_kernel_options; } const cudaDeviceProp& GetDeviceProp() const { - return cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(this).device_prop; + return config_.device_prop; } + // Config is public so that detail::GetCudaKernelAdapterRuntimeConfigForProvider + // (a free function defined after this class) can access it via pointer cast. + mutable cuda::detail::CudaKernelAdapterRuntimeConfig config_; + private: const OrtEp* ort_ep_ = nullptr; }; namespace cuda { +namespace detail { + +// Accessor: config is stored directly on CUDAExecutionProvider; no map or mutex needed. +inline CudaKernelAdapterRuntimeConfig& GetCudaKernelAdapterRuntimeConfigForProvider(const void* provider) { + return const_cast(static_cast(provider))->config_; +} + +} // namespace detail // Populate the per-provider adapter config from a pre-filled initializer struct. // Callers (e.g. CudaEp constructor) construct a detail::CudaKernelAdapterRuntimeConfig, diff --git a/onnxruntime/test/framework/dynamic_plugin_ep_test.cc b/onnxruntime/test/framework/dynamic_plugin_ep_test.cc index 430e76df6e588..61687632b23b8 100644 --- a/onnxruntime/test/framework/dynamic_plugin_ep_test.cc +++ b/onnxruntime/test/framework/dynamic_plugin_ep_test.cc @@ -98,7 +98,7 @@ TEST(DynamicPluginEpInfraTest, UninitializedStateReturnsSafeDefaults) { #if defined(USE_CUDA) && defined(ORT_USE_EP_API_ADAPTERS) TEST(DynamicPluginEpInfraTest, CudaKernelAdapterRuntimeConfigExposesFuseConvBiasAndSdpaKernel) { - onnxruntime::cuda::CUDAExecutionProvider provider{"CudaPluginExecutionProvider"}; + onnxruntime::CUDAExecutionProvider provider{"CudaPluginExecutionProvider"}; auto& config = onnxruntime::cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(&provider); config.fuse_conv_bias = true; config.sdpa_kernel = static_cast(onnxruntime::contrib::attention::AttentionBackend::MATH); @@ -110,8 +110,6 @@ TEST(DynamicPluginEpInfraTest, CudaKernelAdapterRuntimeConfigExposesFuseConvBias EXPECT_FALSE(attention_kernel_options->UseFlashAttention()); EXPECT_FALSE(attention_kernel_options->UseEfficientAttention()); EXPECT_FALSE(attention_kernel_options->UseCudnnFlashAttention()); - - onnxruntime::cuda::detail::RemoveCudaKernelAdapterRuntimeConfigForProvider(&provider); } #endif From 8da28f5573f38c1b2866f75bc0a160d7e49424c5 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 27 Mar 2026 01:57:05 -0700 Subject: [PATCH 31/48] update doc about arena and resource accounting --- docs/cuda_plugin_ep/cuda_plugin_ep_design.md | 67 +++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md index d7f7730bb4330..f04b4966151ff 100644 --- a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -771,7 +771,54 @@ include/onnxruntime/ep/ ## 13. Future Work -1. **Memory arena / allocator parity** — The plugin currently relies on direct `cudaMalloc`/`cudaFree` paths instead of the in-tree CUDA EP's BFC-style arena. Adding a plugin-side arena or a clean way to reuse ORT's allocator infrastructure would reduce allocation overhead, improve memory reuse, and let the plugin honor options such as `gpu_mem_limit` and `arena_extend_strategy`. +1. **Memory arena / allocator parity** — The plugin currently relies on direct `cudaMalloc`/`cudaFree` in `CudaDeviceAllocator` instead of an arena-backed allocator. Two complementary improvements are planned: + + **A. `CudaMempoolArena` (commit e6023b0c)** + + The in-tree CUDA EP gained a native-CUDA-mempool allocator (`cuda_mempool_arena.h/.cc`) that uses `cudaMallocFromPoolAsync` / `cudaFreeAsync` on stream-ordered allocation paths, with a configurable `cudaMemPoolAttrReleaseThreshold` to return memory to the device as it becomes idle. Enabling this in the plugin requires: + + 1. **Make `CudaMempoolArena` compilable in the plugin build.** `cuda_mempool_arena.h` currently includes `cuda_stream_handle.h` and `provider_api.h` (both `SHARED_PROVIDER`-only). The only real dependency is resolving the stream framework pointer. When migrating for plugin use, this class can be refactored to accept a raw `cudaStream_t` directly (or an `OrtSyncStream*`), bypassing the internal `stream->GetHandle()` logic. + + 2. **Implement a thin `OrtAllocator` wrapper around `CudaMempoolArena`.** The plugin factory's `CreateAllocatorImpl` returns an `OrtAllocator*`, while `CudaMempoolArena` is an `IArena` / `IAllocator`. A new class (e.g., `CudaMempoolOrtAllocator`) should own a `CudaMempoolArena` instance and forward the `OrtAllocator` callbacks to it: + + | `OrtAllocator` callback | Implementation | + |-------------------------|----------------| + | `Alloc(size)` | `arena_->Alloc(size)` (allocates on the legacy default stream) | + | `Free(ptr)` | `arena_->Free(ptr)` | + | `Reserve(size)` | `arena_->Reserve(size)` | + | `AllocOnStream(size, stream)` | `cudaStream_t cu_stream = (cudaStream_t)api->SyncStream_GetHandle(stream);`
`arena_->AllocWithCudaStream(size, cu_stream);` | + | `GetStats(kvps)` | Populate from `arena_->GetStats()` | + | `Info()` | Return the `OrtMemoryInfo*` used at construction | + + The `OrtAllocator` C API already supports stream-aware allocation via the optional `AllocOnStream` callback (set on `OrtAllocator` when `version >= kOrtAllocatorAllocOnStreamMinVersion`). ORT core wraps every plugin `OrtAllocator` into `IAllocatorImplWrappingOrtAllocator` (`allocator_adapters.cc`), which dispatches to `AllocOnStream` when the wrapper reports `IsStreamAware() == true`. So there is **no additional plumbing needed in the adapter or framework** — the plugin allocator just needs to set `AllocOnStream` to a non-null function pointer to get full stream-ordered semantics. + + **Important:** The `OrtMemoryInfo::alloc_type` returned by the wrapper must be `OrtDeviceAllocator`, **not** `OrtArenaAllocator`. Both `PluginExecutionProvider::CreatePreferredAllocators()` and `Environment::CreateSharedAllocatorImpl()` explicitly reject `OrtArenaAllocator` from plugin factories — the arena is expected to be opaque to ORT. + + 3. **Parse mempool options.** ORT can pass allocator configuration to the plugin factory through the `allocator_options` (`OrtKeyValuePairs*`) argument of `OrtEpFactory::CreateAllocator`. The relevant keys are defined in `OrtArenaCfg::Keys` (in `allocator.h`): + - `arena.use_cuda_mempool` — set to `"1"` to enable + - `arena.cuda_mempool_release_threshold` — bytes; `0` disables the threshold + - `arena.cuda_mempool_bytes_to_keep_on_shrink` — bytes retained after `Shrink()` + + **How options reach the plugin factory — two paths:** + + | Path | How it calls `CreateAllocator` | `allocator_options` | + |------|-------------------------------|---------------------| + | **Shared allocator** (`OrtApi::CreateSharedAllocator`) | `Environment::CreateSharedAllocatorImpl` → `ep_factory->CreateAllocator(factory, &mem_info, allocator_options, &alloc)` | Caller-provided `OrtKeyValuePairs*` — can carry arena keys | + | **Per-EP allocator** (`PluginExecutionProvider::CreatePreferredAllocators`) | `ep_factory.CreateAllocator(&ep_factory, memory_info, /*options*/ nullptr, &alloc)` | Always `nullptr` today | + + The per-EP path currently passes `nullptr` for options. To support mempool configuration on this path, either: + - **(a)** Parse the arena keys from session options inside `CudaEp` / `CudaEpFactory` (similar to how `CudaEp::Config` already parses other provider options) and store them so `CreateAllocatorImpl` can read them without needing `allocator_options`. + - **(b)** Extend the ORT core per-EP allocator path to forward the config entries to `CreateAllocator` (requires an ORT core change). + + Option (a) is self-contained within the plugin and does not require ORT core changes. + + 4. **Thread the factory logger.** `CudaMempoolArena` takes a `const logging::Logger*`. The plugin factory already owns a logger (`factory.default_logger_` / the `OrtLogger` passed at EP creation). Convert or wrap it and pass it to the arena constructor. + + 5. **Handle `ReleaseAllocatorImpl`.** The factory's `ReleaseAllocatorImpl` switch currently only knows about `CudaDeviceAllocator` and `CudaPinnedAllocator`. Add a third case (`kMempool` or similar) to correctly destroy the new wrapper and its owned `CudaMempoolArena`. + + **B. BFC arena (longer term)** + + If BFC-style arena behavior (`gpu_mem_limit`, `arena_extend_strategy`) is also needed, a similar `OrtAllocator`-wrapping approach would work for `BFCArena`, once its `SHARED_PROVIDER`-only dependencies are removed. The same `AllocOnStream` / `OrtDeviceAllocator` / option-parsing patterns apply. 2. **Profiling and observability** — The in-tree CUDA EP exposes an EP profiler, while the plugin shim currently does not surface equivalent profiling hooks. Future work should wire up `GetProfiler()` for the plugin path, integrate CUDA/NVTX/CUPTI-based tracing where appropriate, and make plugin execution visible in the same profiling flows users already rely on for the bundled CUDA EP. @@ -790,3 +837,21 @@ include/onnxruntime/ep/ 9. **NHWC cleanup and hardening** — Complete the follow-up work described in [Section 5.3.1](#531-nhwc-layout-transformation-support): unify the allowlist, improve internal-domain diagnostics, and add stronger structural NHWC assertions. 10. **CUDA Graph API for plugin EPs** — Add `IsGraphCaptureEnabled`, `IsGraphCaptured`, and `ReplayGraph` callbacks to the `OrtEp` C API (see [Section 5.4.4](#544-what-needs-to-change-in-ort-core-option-a)). This is required for efficient CUDA graph replay in the plugin EP. The capture/replay infrastructure will be reintroduced once the API is extended. + +11. **Resource accounting and annotation-based partitioning (PR #27595)** — ORT is acquiring two related features that affect how graph nodes are partitioned to EPs: + + **A. Resource accounting** + + `IResourceAccountant` lets an EP declare a resource budget (e.g., available VRAM) and have the partitioner stop assigning nodes once that budget is exhausted. The framework passes an `IResourceAccountant*` to `IExecutionProvider::GetCapability()`; the in-tree CUDA EP uses it to compute per-node estimated VRAM cost from initializer sizes. + + For plugin EPs, the `OrtEp::GetCapability` callback currently has no mechanism to receive or report resource usage — the `OrtEp` C API does not expose `IResourceAccountant`. Two design options: + + - **Option A (preferred — ORT core change):** Add an `OrtEp` analogue of the current `IResourceAccountant` flow instead of inventing a separate plugin-only protocol. In the core path, the partitioner passes an `IResourceAccountant*` into `IExecutionProvider::GetCapability(...)`. The plugin equivalent should preserve that model as closely as possible: either expose an accountant-like object through the `OrtEp` API, or add a small `OrtEp` callback surface that the partitioner can use to compute and accumulate per-node resource cost while still keeping the partitioner's threshold/stop-assignment logic in one place. + + - **Option B (plugin-side workaround):** Expose the VRAM threshold through a plugin-specific session option key. During `GetCapabilityImpl`, the plugin reads the threshold from the parsed config and performs its own initializer-size accounting using `OrtEp_GetNodeAttributes` / node-graph-view APIs already present in the `OrtEp` API surface. This avoids an ORT core change but duplicates budget-tracking logic. + + **B. Annotation-based layering** + + PR #27595 also introduces `layering_annotations` — node-level `"layer_ann"` metadata that routes nodes to specific EPs or CPU during partitioning. The expected model is that plugin EPs participate through the same `GetCapability` flow and therefore observe whatever node set ORT presents after applying layering rules. In practice that should mean no plugin-specific changes are needed to respect annotations that exclude nodes from the plugin. However, the plugin design should avoid depending on undocumented filtering details in the `OrtGraph*` contract. If the plugin EP itself needs to *read* layering annotations for internal decisions, or if the API needs to make filtered-vs-unfiltered graph semantics explicit, that would require new `OrtEp` API surface. + + **Recommended action:** Extend the `OrtEp` C API with a plugin equivalent of the existing resource-accounting flow (Option A), and document the graph-view contract for plugin `GetCapability` under layering annotations as part of the same effort that exposes `kOrtSessionOptionsLayerAssignmentSettings` to plugin EP sessions. From 0e105fe1c49ff8dcfc5d763bc67870215efd963f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 27 Mar 2026 13:30:44 -0700 Subject: [PATCH 32/48] update doc for OpSchema API --- docs/cuda_plugin_ep/cuda_plugin_ep_design.md | 120 ++++++++++++++++++- 1 file changed, 119 insertions(+), 1 deletion(-) diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md index f04b4966151ff..1438c1ddbfdc0 100644 --- a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -142,6 +142,35 @@ In the in-tree build, kernels register through centralized tables (`cuda_nhwc_ke // and registers each kernel into an adapter::KernelRegistry. ``` +#### 3.5.1 Type Constraint Names and OpSchema Access + +Every kernel registration includes type constraint names — string literals such as `"T"`, `"T1"`, `"U"` — that must exactly match the formal parameter type-constraint strings defined in the ONNX operator schema. In the current plugin build, these names are **hard-coded** at compile time with no runtime validation against the actual schema. If a constraint name is wrong, kernel matching silently fails during `GetCapability`. + +PR #27713 adds `OrtEpApi` functions that let plugin EPs query ONNX operator schemas from ORT's global schema registry at runtime (available from ORT 1.25): + +| `OrtEpApi` Function | C++ Wrapper | Purpose | +|---------------------|-------------|----------| +| `GetOpSchema(name, max_ver, domain)` | `Ort::GetOpSchema()` | Look up a schema by op name, max opset version, and domain | +| `OpSchema_GetSinceVersion` | `ConstOpSchema::GetSinceVersion()` | Opset version that introduced this schema entry | +| `OpSchema_GetNumInputs` / `_GetNumOutputs` | `GetNumInputs()` / `GetNumOutputs()` | Formal input/output count | +| `OpSchema_GetInputName` / `_GetOutputName` | `GetInputName(i)` / `GetOutputName(i)` | Formal parameter name | +| `OpSchema_GetInputTypeStr` / `_GetOutputTypeStr` | `GetInputTypeStr(i)` / `GetOutputTypeStr(i)` | Type constraint string (e.g., `"T"`) | +| `OpSchema_HasTypeConstraint` | `HasTypeConstraint(str)` | Whether a string is a valid type constraint name in the schema | + +The returned `OrtOpSchema*` is non-owning — it points into the global `ONNX_NAMESPACE::OpSchemaRegistry` singleton and is valid for the lifetime of the ORT process. + +**Why the plugin cannot link its own ONNX library:** The `OpSchemaRegistry` is a Meyers singleton (`static` local in `Instance()`). Each shared library gets its own copy of that static variable — on Windows each DLL is isolated by default, on macOS two-level namespaces have the same effect, and on Linux behavior depends on `dlopen` flags. Even when isolation doesn't occur, the EP's registry would lack ORT's contrib and internal schemas, and version mismatches between the EP's ONNX library and ORT's vendored copy could cause silent divergence. The `OrtEpApi` route is the only reliable, portable way to query the schemas ORT actually uses. + +**Impact on the CUDA plugin EP:** + +1. **Registration-time validation.** `CreateCudaKernelRegistry()` can optionally validate each registered kernel's type constraint names against the schema after collecting all entries from `PluginKernelCollector`. A mismatch can be logged as a warning (debug builds) or an error, catching drift when ONNX spec updates rename constraint strings. + +2. **NHWC / internal-domain diagnostics.** For rewritten `com.ms.internal.nhwc` nodes, the schema API can confirm that the kernel's registered domain, version range, and constraint names actually match the internal-domain schema entry, improving the diagnostics called for in [Section 5.3.1.3](#5313-nhwc-design-requirements). + +3. **Parity tooling.** `cuda_plugin_parity_report.py` can use the C++ wrapper to compare the plugin's registered constraint names against the schema, flagging incorrect or missing constraints in the parity report. + +4. **Future: schema-driven registration helpers.** A `KernelDefBuilder` helper could derive constraint names automatically from the schema rather than relying on hard-coded strings, reducing the manual maintenance burden when new opset versions change constraint names. See [Section 11.6](#116-if-opschema-access-is-available-schema-validated-type-constraints). + --- ## 4. CPU Base Class Helpers — The SHARED_PROVIDER Pattern @@ -510,6 +539,7 @@ The adapter layer provides thin wrappers around the ORT C API that present a C++ | `Ep` | `OrtEp` | `EpImpl()`, allocators, data transfer | | `Logger` | Plugin logger | Logging interface | | `DataTransferManager` | `IDataTransfer` | `CopyTensor()` | +| `ConstOpSchema` | `const OrtOpSchema*` | `GetSinceVersion()`, `GetNumInputs()`, `GetInputName()`, `GetInputTypeStr()`, `HasTypeConstraint()` (ORT ≥ 1.25, PR #27713) | --- @@ -731,6 +761,66 @@ Use the plugin-compatible overloads already in `CudaKernel`: // Use: GetCublasHandle(ctx) // or GetCublasHandle(Stream(ctx)) ``` +### 11.6 If OpSchema access is available — schema-validated type constraints + +With the `OrtEpApi` OpSchema APIs (ORT ≥ 1.25, PR #27713), the plugin can validate or derive type constraint names at kernel registration time rather than relying solely on hard-coded strings. + +#### 11.6.1 Validation mode (recommended first step) + +Add a debug-mode validation pass in `CreateCudaKernelRegistry()` that runs after all kernels are collected from `PluginKernelCollector`. For each registered kernel, look up its `OrtOpSchema` and verify that every type constraint name used in the `KernelDef` actually appears in the schema's type constraint map: + +```cpp +// In CreateCudaKernelRegistry(), after building the registry: +#ifndef NDEBUG +for (auto build_fn : entries) { + auto info = build_fn(); + if (info.kernel_def == nullptr) continue; + + // Retrieve the op name, domain, and since_version from the KernelDef. + const char* op_name = info.kernel_def->GetOpName(); + const char* domain = info.kernel_def->GetDomain(); + int since_version = info.kernel_def->GetSinceVersion(); + + // Look up the ONNX schema from ORT's global registry. + Ort::ConstOpSchema schema = Ort::GetOpSchema(op_name, since_version, domain); + if (!schema) continue; // contrib/internal ops may not have an ONNX schema + + // Validate each type constraint name against the schema. + for (const auto& [constraint_name, types] : info.kernel_def->GetTypeConstraints()) { + if (!schema.HasTypeConstraint(constraint_name.c_str())) { + LOGS_DEFAULT(WARNING) << "Plugin kernel " << op_name + << ": type constraint '" << constraint_name + << "' not found in OpSchema (domain=" << domain + << ", version=" << since_version << ")"; + } + } +} +#endif +``` + +This catches hard-to-debug kernel-matching failures caused by constraint name typos or opset version drift. + +#### 11.6.2 Schema-driven constraint helper (future) + +A `KernelDefBuilder` extension could derive constraint names from the schema automatically: + +```cpp +/// Look up the type constraint string for a given input index from the OpSchema. +/// Falls back to the provided default if the schema is not found (e.g., contrib ops). +inline const char* GetInputTypeConstraintName( + const char* op_name, int opset_version, const char* domain, size_t input_index, + const char* fallback = "T") { + Ort::ConstOpSchema schema = Ort::GetOpSchema(op_name, opset_version, domain); + if (!schema || input_index >= schema.GetNumInputs()) return fallback; + // Cache the result to avoid repeated lookups for typed kernel variants. + static thread_local std::string cached_name; + cached_name = schema.GetInputTypeStr(input_index); + return cached_name.c_str(); +} +``` + +This is a quality-of-life improvement rather than a required change — the existing hard-coded constraint names are correct for all currently registered kernels. + --- ## 12. File Layout @@ -838,7 +928,35 @@ include/onnxruntime/ep/ 10. **CUDA Graph API for plugin EPs** — Add `IsGraphCaptureEnabled`, `IsGraphCaptured`, and `ReplayGraph` callbacks to the `OrtEp` C API (see [Section 5.4.4](#544-what-needs-to-change-in-ort-core-option-a)). This is required for efficient CUDA graph replay in the plugin EP. The capture/replay infrastructure will be reintroduced once the API is extended. -11. **Resource accounting and annotation-based partitioning (PR #27595)** — ORT is acquiring two related features that affect how graph nodes are partitioned to EPs: +11. **OpSchema-validated kernel registration (PR #27713)** — PR #27713 adds `OrtEpApi` functions that let plugin EPs query ONNX operator schemas from ORT's global registry (see [Section 3.5.1](#351-type-constraint-names-and-opschema-access)). Concrete follow-up work for the CUDA plugin EP: + + **A. Registration-time validation pass** + + Add a debug/diagnostic pass in `CreateCudaKernelRegistry()` that validates every registered kernel's type constraint names against the schema. This is the highest-value, lowest-risk change — it catches silent kernel-matching failures caused by constraint name drift without altering the registration flow. See [Section 11.6.1](#1161-validation-mode-recommended-first-step) for the implementation pattern. + + **B. NHWC internal-domain schema diagnostics** + + Extend the validation pass to cover `com.ms.internal.nhwc`-domain registrations. When kernel lookup fails for a rewritten NHWC node, the diagnostic can now report exactly which constraint name was expected vs. what the kernel registered, directly addressing the diagnostic requirement in [Section 5.3.1.3](#5313-nhwc-design-requirements). + + **C. Parity report enhancement** + + Update `cuda_plugin_parity_report.py` to use the schema API (via a small C++ test harness or Python ONNX bindings) to flag type-constraint mismatches between the plugin's registered kernels and the ONNX schema, in addition to the existing op-coverage comparison. + + **D. Schema-driven `KernelDefBuilder` helpers (longer term)** + + Create a `KernelDefBuilder` helper that auto-derives constraint names from the schema instead of requiring hard-coded strings. This reduces maintenance burden when new opset versions introduce constraint name changes, but is lower priority than the validation pass since all current constraint names are correct. + + **E. Potential code locations for changes** + + | File | Change | + |------|--------| + | `cuda_plugin_kernels.cu` / `CreateCudaKernelRegistry()` | Add schema validation loop after kernel collection | + | `cuda_kernel_adapter.h` | (Optional) Add schema-aware macro variant or post-registration hook | + | `include/onnxruntime/ep/adapter/kernel_def_builder.h` | (Optional) Add schema-lookup helper for constraint names | + | `cuda_ep.cc` / `GetCapabilityImpl()` | (Optional) Add schema-based diagnostic when `EpGraphSupportInfo_LookUpKernel` returns nullptr | + | `test_cuda_plugin_ep.py` | Add a validation stage that exercises schema-validated registration | + +12. **Resource accounting and annotation-based partitioning (PR #27595)** — ORT is acquiring two related features that affect how graph nodes are partitioned to EPs: **A. Resource accounting** From 18e4a2e11f6751afd03a64dc6781550beb365a24 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 27 Mar 2026 16:15:09 -0700 Subject: [PATCH 33/48] refine design of config storage --- docs/cuda_plugin_ep/cuda_plugin_ep_design.md | 20 ++--- onnxruntime/core/providers/cuda/cuda_kernel.h | 4 + .../core/providers/cuda/math/einsum.cc | 2 +- onnxruntime/core/providers/cuda/math/einsum.h | 3 - onnxruntime/core/providers/cuda/nn/conv.cc | 10 +-- onnxruntime/core/providers/cuda/nn/conv_8.h | 11 +-- .../core/providers/cuda/nn/conv_transpose.cc | 10 +-- .../core/providers/cuda/nn/conv_transpose_8.h | 8 +- .../core/providers/cuda/plugin/cuda_ep.cc | 3 +- .../cuda/plugin/cuda_kernel_adapter.h | 84 ++++++++++--------- 10 files changed, 77 insertions(+), 78 deletions(-) diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md index 1438c1ddbfdc0..f9848461a5464 100644 --- a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -52,16 +52,16 @@ CudaEpFactory adapter::Ep Migrated CUDA kernels └─ use CudaKernel / cuda_kernel_adapter.h - ├─ receive EpImpl() from info.GetExecutionProvider() - ├─ cast that pointer to the shim CUDAExecutionProvider + ├─ cache a shared runtime-config handle during construction + ├─ use CudaKernel accessors for provider settings during Compute() └─ resolve stream-local handles via CudaSyncStream::FromCudaStream() ``` Key ownership relationships: - `CudaEpFactory` implements raw `OrtEpFactory` callbacks and owns shared factory-level state such as the kernel registry, cached `OrtMemoryInfo` instances, and the hardware-device to CUDA-ordinal map. - `CudaEp` inherits from `ep::adapter::Ep`, which itself derives from `OrtEp` and owns a framework-facing `IExecutionProvider` object. -- The plugin-local `CUDAExecutionProvider` in `cuda_kernel_adapter.h` is a real shim object owned by `ep::adapter::Ep`. It is not the full in-tree CUDA EP, but it has its own object identity and stores plugin-specific members — including the wrapped `OrtEp*` and the `CudaKernelAdapterRuntimeConfig` that migrated kernels read at compute time. -- Runtime configuration needed by migrated kernels is stored directly as a member (`config_`) of the shim `CUDAExecutionProvider` object, rather than in a separate map keyed by the provider address. +- The plugin-local `CUDAExecutionProvider` in `cuda_kernel_adapter.h` is a real shim object owned by `ep::adapter::Ep`. It is not the full in-tree CUDA EP, but it has its own object identity and stores plugin-specific members — including the wrapped `OrtEp*` and a provider-owned shared runtime-config object. +- Runtime configuration needed by migrated kernels is stored on that shim provider and exposed to kernels as a cached `shared_ptr`, rather than through a separate global map keyed by the provider address. - `CudaSyncStream` owns `cudaStream_t`, `cublasHandle_t`, `cudnnHandle_t`, and `cublasLtHandle_t` for each sync stream created through the EP API. ### 2.4 Plugin DLL Entry Points @@ -278,7 +278,7 @@ For code paths that need handles without an active stream, `cuda_kernel_adapter. ### 5.3 Provider Access -Kernels access provider configuration through the pointer returned by `info.GetExecutionProvider()`. In the plugin build, `ep::adapter::OpKernelInfo` snapshots three related pointers from the framework `OpKernelInfo` when the kernel is created: +Kernels still discover the shim provider through the pointer returned by `info.GetExecutionProvider()` at construction time. In the plugin build, `ep::adapter::OpKernelInfo` snapshots three related pointers from the framework `OpKernelInfo` when the kernel is created: - the session-owned outer `PluginExecutionProvider` - the wrapped plugin `OrtEp` / `CudaEp` @@ -286,7 +286,7 @@ Kernels access provider configuration through the pointer returned by `info.GetE `OpKernelInfo::GetExecutionProvider()` then returns that cached shim pointer, so migrated kernels receive the real shim `CUDAExecutionProvider` object owned by `ep::adapter::Ep`, not the outer `PluginExecutionProvider` and not the `OrtEp`/`CudaEp` object reinterpreted at the same address. -Caching the shim pointer at kernel-creation time is important for the NHWC path. Re-querying `OrtKernelInfo -> OrtEp -> EpImpl()` during execution was fragile after layout transformation; the current implementation avoids that repeated round-trip and uses the cached shim for runtime provider access. +Caching the shim pointer at kernel-creation time is important for the NHWC path. Re-querying `OrtKernelInfo -> OrtEp -> EpImpl()` during execution was fragile after layout transformation. The current implementation resolves the shim once in the `CudaKernel` constructor, caches a `shared_ptr`, and routes later provider-setting reads through `CudaKernel` accessors instead of repeated provider-pointer casts during `Compute()`. This changes the safety model from the earlier "phantom shim" design: - The shim no longer needs to remain layout-compatible with `IExecutionProvider`. @@ -295,9 +295,9 @@ This changes the safety model from the earlier "phantom shim" design: Provider options flow through the plugin in two stages: - `CudaEpFactory` parses session/provider options into `CudaEp::Config`. -- `CudaEp` copies the subset needed by migrated kernels into `CUDAExecutionProvider::config_` via `SetCudaKernelAdapterRuntimeConfigForProvider(EpImpl(), ...)` during EP construction. +- `CudaEp` copies the subset needed by migrated kernels into the shim provider's runtime config via `SetCudaKernelAdapterRuntimeConfigForProvider(EpImpl(), ...)` during EP construction. -Because `config_` is a direct member of the shim object, there is no heap-allocated map and no mutex — reads at kernel compute time are simple field accesses. Today that stored subset includes TF32, device ID/device properties, cuDNN convolution settings, skip-layer-norm strict mode, fused-conv-bias, and SDPA kernel selection. Other plugin behaviors, such as preferred layout, are handled directly by `CudaEp` callbacks instead of through the shim. +Because the runtime config is provider-owned and cached by kernels as a shared pointer, there is no global map and no mutex. Today that stored subset includes TF32, device ID/device properties, cuDNN convolution settings, skip-layer-norm strict mode, fused-conv-bias, and SDPA kernel selection. Other plugin behaviors, such as preferred layout, are handled directly by `CudaEp` callbacks instead of through the shim. For stream bridging, the preferred helpers are: - `Stream(ctx)` when the kernel only needs a raw `cudaStream_t` @@ -375,9 +375,9 @@ The final support set is chosen from `candidate_nodes`, with the existing CPU-pr **B. Cache the shim provider pointer at kernel creation** -Migrated CUDA kernels expect `info.GetExecutionProvider()` to return the shim `CUDAExecutionProvider`, not the outer `PluginExecutionProvider`. The adapter now resolves that relationship once during kernel creation and caches the inner `EpImpl()` pointer. +Migrated CUDA kernels expect `info.GetExecutionProvider()` to return the shim `CUDAExecutionProvider`, not the outer `PluginExecutionProvider`. The adapter now resolves that relationship once during kernel creation, captures the shim provider's runtime-config object, and uses `CudaKernel` accessors for later provider-setting reads. -This is especially important for NHWC kernels because layout transformation introduces additional runtime paths before the actual CUDA kernel executes. Repeatedly reconstructing `OrtEp -> EpImpl()` from `OrtKernelInfo` during execution proved fragile in that path. The cached-shim approach keeps provider access deterministic and matches the actual object model: +This is especially important for NHWC kernels because layout transformation introduces additional runtime paths before the actual CUDA kernel executes. Repeatedly reconstructing provider access from `OrtKernelInfo` during execution proved fragile in that path. The cached-config approach keeps provider access deterministic and matches the actual object model: - outer session EP: `PluginExecutionProvider` - wrapped plugin object: `CudaEp` / `ep::adapter::Ep` diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index f3627b2f97229..13bf5b37490e0 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -109,6 +109,10 @@ class CudaKernel : public OpKernel { } const cudaDeviceProp& GetDeviceProp() const { return provider_->GetDeviceProp(); } + int GetCudnnConvAlgo() const { return provider_->GetCudnnConvAlgo(); } + bool GetCudnnConvUseMaxWorkspace() const { return provider_->GetCudnnConvUseMaxWorkspace(); } + bool GetCudnnConv1dPadToNc1d() const { return provider_->GetCudnnConv1dPadToNc1d(); } + bool IsFuseConvBias() const { return provider_->IsFuseConvBias(); } // Compatibility helper used by kernels that need the underlying ORT stream object. inline onnxruntime::Stream* GetComputeStream(OpKernelContext* ctx) const { diff --git a/onnxruntime/core/providers/cuda/math/einsum.cc b/onnxruntime/core/providers/cuda/math/einsum.cc index 6f2ed41cacab6..7250597e4f3b0 100644 --- a/onnxruntime/core/providers/cuda/math/einsum.cc +++ b/onnxruntime/core/providers/cuda/math/einsum.cc @@ -48,7 +48,7 @@ Status Einsum::ComputeInternal(OpKernelContext* context) const { GetCublasHandle(context), GetCudnnHandle(context), allocator, - cuda_ep_->UseTF32()); + UseTF32()); EinsumComputePreprocessor einsum_compute_preprocessor(einsum_equation_preprocessor, inputs, allocator, &einsum_cuda_assets); diff --git a/onnxruntime/core/providers/cuda/math/einsum.h b/onnxruntime/core/providers/cuda/math/einsum.h index 42fdb52c22c38..60facd285154a 100644 --- a/onnxruntime/core/providers/cuda/math/einsum.h +++ b/onnxruntime/core/providers/cuda/math/einsum.h @@ -6,7 +6,6 @@ #include "core/platform/threadpool.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cuda_kernel.h" -#include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h" #include "einsum_utils/einsum_auxiliary_ops.h" @@ -19,7 +18,6 @@ class Einsum final : public CudaKernel { ORT_ENFORCE(info.GetAttr("equation", &equation_).IsOK(), "Missing 'equation' attribute"); einsum_equation_preprocessor_ = std::make_unique(equation_); - cuda_ep_ = static_cast(info.GetExecutionProvider()); } Status ComputeInternal(OpKernelContext* context) const override; @@ -27,7 +25,6 @@ class Einsum final : public CudaKernel { private: std::string equation_; std::unique_ptr einsum_equation_preprocessor_; - const CUDAExecutionProvider* cuda_ep_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 56e5a35a5c73d..8e5f063bb5a90 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -365,8 +365,6 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected s_.Y = context->Output(0, TensorShape(s_.y_dims)); s_.y_data = reinterpret_cast(s_.Y->MutableData()); - const CUDAExecutionProvider* cuda_ep = - static_cast(this->Info().GetExecutionProvider()); TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; TensorShapeVector y_dims_cudnn{y_dims.begin(), y_dims.end()}; @@ -393,7 +391,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected // PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems // to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT. // See PR #7348 and #7702 for more context. - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + if (this->GetCudnnConv1dPadToNc1d()) { x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); w_dims_cudnn.insert(w_dims_cudnn.begin() + 2, 1); @@ -421,7 +419,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected auto handle = GetCudnnHandle(context); - int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); + int cudnn_conv_algo = this->GetCudnnConvAlgo(); #if !defined(__CUDACC__) cudnn_frontend::HeurMode_t heur_mode; switch (cudnn_conv_algo) { @@ -441,9 +439,9 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected break; } - const auto use_tf32 = cuda_ep->UseTF32(); + const auto use_tf32 = this->UseTF32(); // fuse if this op is part of a FusedConv or if the EP is set to fuse ops - const auto fuse_bias = cuda_ep->IsFuseConvBias() || is_fused_node_; + const auto fuse_bias = this->IsFuseConvBias() || is_fused_node_; const auto fuse_act = is_fused_node_; ORT_RETURN_IF_ERROR(CreateCudnnFeExecutionPlan(x_dims_cudnn, w_dims_cudnn, B, Z, y_dims_cudnn, handle, heur_mode, diff --git a/onnxruntime/core/providers/cuda/nn/conv_8.h b/onnxruntime/core/providers/cuda/nn/conv_8.h index 09745d785dd69..2ce213b92810b 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_8.h +++ b/onnxruntime/core/providers/cuda/nn/conv_8.h @@ -196,9 +196,6 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) s_.y_data = reinterpret_cast(s_.Y->MutableData()); } - const CUDAExecutionProvider* cuda_ep = - static_cast(this->Info().GetExecutionProvider()); - TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; TensorShapeVector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads; if (kernel_rank < 2) { @@ -210,7 +207,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) // PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems // to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT. // See PR #7348 and #7702 for more context. - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + if (this->GetCudnnConv1dPadToNc1d()) { x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); w_dims.insert(w_dims.begin() + 2, 1); @@ -313,13 +310,13 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) cudnnConvolutionFwdAlgoPerf_t perf; int algo_count = 1; - int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); + int cudnn_conv_algo = this->GetCudnnConvAlgo(); ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", cudnn_conv_algo); switch (cudnn_conv_algo) { case 0: { static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; - size_t max_ws_size = cuda_ep->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetCudnnHandle(context), s_, kAllAlgos, num_algos) - : AlgoSearchWorkspaceSize; + size_t max_ws_size = this->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetCudnnHandle(context), s_, kAllAlgos, num_algos) + : AlgoSearchWorkspaceSize; // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index a875b1c0b2aaa..c7eeadefeb555 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -361,8 +361,6 @@ Status ConvTranspose::UpdateState(OpKernelContext* context, bool dyna s_.Y = context->Output(0, s_.y_dims); s_.y_data = reinterpret_cast(s_.Y->MutableData()); - const CUDAExecutionProvider* cuda_ep = - static_cast(this->Info().GetExecutionProvider()); TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; TensorShapeVector y_dims_cudnn{y_dims.begin(), y_dims.end()}; @@ -389,7 +387,7 @@ Status ConvTranspose::UpdateState(OpKernelContext* context, bool dyna // PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems // to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT. // See PR #7348 and #7702 for more context. - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + if (this->GetCudnnConv1dPadToNc1d()) { x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); w_dims_cudnn.insert(w_dims_cudnn.begin() + 2, 1); @@ -417,7 +415,7 @@ Status ConvTranspose::UpdateState(OpKernelContext* context, bool dyna auto handle = GetCudnnHandle(context); - int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); + int cudnn_conv_algo = this->GetCudnnConvAlgo(); #if !defined(__CUDACC__) cudnn_frontend::HeurMode_t heur_mode; switch (cudnn_conv_algo) { @@ -435,8 +433,8 @@ Status ConvTranspose::UpdateState(OpKernelContext* context, bool dyna break; } - auto use_tf32 = cuda_ep->UseTF32(); - const auto fuse_bias = cuda_ep->IsFuseConvBias() || is_fused_node_; + auto use_tf32 = this->UseTF32(); + const auto fuse_bias = this->IsFuseConvBias() || is_fused_node_; const auto fuse_act = is_fused_node_; ORT_RETURN_IF_ERROR(CreateCudnnFeExecutionPlan(x_dims_cudnn, w_dims_cudnn, B, y_dims_cudnn, handle, heur_mode, diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h index cf0a2723111b8..10feb1acf8187 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h @@ -55,14 +55,12 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy CudaT* y_data = nullptr; - const auto* cuda_ep = static_cast(Info().GetExecutionProvider()); - // convert 1D to 2D if (x_dimensions == 3) { // we can either add a fake H or W dimension with a value of 1. to be consistent with the Conv behavior we use // GetCudnnConv1dPadToNc1d to determine which is added. // see Conv::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details. - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + if (this->GetCudnnConv1dPadToNc1d()) { // add fake H dimension const auto insert_at = NHWC ? 1 : 2; @@ -114,7 +112,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy auto y_dims = p.Y->Shape().AsShapeVector(); if (x_dimensions == 3) { - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + if (this->GetCudnnConv1dPadToNc1d()) { // add fake H dimension of 1 // NCHW: N, M, d1 -> N, M, 1, d1 or // NHWC: N, d1, M -> N, 1, d1, M @@ -222,7 +220,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy if (!y_data) { auto y_dims = s_.y_dims.AsShapeVector(); if (x_dimensions == 3) { - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + if (this->GetCudnnConv1dPadToNc1d()) { // erase the fake H dimension y_dims.erase(y_dims.begin() + (NHWC ? 1 : 2)); } else { diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index 5461e30c8199c..f8e7af0f49ed4 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -79,8 +79,7 @@ CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& lo static_cast(EpImpl()), adapter_config); } -CudaEp::~CudaEp() { -} +CudaEp::~CudaEp() = default; /*static*/ const char* ORT_API_CALL CudaEp::GetNameImpl(const OrtEp* this_ptr) noexcept { diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index db8de3865a2c4..90aa982021522 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -338,9 +338,10 @@ namespace cuda { // =================================================================== // Section 5: Runtime configuration for migrated kernels -// Fields are written once during CudaEp construction (under unique_lock) -// and only read afterwards; a shared_mutex in ProviderConfigStore guards -// concurrent access. +// Fields are written once during CudaEp construction and owned by the +// shim CUDAExecutionProvider. Plugin kernels cache a shared_ptr to this +// config object during construction so Compute() does not need to rely on +// raw provider-pointer casts. // =================================================================== namespace detail { @@ -446,9 +447,9 @@ inline const cudaDeviceProp& GetDevicePropForDevice(int device_id) { // (GetCudnnConvAlgo, UseTF32, GetDeviceProp, etc.) without the full // CUDAExecutionProvider class from onnxruntime/core/providers/cuda/. // -// In the plugin build this shim is wrapped by adapter::Ep, so migrated CUDA -// kernels can keep casting `info.GetExecutionProvider()` to -// `CUDAExecutionProvider*` and retrieve the plugin `OrtEp` via GetOrtEp(). +// In the plugin build this shim is wrapped by adapter::Ep. Plugin kernels +// should prefer the CudaKernel base-class accessors for runtime settings +// instead of re-casting info.GetExecutionProvider() inside Compute(). // =================================================================== // Shim for CUDAExecutionProvider required by conv.cc, einsum, and others @@ -465,43 +466,44 @@ class CUDAExecutionProvider : public onnxruntime::IExecutionProvider { return ort_ep_; } + std::shared_ptr GetRuntimeConfig() const { + return config_; + } + int GetCudnnConvAlgo() const { - return config_.cudnn_conv_algo; + return config_->cudnn_conv_algo; } bool GetCudnnConvUseMaxWorkspace() const { - return config_.cudnn_conv_use_max_workspace; + return config_->cudnn_conv_use_max_workspace; } bool GetCudnnConv1dPadToNc1d() const { - return config_.cudnn_conv1d_pad_to_nc1d; + return config_->cudnn_conv1d_pad_to_nc1d; } bool UseTF32() const { - return config_.use_tf32; + return config_->use_tf32; } bool IsFuseConvBias() const { - return config_.fuse_conv_bias; + return config_->fuse_conv_bias; } const onnxruntime::AttentionKernelOptions* GetAttentionKernelOptions() const { - config_.attention_kernel_options.InitializeOnce(config_.sdpa_kernel, true, true); - return &config_.attention_kernel_options; + config_->attention_kernel_options.InitializeOnce(config_->sdpa_kernel, true, true); + return &config_->attention_kernel_options; } const cudaDeviceProp& GetDeviceProp() const { - return config_.device_prop; + return config_->device_prop; } - // Config is public so that detail::GetCudaKernelAdapterRuntimeConfigForProvider - // (a free function defined after this class) can access it via pointer cast. - mutable cuda::detail::CudaKernelAdapterRuntimeConfig config_; - private: const OrtEp* ort_ep_ = nullptr; + std::shared_ptr config_ = + std::make_shared(); }; namespace cuda { namespace detail { -// Accessor: config is stored directly on CUDAExecutionProvider; no map or mutex needed. -inline CudaKernelAdapterRuntimeConfig& GetCudaKernelAdapterRuntimeConfigForProvider(const void* provider) { - return const_cast(static_cast(provider))->config_; +inline std::shared_ptr GetCudaKernelAdapterRuntimeConfigForProvider(const void* provider) { + return static_cast(provider)->GetRuntimeConfig(); } } // namespace detail @@ -512,23 +514,23 @@ inline CudaKernelAdapterRuntimeConfig& GetCudaKernelAdapterRuntimeConfigForProvi // field only requires updating the struct and the call site — no signature change. inline void SetCudaKernelAdapterRuntimeConfigForProvider( const void* provider, const detail::CudaKernelAdapterRuntimeConfig& init_config) { - auto& config = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); + auto config = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); // AttentionKernelOptions contains std::once_flag (not copyable), so assign // the plain-data fields individually rather than relying on operator=. - config.use_tf32 = init_config.use_tf32; - config.skip_layer_norm_strict_mode = init_config.skip_layer_norm_strict_mode; - config.cudnn_conv_algo = init_config.cudnn_conv_algo; - config.cudnn_conv_use_max_workspace = init_config.cudnn_conv_use_max_workspace; - config.cudnn_conv1d_pad_to_nc1d = init_config.cudnn_conv1d_pad_to_nc1d; - config.fuse_conv_bias = init_config.fuse_conv_bias; - config.sdpa_kernel = init_config.sdpa_kernel; - config.device_id = init_config.device_id; - PL_CUDA_CALL_THROW(cudaGetDeviceProperties(&config.device_prop, config.device_id)); + config->use_tf32 = init_config.use_tf32; + config->skip_layer_norm_strict_mode = init_config.skip_layer_norm_strict_mode; + config->cudnn_conv_algo = init_config.cudnn_conv_algo; + config->cudnn_conv_use_max_workspace = init_config.cudnn_conv_use_max_workspace; + config->cudnn_conv1d_pad_to_nc1d = init_config.cudnn_conv1d_pad_to_nc1d; + config->fuse_conv_bias = init_config.fuse_conv_bias; + config->sdpa_kernel = init_config.sdpa_kernel; + config->device_id = init_config.device_id; + PL_CUDA_CALL_THROW(cudaGetDeviceProperties(&config->device_prop, config->device_id)); } inline bool GetCudaKernelAdapterSkipLayerNormStrictMode(const void* provider) { - const auto& config = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); - return config.skip_layer_norm_strict_mode; + const auto config = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); + return config->skip_layer_norm_strict_mode; } // Global aliases and shims @@ -706,10 +708,10 @@ class CudaKernel : public OpKernel { public: explicit CudaKernel(const OpKernelInfo& info) : OpKernel(info), info_(info) { const auto* provider = info.GetExecutionProvider(); - const auto& config = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); - use_tf32_ = config.use_tf32; - device_id_ = config.device_id; - device_prop_ = config.device_prop; + runtime_config_ = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); + use_tf32_ = runtime_config_->use_tf32; + device_id_ = runtime_config_->device_id; + device_prop_ = runtime_config_->device_prop; } virtual ~CudaKernel() = default; Status Compute(OpKernelContext* ctx) const { @@ -819,11 +821,16 @@ class CudaKernel : public OpKernel { return device_prop_; } + int GetCudnnConvAlgo() const { return runtime_config_->cudnn_conv_algo; } + bool GetCudnnConvUseMaxWorkspace() const { return runtime_config_->cudnn_conv_use_max_workspace; } + bool GetCudnnConv1dPadToNc1d() const { return runtime_config_->cudnn_conv1d_pad_to_nc1d; } bool UseTF32() const { return use_tf32_; } + bool IsFuseConvBias() const { return runtime_config_->fuse_conv_bias; } bool IsArchAvailable(int arch) const { return GetDeviceProp().major >= arch; } const OpKernelInfo& Info() const { return info_; } const onnxruntime::AttentionKernelOptions* GetAttentionKernelOptions() const { - return static_cast(info_.GetExecutionProvider())->GetAttentionKernelOptions(); + runtime_config_->attention_kernel_options.InitializeOnce(runtime_config_->sdpa_kernel, true, true); + return &runtime_config_->attention_kernel_options; } // Stub for GetTuningContext — tunable ops are not supported in the plugin. @@ -942,6 +949,7 @@ class CudaKernel : public OpKernel { private: const OpKernelInfo& info_; + std::shared_ptr runtime_config_; cudaDeviceProp device_prop_{}; bool use_tf32_ = true; int device_id_ = 0; From 102b6b243be6748f7ec1b10602181879297e7a33 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 27 Mar 2026 16:30:32 -0700 Subject: [PATCH 34/48] fx build --- cmake/onnxruntime_unittests.cmake | 8 +++++ .../test/framework/dynamic_plugin_ep_test.cc | 36 ++++++++++--------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 59f5e6ff11fd8..cdd67fa55266f 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1090,6 +1090,10 @@ target_include_directories(onnxruntime_test_all PRIVATE ${ONNXRUNTIME_ROOT}/core onnxruntime_apply_test_target_workarounds(onnxruntime_test_all) +if (onnxruntime_USE_CUDA AND onnxruntime_BUILD_CUDA_EP_AS_PLUGIN) + target_compile_definitions(onnxruntime_test_all PRIVATE ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP=1) +endif() + if (MSVC) # The warning means the type of two integral values around a binary operator is narrow than their result. # If we promote the two input values first, it could be more tolerant to integer overflow. @@ -1264,6 +1268,10 @@ block() onnxruntime_apply_test_target_workarounds(onnxruntime_provider_test) onnxruntime_set_plugin_ep_test_environment(onnxruntime_provider_test) + if (onnxruntime_USE_CUDA AND onnxruntime_BUILD_CUDA_EP_AS_PLUGIN) + target_compile_definitions(onnxruntime_provider_test PRIVATE ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP=1) + endif() + # Expose QNN SDK headers to unit tests via an interface target if(onnxruntime_USE_QNN) add_library(qnn_sdk_headers_include INTERFACE) diff --git a/onnxruntime/test/framework/dynamic_plugin_ep_test.cc b/onnxruntime/test/framework/dynamic_plugin_ep_test.cc index 61687632b23b8..9c7b9fc3523f0 100644 --- a/onnxruntime/test/framework/dynamic_plugin_ep_test.cc +++ b/onnxruntime/test/framework/dynamic_plugin_ep_test.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP) + #include "core/framework/execution_provider.h" #include "test/unittest_util/test_dynamic_plugin_ep.h" @@ -15,7 +17,7 @@ namespace onnxruntime::test { -namespace dynamic_plugin_ep_infra = onnxruntime::test::dynamic_plugin_ep_infra; +namespace dynamic_plugin_ep_test_infra = onnxruntime::test::dynamic_plugin_ep_infra; TEST(DynamicPluginEpInfraTest, ParseInitializationConfigParsesOptionalFields) { constexpr std::string_view kConfigJson = R"json( @@ -34,8 +36,8 @@ TEST(DynamicPluginEpInfraTest, ParseInitializationConfigParsesOptionalFields) { } )json"; - dynamic_plugin_ep_infra::InitializationConfig config{}; - ASSERT_STATUS_OK(dynamic_plugin_ep_infra::ParseInitializationConfig(kConfigJson, config)); + dynamic_plugin_ep_test_infra::InitializationConfig config{}; + ASSERT_STATUS_OK(dynamic_plugin_ep_test_infra::ParseInitializationConfig(kConfigJson, config)); EXPECT_EQ(config.ep_library_registration_name, "CudaPluginExecutionProvider"); EXPECT_EQ(config.ep_library_path, "/tmp/libonnxruntime_providers_cuda_plugin.so"); @@ -58,8 +60,8 @@ TEST(DynamicPluginEpInfraTest, ParseInitializationConfigDefaultsUnsetOptionalFie } )json"; - dynamic_plugin_ep_infra::InitializationConfig config{}; - ASSERT_STATUS_OK(dynamic_plugin_ep_infra::ParseInitializationConfig(kConfigJson, config)); + dynamic_plugin_ep_test_infra::InitializationConfig config{}; + ASSERT_STATUS_OK(dynamic_plugin_ep_test_infra::ParseInitializationConfig(kConfigJson, config)); EXPECT_EQ(config.ep_library_registration_name, "ExamplePluginEP"); EXPECT_EQ(config.ep_library_path, "/tmp/libexample_plugin_ep.so"); @@ -76,24 +78,24 @@ TEST(DynamicPluginEpInfraTest, ParseInitializationConfigRejectsMissingRequiredFi } )json"; - dynamic_plugin_ep_infra::InitializationConfig config{}; - ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(dynamic_plugin_ep_infra::ParseInitializationConfig(kConfigJson, config), + dynamic_plugin_ep_test_infra::InitializationConfig config{}; + ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(dynamic_plugin_ep_test_infra::ParseInitializationConfig(kConfigJson, config), "JSON parse error"); } TEST(DynamicPluginEpInfraTest, UninitializedStateReturnsSafeDefaults) { - dynamic_plugin_ep_infra::Shutdown(); + dynamic_plugin_ep_test_infra::Shutdown(); - EXPECT_FALSE(dynamic_plugin_ep_infra::IsInitialized()); - EXPECT_EQ(dynamic_plugin_ep_infra::MakeEp(), nullptr); - EXPECT_FALSE(dynamic_plugin_ep_infra::GetEpName().has_value()); - EXPECT_TRUE(dynamic_plugin_ep_infra::GetTestsToSkip().empty()); + EXPECT_FALSE(dynamic_plugin_ep_test_infra::IsInitialized()); + EXPECT_EQ(dynamic_plugin_ep_test_infra::MakeEp(), nullptr); + EXPECT_FALSE(dynamic_plugin_ep_test_infra::GetEpName().has_value()); + EXPECT_TRUE(dynamic_plugin_ep_test_infra::GetTestsToSkip().empty()); - dynamic_plugin_ep_infra::Shutdown(); + dynamic_plugin_ep_test_infra::Shutdown(); - EXPECT_FALSE(dynamic_plugin_ep_infra::IsInitialized()); - EXPECT_FALSE(dynamic_plugin_ep_infra::GetEpName().has_value()); - EXPECT_TRUE(dynamic_plugin_ep_infra::GetTestsToSkip().empty()); + EXPECT_FALSE(dynamic_plugin_ep_test_infra::IsInitialized()); + EXPECT_FALSE(dynamic_plugin_ep_test_infra::GetEpName().has_value()); + EXPECT_TRUE(dynamic_plugin_ep_test_infra::GetTestsToSkip().empty()); } #if defined(USE_CUDA) && defined(ORT_USE_EP_API_ADAPTERS) @@ -114,3 +116,5 @@ TEST(DynamicPluginEpInfraTest, CudaKernelAdapterRuntimeConfigExposesFuseConvBias #endif } // namespace onnxruntime::test + +#endif // !defined(ORT_MINIMAL_BUILD) && defined(ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP) From 94e6d7a376a44449efebf8bfaa95c14a31aaf3b3 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 27 Mar 2026 16:50:16 -0700 Subject: [PATCH 35/48] refine --- cmake/onnxruntime_providers_cuda_plugin.cmake | 5 +++++ include/onnxruntime/ep/adapter/allocator.h | 2 +- onnxruntime/core/providers/cuda/nn/conv.cc | 10 +++++++++ .../providers/cuda/plugin/cuda_ep_factory.cc | 21 ++++++++++++++----- .../cuda/plugin/cuda_kernel_adapter.h | 12 +++++++---- .../cuda/plugin/cuda_stream_plugin.cc | 11 ++++++++-- .../python/onnxruntime_pybind_state.cc | 3 ++- 7 files changed, 51 insertions(+), 13 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index 84edd0b35e8f0..f352f48689226 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -135,6 +135,11 @@ target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE # receive the ORT-framework force-include (it conflicts with cute::Tensor etc.). # cuda_plugin_kernels.cu already #include "cuda_kernel_adapter.h" directly. # Op-registration .cc files do not include it directly, so they need it here. + # + # IMPORTANT: The CXX force-include order matters — adapters.h MUST precede + # cuda_kernel_adapter.h because the adapter establishes type aliases that the + # kernel adapter header depends on. + # # Force NVCC onto C++20 explicitly. With the VS generator the CUDA standard # property alone still leaves `-std=c++17` in AdditionalOptions. # Suppress NVCC cudafe warnings: diff --git a/include/onnxruntime/ep/adapter/allocator.h b/include/onnxruntime/ep/adapter/allocator.h index c1d4bcaf77017..4e601bb22252b 100644 --- a/include/onnxruntime/ep/adapter/allocator.h +++ b/include/onnxruntime/ep/adapter/allocator.h @@ -41,7 +41,7 @@ class Allocator : public OrtAllocator { private: explicit Allocator(const OrtMemoryInfo* memory_info) - : OrtAllocator{}, memory_info_(memory_info) { + : OrtAllocator{}, memory_info_(memory_info), get_allocator_impl_(nullptr) { version = ORT_API_VERSION; Alloc = AllocImpl; Free = FreeImpl; diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 8e5f063bb5a90..20bc990aaee24 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -237,7 +237,12 @@ Status Conv::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShap CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_operation_graph(handle)); CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->create_execution_plans({heur_mode})); } catch (const std::exception& ex) { +#ifndef BUILD_CUDA_EP_AS_PLUGIN + std::string message = MakeString("Failed to initialize CUDNN Frontend: ", ex.what(), + " with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); +#else std::string message = MakeString("Failed to initialize CUDNN Frontend: ", ex.what()); +#endif return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); } @@ -248,7 +253,12 @@ Status Conv::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShap CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_plans(handle)); } catch (const std::exception& ex) { if (!fuse_bias && !fuse_act && use_tf32) { +#ifndef BUILD_CUDA_EP_AS_PLUGIN + std::string message = MakeString("OP not supported by CUDNN Frontend: ", ex.what(), + " with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); +#else std::string message = MakeString("OP not supported by CUDNN Frontend: ", ex.what()); +#endif return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index 3e3eeb5b4af1a..d7a2e9f5242b2 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -9,8 +9,8 @@ #include #include #include +#include #include -#include namespace onnxruntime { namespace cuda_plugin { @@ -117,7 +117,7 @@ std::string ToUpper(std::string value) { } std::string GetProviderOptionPrefix(std::string_view provider_name) { - return std::format("ep.{}.", onnxruntime::utils::GetLowercaseString(std::string{provider_name})); + return "ep." + onnxruntime::utils::GetLowercaseString(std::string{provider_name}) + "."; } } // namespace @@ -134,6 +134,17 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( size_t& num_ep_devices = *p_num_ep_devices; num_ep_devices = 0; + // Clear stale entries from previous enumerations. The OrtHardwareDevice* + // keys point to ORT-owned memory that may be freed between calls. + { + std::lock_guard lock(factory->device_map_mutex_); + factory->hw_device_to_cuda_index_.clear(); + } + { + std::lock_guard lock(factory->cached_memory_info_mutex_); + factory->cached_memory_infos_.clear(); + } + int cuda_device_index = 0; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *hw_devices[i]; @@ -285,9 +296,9 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( return; } - const std::string msg = std::format( - "Failed to parse session config for key '{}'. Expected {}. Using default value.", - key, expected); + const std::string msg = std::string("Failed to parse session config for key '") + + std::string(key) + "'. Expected " + std::string(expected) + + ". Using default value."; OrtStatus* st = factory->ort_api_.Logger_LogMessage( logger, ORT_LOGGING_LEVEL_WARNING, msg.c_str(), "cuda_ep_factory.cc", __LINE__, "CudaEpFactory"); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 90aa982021522..3e2e3dbc78c3c 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -198,9 +198,9 @@ namespace cuda { /// Each compiled kernel .cc file's macro expansion auto-registers here. /// /// Thread-safety: Instance() uses a function-local static (C++11 §6.7/4: -/// constructed exactly once, even under concurrent first-access). All Add() -/// calls happen during static initialisation of each translation unit -/// (before main()), which is single-threaded per the C++ standard. +/// constructed exactly once, even under concurrent first-access). Add() +/// is guarded by a mutex for formal correctness across translation units, +/// though in practice all calls occur during static initialization. class PluginKernelCollector { public: static PluginKernelCollector& Instance() { @@ -208,11 +208,15 @@ class PluginKernelCollector { return instance; } - void Add(BuildKernelCreateInfoFn fn) { entries_.push_back(fn); } + void Add(BuildKernelCreateInfoFn fn) { + std::lock_guard lock(mutex_); + entries_.push_back(fn); + } const std::vector& Entries() const { return entries_; } private: std::vector entries_; + std::mutex mutex_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc index 0764b8931dc9e..66aa3c4cf1d45 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc @@ -67,12 +67,19 @@ CudaSyncStream::~CudaSyncStream() { } } - if (cuda_stream_) UnregisterStream(cuda_stream_); - if (cublas_handle_) cublasDestroy(cublas_handle_); if (cudnn_handle_) cudnnDestroy(cudnn_handle_); if (cublas_lt_handle_) cublasLtDestroy(cublas_lt_handle_); if (cuda_stream_) { + // Unregister the stream from the global map *after* destroying handles but + // *before* destroying the stream itself. This ordering ensures: + // 1. No concurrent kernel can obtain cuBLAS/cuDNN handles from a destroyed + // CudaSyncStream during the brief window before unregistration. + // 2. UnregisterStream bumps the TLS generation counter, invalidating cached + // lookups in other threads. + // 3. The stream is destroyed only after it is no longer discoverable. + UnregisterStream(cuda_stream_); + auto destroy_result = cudaStreamDestroy(cuda_stream_); if (destroy_result == cudaSuccess && !deferred_cpu_buffers_.empty()) { // Fallback: we only reach here when the earlier cudaStreamSynchronize in diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f7b7056fd277b..212272647f8e3 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -592,7 +592,8 @@ static const OrtEpDevice* FindRegisteredPluginEpDevice( requested_device_id = std::stoi(device_id_it->second); has_requested_device_id = requested_device_id >= 0; } catch (const std::exception&) { - // Invalid device_id values are logged by callers when appropriate. + LOGS_DEFAULT(WARNING) << "Invalid device_id value '" << device_id_it->second + << "' in provider options for EP '" << ep_name << "'; ignoring."; } } } From ab0d23e449c49381e3c4a00ee0fc0403043d254d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 28 Mar 2026 14:38:39 -0700 Subject: [PATCH 36/48] add droput, identity, crop, synamicslice and fft ops --- cmake/onnxruntime_providers_cuda_plugin.cmake | 11 +- docs/cuda_plugin_ep/cuda_plugin_ep_design.md | 41 ++- onnxruntime/core/providers/cuda/nn/dropout.cc | 11 + .../cuda/plugin/cuda_allocator_plugin.cc | 9 +- .../providers/cuda/plugin/cuda_ep_factory.cc | 4 +- .../cuda/plugin/cuda_kernel_adapter.h | 70 +++++ .../core/providers/cuda/tensor/identity_op.cc | 50 ++- .../core/providers/cuda/tensor/identity_op.h | 4 + .../python/onnxruntime_pybind_schema.cc | 56 ++++ .../transformers/test_cuda_plugin_ep.py | 297 ++++++++++++++++++ tools/python/gen_opkernel_doc.py | 54 +++- 11 files changed, 551 insertions(+), 56 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index f352f48689226..790545568cb18 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -79,8 +79,9 @@ list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_common\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_nhwc_kernels\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_contrib_kernels\\.cc$") -# Exclude files that use TensorSeq (incomplete type in plugin build). -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/identity_op\\.cc$") +# Exclude sequence_op.cc — uses TensorSeq (incomplete type in plugin build). +# identity_op.cc is now included: TensorSeq code path is guarded by +# BUILD_CUDA_EP_AS_PLUGIN and opset 14+ registrations use Tensor-only types. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/sequence_op\\.cc$") # Permanently excluded — pure CPU ops, handled by GetCpuPreferredNodes. @@ -101,11 +102,6 @@ list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/llm/.*") # Exclude contrib training ops (shrunken_gather depends on provider_api.h in header). list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/shrunken_gather\\.cc$") -# Exclude contrib ops using GetComputeStream() or framework type deps. -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/math/fft_ops\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/crop\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/dynamicslice\\.cc$") - # Exclude contrib transformers/ (beam search, greedy search, sampling). Those need subgraph inference. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/transformers/.*") @@ -237,6 +233,7 @@ target_link_libraries(onnxruntime_providers_cuda_plugin PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt + CUDA::cufft CUDNN::cudnn_all cudnn_frontend Boost::mp11 diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md index f9848461a5464..7be29ff5e9442 100644 --- a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -121,10 +121,10 @@ This 700+ line header provides everything CUDA kernels need that would normally | Section | What It Provides | |---------|-----------------| -| Error macros | `CUDA_RETURN_IF_ERROR`, `CUBLAS_RETURN_IF_ERROR`, `CUDNN_RETURN_IF_ERROR` | +| Error macros | `CUDA_RETURN_IF_ERROR`, `CUBLAS_RETURN_IF_ERROR`, `CUDNN_RETURN_IF_ERROR`, `CUFFT_RETURN_IF_ERROR` | | Type mappings | `ToCudaType::MappedType = half`, etc. | | CudaKernel base | Scratch buffers, handle access, `Stream()`, `GetComputeStream()` | -| Kernel registration | Self-registering `ONNX_OPERATOR_*_KERNEL_EX` macro overrides via `PluginKernelCollector` | +| Kernel registration | Self-registering `ONNX_OPERATOR_*_KERNEL_EX` macro overrides via `PluginKernelCollector`, including `ONNX_OPERATOR_TWO_TYPED_KERNEL_EX`, `ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX`, and `ONNX_OPERATOR_THREE_TYPED_KERNEL_EX` variants | | CPU shims | Lightweight reimplementations of CPU helpers not linked into plugin | | Math helpers | `HalfGemmOptions`, `CublasMathModeSetter` | | Stream shim | `OrtStreamAdapter`/`PluginStreamShim` to present a framework-compatible `Stream*` view over a raw `cudaStream_t` where needed | @@ -576,13 +576,10 @@ Section 7 reflects the current source exclusions in `cmake/onnxruntime_providers | `core/providers/cuda/controlflow/*` | The framework controlflow kernels inherit from CPU-side controlflow bases (`If`, `Loop`, `Scan`) and are intentionally omitted from the plugin source list | No change is currently planned. The plugin uses its own `cuda_controlflow_plugin.cc` wrappers instead of these framework sources | | `tunable/*` | Depends on the real `CudaTuningContext` and other framework CUDA EP infrastructure that is not available in the plugin build | Add a plugin-capable tuning context and remove the remaining framework-only tunable dependencies | | `math/einsum.cc` | The top-level framework einsum provider path is still not plugin-safe even though supporting utility code is now included | Finish a plugin-safe top-level einsum path and then remove the CMake exclusion | -| `tensor/identity_op.cc` | Uses `TensorSeq`, which is still not adapter-safe here | Add `TensorSeq` adapter coverage | -| `tensor/sequence_op.cc` | Uses `TensorSeq`, which is still not adapter-safe here | Same as above | +| `tensor/sequence_op.cc` | Uses `TensorSeq`, which is still not adapter-safe here | Add `TensorSeq` adapter coverage | | `contrib_ops/cuda/llm/*` | Contrib LLM sources have not gone through the same adapter-migration pass as the core CUDA LLM kernels | Finish contrib-LLM-specific adapter work | | `contrib_ops/cuda/tensor/shrunken_gather.cc` | The training-side header path still depends on framework/provider API wiring | Low-priority training-specific adapter work | -| `contrib_ops/cuda/math/fft_ops.cc` | Still has framework stream/type assumptions that are not yet adapter-safe | Finish FFT-specific adapter cleanup | -| `contrib_ops/cuda/tensor/crop.cc` | Still has remaining framework assumptions, so it stays excluded even though some helper-side migration work is already done | Finish and validate the remaining plugin-safe path, then remove the CMake exclusion | -| `contrib_ops/cuda/tensor/dynamicslice.cc` | Still excluded in CMake due to remaining framework assumptions | Finish dynamicslice-specific adapter cleanup | + | `contrib_ops/cuda/transformers/*` | Beam search, greedy search, and sampling depend on broader framework/subgraph integration that has not been adapted for the plugin build | Significant adapter and subgraph support work | | `contrib_ops/cuda/aten_ops/*` | ATen interop is intentionally out of scope for the standalone CUDA plugin build | A separate ATen/plugin strategy | | `contrib_ops/cuda/collective/*` | Collective/NCCL support is intentionally out of scope for the standalone CUDA plugin build | A separate collective/NCCL plugin strategy | @@ -593,7 +590,7 @@ The current exclusions fall into a few categories: 1. **Tunable/framework-dependent infrastructure** — `tunable/*`, contrib transformers, and some contrib LLM paths still rely on framework-only execution-provider services. -2. **Remaining adapter gaps** — `TensorSeq`, some contrib FFT/crop/dynamicslice paths, and contrib-LLM-specific plumbing still need dedicated adapter work. +2. **Remaining adapter gaps** — `TensorSeq` (needed for `sequence_op.cc`) and contrib-LLM-specific plumbing still need dedicated adapter work. 3. **Deliberate scope cuts** — ATen and collective/NCCL sources remain intentionally out of scope for the standalone CUDA plugin. @@ -608,6 +605,7 @@ The branch still contains a small set of plugin guards in both infrastructure an - Infrastructure files such as `cuda_kernel.h`, `cuda_common.h`, and `cudnn_common.h` still need build-mode gates. - `generator/constant_of_shape.h` still needs a plugin-specific path because `ConstantOfShapeBase` depends on framework-only tensor-attribute helpers. - Tunable kernels such as `math/matmul.cc` still gate framework-only registration paths. +- `tensor/identity_op.h` guards the `TensorSeq` code path and `context->InputType()` call with `#ifndef BUILD_CUDA_EP_AS_PLUGIN` — the plugin build handles only the `Tensor` path. `identity_op.cc` uses conditional macros (`IDENTITY_V_TYPES` / `IDENTITY_V_TYPES_IRv9`) so opset 14+ registrations use `AllFixedSizeTensorTypes()` in the plugin build. Additionally, old Dropout opset 7–9 and 10–11 kernel registrations were moved from `identity_op.cc` to `nn/dropout.cc` so that each op's registrations live in that op's own source file. - A few tensor kernels (`pad.cc`, `tile.cc`, `unsqueeze.cc`, `upsample.*`, `space_depth_ops.h`, `scatter_nd.*`) still contain localized plugin guards where adapter and framework paths have not fully converged. The broad trend remains positive: most operator-level plugin conditionals were removed by moving reusable CPU/helper logic into shared headers and by centralizing stream bridging in `CudaKernel` helpers. @@ -642,7 +640,7 @@ The in-tree CUDA EP and shared provider bridge are compiled identically regardle ### 9.3 Plugin Independence -`libonnxruntime_providers_cuda_plugin.so` is **fully self-contained**. It does not depend on `libonnxruntime_providers_cuda.so` or `libonnxruntime_providers_shared.so` at load time. It statically links against `onnxruntime_framework`, `onnxruntime_graph`, `onnxruntime_common`, `onnxruntime_mlas`, `onnxruntime_flatbuffers`, and links dynamically against CUDA/cuDNN/protobuf. Communication with the ORT runtime happens exclusively through the C API (`OrtApi`/`OrtEpApi`) passed at load time. +`libonnxruntime_providers_cuda_plugin.so` is **fully self-contained**. It does not depend on `libonnxruntime_providers_cuda.so` or `libonnxruntime_providers_shared.so` at load time. It statically links against `onnxruntime_framework`, `onnxruntime_graph`, `onnxruntime_common`, `onnxruntime_mlas`, `onnxruntime_flatbuffers`, and links dynamically against CUDA (`cudart`, `cublas`, `cublasLt`, `cufft`), cuDNN, and protobuf. Communication with the ORT runtime happens exclusively through the C API (`OrtApi`/`OrtEpApi`) passed at load time. ### 9.4 Build Outputs @@ -682,7 +680,13 @@ The plugin is then available as `CudaPluginExecutionProvider` in session provide | Stage 5A | Standard ops: Reshape, Split, Concat, Gather, Unsqueeze | | Stage 5B | More ops: Tile, CumSum, ConstantOfShape, SpaceToDepth, Pad, Slice, Resize, Sum | | Stage 5C | CPU base class ops: Upsample, DepthToSpace | -| Stage 5D | Contrib ops: FastGelu, SkipLayerNorm (BiasDropout is currently skipped as a known issue in the script) | +| Stage 5D | Contrib ops: FastGelu, SkipLayerNorm, BiasDropout | +| Dropout | Dropout opset 7 and opset 10 — verifies registrations moved to `dropout.cc` | +| Quantization | DequantizeLinear / QuantizeLinear opset 21 — exercises `TWO_TYPED_KERNEL_EX` adapter macro | +| GatherBlockQuantized | Contrib `GatherBlockQuantized` — exercises `THREE_TYPED_KERNEL_EX` adapter macro | +| Identity | Identity opset 13 and opset 25 — re-enabled op with `TensorSeq` path guarded | +| Crop | Crop (opset 1) — previously excluded contrib op, now re-enabled | +| Key-ops probe | Session-based probing of representative ops: Sub, Relu, Softmax, Transpose, Cast, Sigmoid, ConvTranspose, LRN | ### 10.2 Running Tests @@ -698,7 +702,18 @@ The current branch has been validated with `./cuda_plugin.sh --build --test_plug ### 10.3 Parity Report -`tools/ci_build/cuda_plugin_parity_report.py` generates a report comparing registered kernels between the in-tree CUDA EP and the plugin EP, identifying gaps. +`tools/ci_build/cuda_plugin_parity_report.py` generates both static and runtime parity reports: + +- **Static mode** (default): Parses CMake exclusion patterns and kernel registration macros from source to compare what the plugin build includes vs. the bundled CUDA EP. + ```bash + python tools/ci_build/cuda_plugin_parity_report.py + ``` +- **Runtime mode** (`--runtime`): Uses the pybind `get_registered_ep_kernel_defs()` API (added in `onnxruntime_pybind_schema.cc`) to query actual kernel registries from both the bundled and plugin EPs, providing an accurate comparison of registered op/domain/version/type-constraint coverage. + ```bash + python tools/ci_build/cuda_plugin_parity_report.py --runtime [--plugin-ep-lib /path/to/plugin.so] + ``` + +The runtime API (`get_registered_ep_kernel_defs(ep_name)`) creates a temporary EP factory and EP instance for the named EP, iterates its kernel registry, and returns `KernelDef` objects with `op_name`, `domain`, `version_range`, `provider`, and `type_constraints` fields. --- @@ -918,9 +933,9 @@ include/onnxruntime/ep/ 5. **Tunable ops** — Implement a plugin-side `ITuningContext` and remove the `ORT_USE_EP_API_ADAPTERS` guards in `matmul.cc`/`gemm.cc` so the plugin can recover runtime kernel selection and profiling-based tuning behavior. -6. **TensorSeq and additional C API coverage** — Add enough sequence/tensor-sequence support to unblock `identity_op.cc` and `sequence_op.cc`, and extend the ORT C API where needed for remaining framework-style attribute accessors such as string-array attributes used by RNN kernels. +6. **TensorSeq and additional C API coverage** — Add enough sequence/tensor-sequence support to unblock `sequence_op.cc` (the last remaining TensorSeq-dependent file), and extend the ORT C API where needed for remaining framework-style attribute accessors such as string-array attributes used by RNN kernels. Note: `identity_op.cc` is now included in the plugin build — its TensorSeq code path is guarded by `#ifndef BUILD_CUDA_EP_AS_PLUGIN` and opset 14+ registrations use `AllFixedSizeTensorTypes()` (Tensor-only) instead of `AllFixedSizeTensorAndSequenceTensorTypes()`. -7. **Remaining contrib exclusions** — Remove the current CMake exclusions for FFT, crop, dynamicslice, and other remaining contrib paths once their framework assumptions are gone or adapter equivalents exist. +7. **Remaining contrib exclusions** — The FFT (`fft_ops.cc`), crop (`crop.cc`), and dynamicslice (`dynamicslice.cc`) exclusions have been removed. These files now compile in the plugin build: FFT ops use `Stream(context)` (which works in both builds) and the `CUFFT_RETURN_IF_ERROR` macro was added to the adapter; crop and dynamicslice had no real framework blockers once tested. The plugin CMake now links `CUDA::cufft` for cuFFT symbol resolution. Remaining contrib exclusions are: `shrunken_gather.cc` (training), `transformers/*` (subgraph), `aten_ops/*` (ATen), `collective/*` (NCCL), and `llm/*` (contrib LLM pass). 8. **CI integration and targeted benchmarking** — Add plugin build + test coverage to CI and include perf-oriented validation so allocator, profiling, and tunable-op regressions are caught early. diff --git a/onnxruntime/core/providers/cuda/nn/dropout.cc b/onnxruntime/core/providers/cuda/nn/dropout.cc index 54202e0231732..7233f2e241a73 100644 --- a/onnxruntime/core/providers/cuda/nn/dropout.cc +++ b/onnxruntime/core/providers/cuda/nn/dropout.cc @@ -35,6 +35,17 @@ struct DropoutComputeImpl { } // namespace +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Dropout, kOnnxDomain, 7, 9, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()), + Dropout); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Dropout, kOnnxDomain, 10, 11, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Dropout); + ONNX_OPERATOR_VERSIONED_KERNEL_EX(Dropout, kOnnxDomain, 12, 12, kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc index 2534ce31de6c2..8f2195b03d1a1 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc @@ -18,7 +18,14 @@ void RestoreDeviceIfKnown(bool restore_prev_device, int prev_device) noexcept { // --------------------------------------------------------------------------- // CudaDeviceAllocator — uses cudaMalloc/cudaFree for GPU device memory. -// Note: No arena or caching layer — every allocation goes directly to CUDA. +// +// PERFORMANCE NOTE (Direct cudaMalloc Penalty): +// No arena or caching layer is provided within this plugin. Every allocation +// goes directly to CUDA (cudaMalloc). For models with dynamic shape resizing +// or many intermediate buffers, this can cause substantial overhead. +// Compared to the built-in CUDA Execution Provider, which has an integrated +// memory arena, this is a notable performance gap unless an external +// memory pool/arena is injected or configured by the application. // --------------------------------------------------------------------------- CudaDeviceAllocator::CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int device_id) diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index d7a2e9f5242b2..7ee0ccedf6bee 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -297,8 +297,8 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( } const std::string msg = std::string("Failed to parse session config for key '") + - std::string(key) + "'. Expected " + std::string(expected) + - ". Using default value."; + std::string(key) + "'. Expected " + std::string(expected) + + ". Using default value."; OrtStatus* st = factory->ort_api_.Logger_LogMessage( logger, ORT_LOGGING_LEVEL_WARNING, msg.c_str(), "cuda_ep_factory.cc", __LINE__, "CudaEpFactory"); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 3e2e3dbc78c3c..ecf2354ac2779 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -175,6 +175,15 @@ using ::onnxruntime::HandleNegativeAxis; } \ } +#undef CUFFT_RETURN_IF_ERROR +#define CUFFT_RETURN_IF_ERROR(expr) \ + { \ + cufftResult _status = (expr); \ + if (_status != CUFFT_SUCCESS) { \ + return onnxruntime::common::Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, std::string("cuFFT error: ") + std::to_string((int)_status)); \ + } \ + } + // =================================================================== // Section 3: Self-registering kernel collector using BuildKernelCreateInfo<> // @@ -312,6 +321,67 @@ class PluginKernelCollector { provider, domain, startver, endver, type, name)>), \ true); +#undef ONNX_OPERATOR_TWO_TYPED_KERNEL_EX +#define ONNX_OPERATOR_TWO_TYPED_KERNEL_EX(name, domain, ver, type1, type2, provider, builder, ...) \ + class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(ver).Provider(CUDA_PLUGIN_EP).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_##type1##_##type2##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ + true); + +#undef ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX +#define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX(name, domain, startver, endver, type1, type2, \ + provider, builder, ...) \ + class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(startver, endver).Provider(CUDA_PLUGIN_EP).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_##type1##_##type2##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ + true); + +#undef ONNX_OPERATOR_THREE_TYPED_KERNEL_EX +#define ONNX_OPERATOR_THREE_TYPED_KERNEL_EX(name, domain, ver, type1, type2, type3, provider, builder, ...) \ + class ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(ver).Provider(CUDA_PLUGIN_EP).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_##type1##_##type2##_##type3##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ + true); + // =================================================================== // Section 4: Logging shim (adapter path only) // LOGS_DEFAULT is re-routed through ep::adapter::LoggingManager, which diff --git a/onnxruntime/core/providers/cuda/tensor/identity_op.cc b/onnxruntime/core/providers/cuda/tensor/identity_op.cc index 5120a661ef971..a92dfdb29a0aa 100644 --- a/onnxruntime/core/providers/cuda/tensor/identity_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/identity_op.cc @@ -5,32 +5,6 @@ namespace onnxruntime { namespace cuda { -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Dropout, - kOnnxDomain, - 7, 9, - kCudaExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) - .Alias(0, 0), - IdentityOp); - -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Dropout, - kOnnxDomain, - 10, - 11, - kCudaExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .Alias(0, 0), - IdentityOp); - ONNX_OPERATOR_VERSIONED_KERNEL_EX( Identity, kOnnxDomain, @@ -51,13 +25,24 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .Alias(0, 0), IdentityOp); +// From opset 14 onward the ONNX spec's type constraint is "V" which includes +// both Tensor and TensorSequence types. In the plugin EP build TensorSeq is +// an incomplete type, so we register only the Tensor subset. +#ifdef BUILD_CUDA_EP_AS_PLUGIN +#define IDENTITY_V_TYPES DataTypeImpl::AllFixedSizeTensorTypes() +#define IDENTITY_V_TYPES_IRv9 DataTypeImpl::AllFixedSizeTensorTypes() +#else +#define IDENTITY_V_TYPES DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes() +#define IDENTITY_V_TYPES_IRv9 DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9() +#endif + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Identity, kOnnxDomain, 14, 18, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()) + .TypeConstraint("V", IDENTITY_V_TYPES) .Alias(0, 0), IdentityOp); @@ -67,7 +52,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 19, 20, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9()) + .TypeConstraint("V", IDENTITY_V_TYPES_IRv9) .Alias(0, 0), IdentityOp); @@ -77,7 +62,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 21, 22, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9()) + .TypeConstraint("V", IDENTITY_V_TYPES_IRv9) .Alias(0, 0), IdentityOp); @@ -87,7 +72,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 23, 24, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9()) + .TypeConstraint("V", IDENTITY_V_TYPES_IRv9) .Alias(0, 0), IdentityOp); @@ -97,8 +82,11 @@ ONNX_OPERATOR_KERNEL_EX( 25, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9()) + .TypeConstraint("V", IDENTITY_V_TYPES_IRv9) .Alias(0, 0), IdentityOp); + +#undef IDENTITY_V_TYPES +#undef IDENTITY_V_TYPES_IRv9 } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/identity_op.h b/onnxruntime/core/providers/cuda/tensor/identity_op.h index 775bc9d1ec924..de338f61126aa 100644 --- a/onnxruntime/core/providers/cuda/tensor/identity_op.h +++ b/onnxruntime/core/providers/cuda/tensor/identity_op.h @@ -15,8 +15,10 @@ class IdentityOp final : public CudaKernel { } Status ComputeInternal(OpKernelContext* context) const override { +#ifndef BUILD_CUDA_EP_AS_PLUGIN auto X_ml_type = context->InputType(0); if (X_ml_type->IsTensorType()) { +#endif const Tensor* X = context->Input(0); if (nullptr == X) { return Status(common::ONNXRUNTIME, common::FAIL, @@ -51,6 +53,7 @@ class IdentityOp final : public CudaKernel { CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mask_data, 0, mask->SizeInBytes(), Stream(context))); } } +#ifndef BUILD_CUDA_EP_AS_PLUGIN } else if (X_ml_type->IsTensorSequenceType()) { const TensorSeq* X = context->Input(0); ORT_ENFORCE(X != nullptr, "IdentityOp cuda: input tensor is missing."); @@ -83,6 +86,7 @@ class IdentityOp final : public CudaKernel { return Status(common::ONNXRUNTIME, common::FAIL, "IdentityOp cuda: unsupported input type."); } +#endif return Status::OK(); } }; diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index 0f51794f12e48..1b5b888110e27 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -4,6 +4,10 @@ #include "python/onnxruntime_pybind_state_common.h" #include "core/framework/kernel_registry.h" +#if !defined(ORT_MINIMAL_BUILD) +#include "core/session/utils.h" +#include "core/session/abi_devices.h" +#endif #include namespace py = pybind11; @@ -93,6 +97,58 @@ void addGlobalSchemaFunctions(pybind11::module& m) { return result; }, "Return a vector of KernelDef for all registered OpKernels"); + +#if !defined(ORT_MINIMAL_BUILD) + m.def( + "get_registered_ep_kernel_defs", + [](const std::string& ep_name) -> std::vector { + std::vector result; + auto& env = GetEnv(); + + // Collect all OrtEpDevice pointers matching the requested EP name. + std::vector selected_devices; + for (const OrtEpDevice* device : env.GetOrtEpDevices()) { + if (device && device->ep_name == ep_name) { + selected_devices.push_back(device); + break; // one device is sufficient to create the factory and query kernels + } + } + + if (selected_devices.empty()) { + throw std::runtime_error( + "No devices found for EP '" + ep_name + + "'. Ensure the plugin EP library is registered via register_execution_provider_library()."); + } + + // Create a factory for the plugin EP. + std::unique_ptr factory; + auto status = CreateIExecutionProviderFactoryForEpDevices(env, selected_devices, factory); + if (!status.IsOK()) { + throw std::runtime_error("Failed to create EP factory for '" + ep_name + "': " + status.ToString()); + } + + // Create an EP instance with default session options. + OrtSessionOptions ort_session_options{}; + const auto& logger = *env.GetLoggingManager()->DefaultLogger().ToExternal(); + auto provider = factory->CreateProvider(ort_session_options, logger); + if (!provider) { + throw std::runtime_error("Failed to create EP instance for '" + ep_name + "'."); + } + + // Extract kernel defs from the EP's kernel registry. + auto kernel_registry = provider->GetKernelRegistry(); + if (kernel_registry) { + for (const auto& entry : kernel_registry->GetKernelCreateMap()) { + result.emplace_back(*(entry.second.kernel_def)); + } + } + + return result; + }, + py::arg("ep_name"), + "Return a vector of KernelDef for a dynamically registered plugin EP.\n" + "The EP must be loaded first via register_execution_provider_library()."); +#endif } void addOpKernelSubmodule(py::module& m) { diff --git a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py index ea645bcbdc39b..96ecd834ec6ad 100644 --- a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py +++ b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py @@ -851,6 +851,162 @@ def test_op_bias_dropout(self): ) self.assertEqual(result, TEST_PASS, "BiasDropout plugin op test failed") + def test_op_dropout_opset7(self): + """Dropout opset 7-9: simple in/out, no mask. Verifies old-version registration in dropout.cc.""" + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Dropout", ["X"], ["Y"], ratio=0.0) + graph = helper.make_graph( + [node], + "test-Dropout-opset7", + [helper.make_tensor_value_info("X", f_dtype, [2, 4])], + [helper.make_tensor_value_info("Y", f_dtype, [2, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 7 + model = helper.make_model(graph, opset_imports=[opset]) + x = np.random.rand(2, 4).astype(np.float32) + result = _run_model_test(target_device, "Dropout_opset7", model, {"X": x}, lambda f: f["X"]) + self.assertEqual(result, TEST_PASS, "Dropout opset 7 plugin op test failed") + + def test_op_dropout_opset10(self): + """Dropout opset 10-11: data + optional mask output. Verifies old-version registration in dropout.cc.""" + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Dropout", ["X"], ["Y", "mask"]) + graph = helper.make_graph( + [node], + "test-Dropout-opset10", + [helper.make_tensor_value_info("X", f_dtype, [2, 4])], + [ + helper.make_tensor_value_info("Y", f_dtype, [2, 4]), + helper.make_tensor_value_info("mask", TensorProto.BOOL, [2, 4]), + ], + ) + opset = OperatorSetIdProto() + opset.version = 10 + model = helper.make_model(graph, opset_imports=[opset]) + x = np.random.rand(2, 4).astype(np.float32) + result = _run_model_test( + target_device, + "Dropout_opset10", + model, + {"X": x}, + lambda f: [f["X"], np.ones((2, 4), dtype=bool)], + ) + self.assertEqual(result, TEST_PASS, "Dropout opset 10 plugin op test failed") + + def test_op_dequantize_linear_opset21(self): + """DequantizeLinear opset 21 uses TWO_TYPED_KERNEL_EX — verifies the new adapter macro.""" + target_device = get_cuda_plugin_device() + node = helper.make_node("DequantizeLinear", ["x", "x_scale", "x_zero_point"], ["y"]) + x_data = np.array([0, 1, 2, 3, 4, 5], dtype=np.uint8) + scale_data = np.array(0.5, dtype=np.float32) + zp_data = np.array(2, dtype=np.uint8) + graph = helper.make_graph( + [node], + "test-DequantizeLinear-opset21", + [ + helper.make_tensor_value_info("x", TensorProto.UINT8, [6]), + helper.make_tensor_value_info("x_scale", TensorProto.FLOAT, []), + helper.make_tensor_value_info("x_zero_point", TensorProto.UINT8, []), + ], + [helper.make_tensor_value_info("y", TensorProto.FLOAT, [6])], + ) + opset = OperatorSetIdProto() + opset.version = 21 + model = helper.make_model(graph, opset_imports=[opset]) + feed = {"x": x_data, "x_scale": scale_data, "x_zero_point": zp_data} + result = _run_model_test( + target_device, + "DequantizeLinear_opset21", + model, + feed, + lambda f: (f["x"].astype(np.float32) - f["x_zero_point"].astype(np.float32)) * f["x_scale"], + ) + self.assertEqual(result, TEST_PASS, "DequantizeLinear opset 21 plugin op test failed") + + def test_op_quantize_linear_opset21(self): + """QuantizeLinear opset 21 uses TWO_TYPED_KERNEL_EX — verifies the new adapter macro.""" + target_device = get_cuda_plugin_device() + node = helper.make_node("QuantizeLinear", ["x", "y_scale", "y_zero_point"], ["y"]) + x_data = np.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5], dtype=np.float32) + scale_data = np.array(0.5, dtype=np.float32) + zp_data = np.array(0, dtype=np.uint8) + graph = helper.make_graph( + [node], + "test-QuantizeLinear-opset21", + [ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [6]), + helper.make_tensor_value_info("y_scale", TensorProto.FLOAT, []), + helper.make_tensor_value_info("y_zero_point", TensorProto.UINT8, []), + ], + [helper.make_tensor_value_info("y", TensorProto.UINT8, [6])], + ) + opset = OperatorSetIdProto() + opset.version = 21 + model = helper.make_model(graph, opset_imports=[opset]) + feed = {"x": x_data, "y_scale": scale_data, "y_zero_point": zp_data} + result = _run_model_test( + target_device, + "QuantizeLinear_opset21", + model, + feed, + lambda f: np.clip( + np.round(f["x"] / f["y_scale"]).astype(np.float32) + f["y_zero_point"].astype(np.float32), + 0, + 255, + ).astype(np.uint8), + atol=1, + ) + self.assertEqual(result, TEST_PASS, "QuantizeLinear opset 21 plugin op test failed") + + def test_op_gather_block_quantized(self): + """GatherBlockQuantized uses THREE_TYPED_KERNEL_EX — verifies the new adapter macro.""" + target_device = get_cuda_plugin_device() + # GatherBlockQuantized: gathers rows from a block-quantized weight matrix. + # data shape [4, 16] (uint8), scales shape [4, 1] (float), indices [2] (int64) + # bits=8, block_size=16 (must be >= 16 and power of 2), quantize_axis=last + node = helper.make_node( + "GatherBlockQuantized", + ["data", "indices", "scales"], + ["output"], + domain="com.microsoft", + gather_axis=0, + quantize_axis=1, + block_size=16, + bits=8, + ) + data = np.random.randint(0, 255, size=(4, 16), dtype=np.uint8) + scales = np.random.rand(4, 1).astype(np.float32) * 0.1 + 0.01 + indices = np.array([0, 2], dtype=np.int64) + graph = helper.make_graph( + [node], + "test-GatherBlockQuantized", + [ + helper.make_tensor_value_info("data", TensorProto.UINT8, [4, 16]), + helper.make_tensor_value_info("indices", TensorProto.INT64, [2]), + helper.make_tensor_value_info("scales", TensorProto.FLOAT, [4, 1]), + ], + [helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 16])], + ) + opset_onnx = OperatorSetIdProto() + opset_onnx.version = 21 + opset_ms = OperatorSetIdProto() + opset_ms.domain = "com.microsoft" + opset_ms.version = 1 + model = helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) + feed = {"data": data, "indices": indices, "scales": scales} + + def expected(f): + # Gather rows [0, 2], then dequantize: float_val = uint8_val * scale + gathered_data = f["data"][f["indices"]] # [2, 16] + gathered_scales = f["scales"][f["indices"]] # [2, 1] + return gathered_data.astype(np.float32) * gathered_scales + + result = _run_model_test(target_device, "GatherBlockQuantized", model, feed, expected, rtol=1e-2, atol=1e-2) + self.assertEqual(result, TEST_PASS, "GatherBlockQuantized plugin op test failed") + def test_op_skip_layer_norm(self): target_device = get_cuda_plugin_device() f_dtype = TensorProto.FLOAT @@ -905,6 +1061,147 @@ def expected(f): ) self.assertEqual(result, TEST_PASS, "SkipLayerNorm plugin op test failed") + # ---- Tests for previously-excluded ops (identity, crop, dynamicslice) ---- + + def test_op_identity(self): + """Identity op: previously excluded from plugin due to TensorSeq; now Tensor-only.""" + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model("Identity", [("X", f_dtype, [3, 4])], [("Y", f_dtype, [3, 4])]) + x = np.random.rand(3, 4).astype(np.float32) + result = _run_model_test(target_device, "Identity", model, {"X": x}, lambda f: f["X"]) + self.assertEqual(result, TEST_PASS, "Identity plugin op test failed") + + def test_op_identity_opset25(self): + """Identity opset 25: highest opset, uses V type constraint (Tensor subset in plugin).""" + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model("Identity", [("X", f_dtype, [2, 5])], [("Y", f_dtype, [2, 5])], opset=25) + x = np.random.rand(2, 5).astype(np.float32) + result = _run_model_test(target_device, "Identity_opset25", model, {"X": x}, lambda f: f["X"]) + self.assertEqual(result, TEST_PASS, "Identity opset 25 plugin op test failed") + + def test_op_crop(self): + """Crop (opset 1, contrib): previously excluded from plugin.""" + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Crop", ["input"], ["output"], border=[1, 1, 1, 1]) + graph = helper.make_graph( + [node], + "test-Crop", + [helper.make_tensor_value_info("input", f_dtype, [1, 1, 4, 4])], + [helper.make_tensor_value_info("output", f_dtype, [1, 1, 2, 2])], + ) + opset = OperatorSetIdProto() + opset.version = 1 + model = helper.make_model(graph, opset_imports=[opset]) + x = np.arange(16, dtype=np.float32).reshape(1, 1, 4, 4) + result = _run_model_test( + target_device, + "Crop", + model, + {"input": x}, + lambda f: f["input"][:, :, 1:3, 1:3], + ) + self.assertEqual(result, TEST_PASS, "Crop plugin op test failed") + + def test_plugin_ep_claims_key_ops(self): + """Session-based probing: verify the plugin EP claims key ops via graph assignment.""" + target_device = get_cuda_plugin_device() + + # Representative ops the plugin EP must claim (op_type, domain, opset, inputs, outputs, attrs). + # One representative per major op family; ops already covered by dedicated test_registration_* + # or test_op_* tests (Add, MatMul, Gemm, Conv, …) are intentionally excluded here. + probe_specs = [ + # binary elementwise (Sub — Add is tested by test_registration_add) + ( + "Sub", + "", + 13, + [("A", TensorProto.FLOAT, [2, 4]), ("B", TensorProto.FLOAT, [2, 4])], + [("Y", TensorProto.FLOAT, [2, 4])], + None, + ), + # unary activation + ("Relu", "", 13, [("X", TensorProto.FLOAT, [2, 4])], [("Y", TensorProto.FLOAT, [2, 4])], None), + # reduction-style + ("Softmax", "", 13, [("X", TensorProto.FLOAT, [2, 4])], [("Y", TensorProto.FLOAT, [2, 4])], {"axis": -1}), + # data-movement + ( + "Transpose", + "", + 13, + [("X", TensorProto.FLOAT, [2, 4])], + [("Y", TensorProto.FLOAT, [4, 2])], + {"perm": [1, 0]}, + ), + # type-dispatch + ( + "Cast", + "", + 13, + [("X", TensorProto.FLOAT, [2, 4])], + [("Y", TensorProto.FLOAT16, [2, 4])], + {"to": int(TensorProto.FLOAT16)}, + ), + # second unary + ("Sigmoid", "", 13, [("X", TensorProto.FLOAT, [2, 4])], [("Y", TensorProto.FLOAT, [2, 4])], None), + # cuDNN: ConvTranspose (Conv already tested by test_registration_conv) + ( + "ConvTranspose", + "", + 13, + [("X", TensorProto.FLOAT, [1, 2, 3, 3]), ("W", TensorProto.FLOAT, [2, 3, 3, 3])], + [("Y", TensorProto.FLOAT, [1, 3, 5, 5])], + None, + ), + # cuDNN: LRN (local response normalization) + ( + "LRN", + "", + 13, + [("X", TensorProto.FLOAT, [1, 2, 4, 4])], + [("Y", TensorProto.FLOAT, [1, 2, 4, 4])], + {"size": 3}, + ), + ] + + claimed = [] + not_claimed = [] + errors = [] + + for op_type, domain, opset, inputs_info, outputs_info, attrs in probe_specs: + model = _make_simple_model(op_type, inputs_info, outputs_info, attrs=attrs, opset=opset, domain=domain) + with tempfile.NamedTemporaryFile(suffix=f"_probe_{op_type}.onnx", delete=False) as tmp: + model_path = tmp.name + try: + save(model, model_path) + sess_options = _create_session_options() + sess_options.graph_optimization_level = onnxrt.GraphOptimizationLevel.ORT_DISABLE_ALL + sess_options.add_provider_for_devices([target_device], {}) + sess = onnxrt.InferenceSession(model_path, sess_options=sess_options) + assigned_nodes, _ = _get_assigned_nodes(sess, CUDA_PLUGIN_EP_NAME) + if assigned_nodes: + claimed.append(op_type) + else: + not_claimed.append(op_type) + except Exception as e: + errors.append((op_type, str(e)[:120])) + finally: + if os.path.exists(model_path): + os.remove(model_path) + + # All probed ops should be claimed by the plugin EP + self.assertFalse( + not_claimed, + f"Plugin EP did not claim these key ops: {not_claimed}", + ) + self.assertFalse( + errors, + f"Errors probing ops: {errors}", + ) + self.assertGreater(len(claimed), 0, "No ops were claimed at all") + if __name__ == "__main__": unittest.main() diff --git a/tools/python/gen_opkernel_doc.py b/tools/python/gen_opkernel_doc.py index f6f9f21396859..b5a4507c1f0b3 100644 --- a/tools/python/gen_opkernel_doc.py +++ b/tools/python/gen_opkernel_doc.py @@ -58,7 +58,42 @@ def expand_providers(provider_filter: [str]): return providers -def main(output_path: pathlib.Path, provider_filter: [str]): +def load_plugin_ep_kernel_defs(plugin_eps): + """Register plugin EP libraries and return their kernel defs. + + Args: + plugin_eps: list of "name:path" strings, e.g. + ["CudaPluginExecutionProvider:/path/to/lib.so"] + + Returns: + list of KernelDef objects from the plugin EPs. + """ + if not plugin_eps: + return [] + + import onnxruntime as ort # noqa: PLC0415 + + defs = [] + for spec in plugin_eps: + if ":" not in spec: + print(f"Warning: invalid --plugin-ep spec '{spec}' (expected NAME:PATH), skipping") + continue + ep_name, lib_path = spec.split(":", 1) + try: + ort.register_execution_provider_library(ep_name, lib_path) + except Exception as e: + if "already registered" not in str(e).lower(): + print(f"Warning: failed to register plugin EP '{ep_name}': {e}") + continue + try: + defs.extend(rtpy.get_registered_ep_kernel_defs(ep_name)) + except Exception as e: + print(f"Warning: failed to get kernel defs for plugin EP '{ep_name}': {e}") + + return defs + + +def main(output_path: pathlib.Path, provider_filter: [str], plugin_eps=None): providers = expand_providers(provider_filter) with open(output_path, "w", newline="", encoding="utf-8") as fout: @@ -109,6 +144,13 @@ def main(output_path: pathlib.Path, provider_filter: [str]): domain = "ai.onnx" index[op.provider][domain][op.op_name].append(op) + # Include kernel defs from plugin EPs + for op in load_plugin_ep_kernel_defs(plugin_eps): + domain = op.domain + if not domain: + domain = "ai.onnx" + index[op.provider][domain][op.op_name].append(op) + # TOC fout.write("## Execution Providers\n\n") for provider in sorted(index.keys()): @@ -166,6 +208,14 @@ def main(output_path: pathlib.Path, provider_filter: [str]): "'ExecutionProvider' is automatically appended as needed. " "e.g. `--providers cpu cuda` will match CPUExecutionProvider and CUDAExecutionProvider.", ) + parser.add_argument( + "--plugin-ep", + nargs="+", + dest="plugin_eps", + help="Register plugin EP libraries and include their kernels. " + "Each entry is NAME:PATH, e.g. " + "'CudaPluginExecutionProvider:/path/to/libonnxruntime_providers_cuda_plugin.so'.", + ) parser.add_argument( "--output_path", help="output markdown file path", @@ -175,4 +225,4 @@ def main(output_path: pathlib.Path, provider_filter: [str]): ) args = parser.parse_args() - main(args.output_path, args.providers) + main(args.output_path, args.providers, args.plugin_eps) From d201f35477b1b8751231e0dc04811768d2075311 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 28 Mar 2026 14:38:58 -0700 Subject: [PATCH 37/48] doc: quick start --- docs/cuda_plugin_ep/QUICK_START.md | 108 +++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 docs/cuda_plugin_ep/QUICK_START.md diff --git a/docs/cuda_plugin_ep/QUICK_START.md b/docs/cuda_plugin_ep/QUICK_START.md new file mode 100644 index 0000000000000..2a94d895892dc --- /dev/null +++ b/docs/cuda_plugin_ep/QUICK_START.md @@ -0,0 +1,108 @@ +# CUDA Plugin EP Quick Start + +## Build Instructions + +To build ONNX Runtime with the CUDA Plugin Execution Provider instead of the statically linked CUDA EP, use the `--build_cuda_ep_as_plugin` flag with the build script. + +```bash +# Build the core framework and the CUDA Plugin EP +./build.sh --config RelWithDebInfo --build_shared_lib --use_cuda --build_cuda_ep_as_plugin +``` + +## Running + +When the plugin is built, it will produce `libonnxruntime_providers_cuda_plugin.so` (or `.dll` on Windows) in the build output directory alongside `libonnxruntime.so`. + +The plugin EP is registered under the name **`CudaPluginExecutionProvider`** and uses the EP Plugin API (`RegisterExecutionProviderLibrary` / `GetEpDevices` / `SessionOptionsAppendExecutionProvider_V2`). It is **not** a drop-in replacement for the in-tree `CUDAExecutionProvider` — you must register the plugin library, enumerate its devices, and add them to the session. + +### C++ API + +Use `Env::RegisterExecutionProviderLibrary` to load the plugin, `Env::GetEpDevices` to discover the CUDA devices it exposes, and `SessionOptions::AppendExecutionProvider_V2` to add the selected device to the session. + +```cpp +#include "onnxruntime_cxx_api.h" + +Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "PluginTest"); + +// 1. Register the plugin library. +env.RegisterExecutionProviderLibrary("CudaPluginExecutionProvider", + ORT_TSTR("libonnxruntime_providers_cuda_plugin.so")); + +// 2. Enumerate available EP devices and pick the CUDA plugin device. +auto ep_devices = env.GetEpDevices(); +std::vector plugin_devices; +for (const auto& dev : ep_devices) { + if (std::string(dev.EpName()) == "CudaPluginExecutionProvider") { + plugin_devices.push_back(dev); + break; // use the first CUDA plugin device + } +} + +// 3. Add the plugin device to session options. +Ort::SessionOptions session_options; +session_options.AppendExecutionProvider_V2(env, plugin_devices, {}); + +Ort::Session session(env, "model.onnx", session_options); +``` + +### Python API + +Use `onnxruntime.register_execution_provider_library` to load the plugin, `onnxruntime.get_ep_devices` to discover devices, and `SessionOptions.add_provider_for_devices` to add the selected device. + +**Device-based approach (recommended):** + +```python +import onnxruntime as ort + +# 1. Register the plugin library. +ort.register_execution_provider_library( + "CudaPluginExecutionProvider", + "libonnxruntime_providers_cuda_plugin.so", +) + +# 2. Enumerate devices and pick the CUDA plugin device. +devices = ort.get_ep_devices() +plugin_device = next(d for d in devices if d.ep_name == "CudaPluginExecutionProvider") + +# 3. Create session with the plugin device. +sess_options = ort.SessionOptions() +sess_options.add_provider_for_devices([plugin_device], {}) + +sess = ort.InferenceSession("model.onnx", sess_options=sess_options) +``` + +**Provider-name approach:** + +You can also pass `CudaPluginExecutionProvider` by name in the `providers` list +(the plugin library must already be registered): + +```python +import onnxruntime as ort + +ort.register_execution_provider_library( + "CudaPluginExecutionProvider", + "libonnxruntime_providers_cuda_plugin.so", +) + +sess = ort.InferenceSession( + "model.onnx", + providers=[ + ("CudaPluginExecutionProvider", {"device_id": "0"}), + "CPUExecutionProvider", + ], +) +``` + +## Known Limitations +* The plugin does not currently support CUDA Graphs. +* The plugin direct-allocates memory using `cudaMalloc` resulting in a potential performance penalty compared to the integrated Memory Arena. + +## Verification +You can generate a parity report comparing the kernels available in the plugin EP versus the statically linked CUDA EP. +```bash +# Check static source registration parity: +python tools/ci_build/cuda_plugin_parity_report.py + +# Check runtime registry parity: +python tools/ci_build/cuda_plugin_parity_report.py --runtime --plugin-ep-lib build/Linux/RelWithDebInfo/libonnxruntime_providers_cuda_plugin.so +``` From 8ffb3c251e59179f380d08328166b9823f6c01f3 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 28 Mar 2026 14:39:18 -0700 Subject: [PATCH 38/48] add script for parity report --- tools/ci_build/cuda_plugin_parity_report.py | 736 ++++++++++++++++++++ 1 file changed, 736 insertions(+) create mode 100755 tools/ci_build/cuda_plugin_parity_report.py diff --git a/tools/ci_build/cuda_plugin_parity_report.py b/tools/ci_build/cuda_plugin_parity_report.py new file mode 100755 index 0000000000000..23fd8fbb6c39f --- /dev/null +++ b/tools/ci_build/cuda_plugin_parity_report.py @@ -0,0 +1,736 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +CUDA EP Plugin Registration Parity Report + +Compares kernel registrations between the bundled CUDA EP and the plugin CUDA EP +by statically parsing source files or interrogating the kernel registry at runtime. +Produces a report showing which ops are in both builds, only in bundled, or only in plugin. + +Usage: + # Static parse mode (default): + python tools/ci_build/cuda_plugin_parity_report.py [--repo-root /path/to/onnxruntime] + + # Runtime inquiry mode: + python tools/ci_build/cuda_plugin_parity_report.py --runtime [--plugin-ep-lib /path/to/libonnxruntime_providers_cuda_plugin.so] +""" + +import argparse +import os +import re +import sys +from collections import defaultdict +from pathlib import Path + +# Regex patterns for kernel registration macros +# These macros define kernel classes and are the source of truth for op registrations. +KERNEL_EX_PATTERNS = [ + # ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) + re.compile( + r"ONNX_OPERATOR_KERNEL_EX\s*\(\s*" + r"(\w+)\s*,\s*" # name + r"(\w+)\s*,\s*" # domain + r"(\d+)\s*,\s*" # version + r"(\w+)\s*," # provider + ), + # ONNX_OPERATOR_TYPED_KERNEL_EX(name, domain, ver, type, provider, builder, ...) + re.compile( + r"ONNX_OPERATOR_TYPED_KERNEL_EX\s*\(\s*" + r"(\w+)\s*,\s*" # name + r"(\w+)\s*,\s*" # domain + r"(\d+)\s*,\s*" # version + r"(\w+)\s*,\s*" # type + r"(\w+)\s*," # provider + ), + # ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, start_ver, end_ver, provider, builder, ...) + re.compile( + r"ONNX_OPERATOR_VERSIONED_KERNEL_EX\s*\(\s*" + r"(\w+)\s*,\s*" # name + r"(\w+)\s*,\s*" # domain + r"(\d+)\s*,\s*" # start_version + r"(\d+)\s*,\s*" # end_version + r"(\w+)\s*," # provider + ), + # ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, domain, start_ver, end_ver, type, provider, builder, ...) + re.compile( + r"ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX\s*\(\s*" + r"(\w+)\s*,\s*" # name + r"(\w+)\s*,\s*" # domain + r"(\d+)\s*,\s*" # start_version + r"(\d+)\s*,\s*" # end_version + r"(\w+)\s*,\s*" # type + r"(\w+)\s*," # provider + ), +] + +# Patterns for contrib ops (CUDA_MS_OP macros expand to ONNX_OPERATOR macros internally) +# Just match the ONNX_OPERATOR_*_KERNEL_EX patterns since the CUDA_MS_OP macros are wrappers. + +# Terminal kernel registration macro names (op_name at arg 0, domain at arg 1 in all variants) +_TERMINAL_KERNEL_MACROS = { + "ONNX_OPERATOR_KERNEL_EX", + "ONNX_OPERATOR_TYPED_KERNEL_EX", + "ONNX_OPERATOR_VERSIONED_KERNEL_EX", + "ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX", + "ONNX_OPERATOR_TWO_TYPED_KERNEL_EX", + "ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX", +} + + +def _preprocess_content(file_path): + """Read a file and preprocess: strip comments, join continuation lines.""" + try: + content = Path(file_path).read_text(errors="replace") + except OSError: + return None + content = re.sub(r"//.*?$", "", content, flags=re.MULTILINE) + content = re.sub(r"/\*.*?\*/", "", content, flags=re.DOTALL) + content = re.sub(r"\\\s*\n\s*", " ", content) + return content + + +def _strip_define_bodies(content): + """Replace #define lines with blanks so regex only matches non-define code.""" + lines = content.split("\n") + return "\n".join("" if re.match(r"\s*#\s*define\b", line) else line for line in lines) + + +def _parse_macro_args_at(text, pos): + """Parse balanced-parentheses argument list starting at *pos* (must be '('). + Returns a list of argument strings, or None on failure.""" + if pos >= len(text) or text[pos] != "(": + return None + depth = 0 + args = [] + arg_start = pos + 1 + for i in range(pos, len(text)): + ch = text[i] + if ch == "(": + if depth == 0: + arg_start = i + 1 + depth += 1 + elif ch == ")": + depth -= 1 + if depth == 0: + arg = text[arg_start:i].strip() + if arg or args: # keep empty trailing arg if there were prior args + args.append(arg) + return args + elif ch == "," and depth == 1: + args.append(text[arg_start:i].strip()) + arg_start = i + 1 + return None + + +def _parse_macro_definitions(content): + """Parse ``#define NAME(params) body`` statements. + Returns ``{name: (param_name_list, body_string)}``.""" + macros = {} + for m in re.finditer(r"^#\s*define\s+(\w+)\s*\(([^)]*)\)\s*(.*)", content, re.MULTILINE): + name = m.group(1) + params = [p.strip() for p in m.group(2).split(",") if p.strip() and p.strip() != "..."] + body = m.group(3).strip() + macros[name] = (params, body) + return macros + + +def _find_calls_in_text(text, target_names): + """Find macro invocations in *text* whose name is in *target_names*. + Returns list of ``(macro_name, [arg, ...])``.""" + results = [] + for m in re.finditer(r"\b(\w+)\s*\(", text): + name = m.group(1) + if name in target_names: + args = _parse_macro_args_at(text, m.end() - 1) + if args is not None: + results.append((name, args)) + return results + + +def _resolve_through_chain(info_field, call_args, parent_params): + """Resolve an ``('param', idx)`` or ``('literal', val)`` through one macro-call level.""" + kind, value = info_field + if kind == "literal": + return info_field + if kind == "param": + idx = value + if idx < len(call_args): + arg_val = call_args[idx].strip() + if arg_val in parent_params: + return ("param", parent_params.index(arg_val)) + return ("literal", arg_val) + return info_field + + +def _resolve_kernel_macros(macros): + """Determine which macros transitively expand to ``ONNX_OPERATOR_*_KERNEL_EX`` + and how their parameters map to (op_name, domain). + + Returns ``{macro_name: [(op_name_info, domain_info), ...]}`` where each + ``*_info`` is ``('literal', str_value)`` or ``('param', int_index)``. + """ + kernel_macros = {} # macro_name -> [(op_info, dom_info), ...] + + # Phase 1: macros whose body directly contains a terminal KERNEL_EX call + for macro_name, (params, body) in macros.items(): + calls = _find_calls_in_text(body, _TERMINAL_KERNEL_MACROS) + entries = set() + for _call_name, call_args in calls: + if len(call_args) >= 2: + a0, a1 = call_args[0].strip(), call_args[1].strip() + op_info = ("param", params.index(a0)) if a0 in params else ("literal", a0) + dom_info = ("param", params.index(a1)) if a1 in params else ("literal", a1) + entries.add((op_info, dom_info)) + if entries: + kernel_macros[macro_name] = list(entries) + + # Phase 2: iteratively resolve higher-level wrapper macros + changed = True + for _ in range(20): # bounded iteration + if not changed: + break + changed = False + for macro_name, (params, body) in macros.items(): + if macro_name in kernel_macros: + continue + calls = _find_calls_in_text(body, set(kernel_macros.keys())) + entries = set() + for call_name, call_args in calls: + for child_op, child_dom in kernel_macros[call_name]: + new_op = _resolve_through_chain(child_op, call_args, params) + new_dom = _resolve_through_chain(child_dom, call_args, params) + entries.add((new_op, new_dom)) + if entries: + kernel_macros[macro_name] = list(entries) + changed = True + + return kernel_macros + + +def _extract_wrapper_registrations(content, file_path): + """Extract registrations from wrapper macros that expand to KERNEL_EX.""" + macros = _parse_macro_definitions(content) + if not macros: + return [] + + kernel_macros = _resolve_kernel_macros(macros) + if not kernel_macros: + return [] + + non_define = _strip_define_bodies(content) + invocations = _find_calls_in_text(non_define, set(kernel_macros.keys())) + + registrations = [] + seen = set() + for call_name, call_args in invocations: + for op_info, dom_info in kernel_macros[call_name]: + # Resolve op_name + if op_info[0] == "literal": + op_name = op_info[1] + elif op_info[0] == "param" and op_info[1] < len(call_args): + op_name = call_args[op_info[1]].strip() + else: + continue + + # Resolve domain + if dom_info[0] == "literal": + domain = dom_info[1] + elif dom_info[0] == "param" and dom_info[1] < len(call_args): + domain = call_args[dom_info[1]].strip() + else: + domain = "kOnnxDomain" + + # Filter out C++ types / param names that aren't valid op names + if not op_name or not op_name[0].isupper(): + continue + + key = (op_name, domain, file_path) + if key not in seen: + seen.add(key) + registrations.append((op_name, domain, 0, file_path)) + + return registrations + + +def extract_kernel_registrations(file_path): + """Extract (op_name, domain, since_version) tuples from kernel registration macros in a file.""" + content = _preprocess_content(file_path) + if content is None: + return [] + + registrations = [] + + # Phase 1: Direct ONNX_OPERATOR_*_KERNEL_EX patterns outside #define bodies + non_define = _strip_define_bodies(content) + for pattern in KERNEL_EX_PATTERNS: + for m in pattern.finditer(non_define): + groups = m.groups() + op_name = groups[0] + domain = groups[1] + since_version = int(groups[2]) + registrations.append((op_name, domain, since_version, str(file_path))) + + # Phase 2: Wrapper macros that expand to ONNX_OPERATOR_*_KERNEL_EX + registrations.extend(_extract_wrapper_registrations(content, str(file_path))) + + return registrations + + +def parse_registration_table(file_path, table_func_name): + """Parse the registration table function to extract op names referenced in BuildKernelCreateInfo calls.""" + registrations = set() + try: + content = Path(file_path).read_text(errors="replace") + except OSError: + return registrations + + # Find the function + func_start = content.find(f"{table_func_name}") + if func_start < 0: + return registrations + + # Extract class names from BuildKernelCreateInfo entries + # Pattern: ONNX_OPERATOR_*_KERNEL_CLASS_NAME(provider, domain, ver, [type,] name) + class_name_patterns = [ + # ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name) + re.compile(r"ONNX_OPERATOR_KERNEL_CLASS_NAME\s*\(\s*\w+\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*(\w+)\s*\)"), + # ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name) + re.compile( + r"ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME\s*\(\s*\w+\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*\w+\s*,\s*(\w+)\s*\)" + ), + # ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, start, end, name) + re.compile( + r"ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME\s*\(\s*\w+\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*\d+\s*,\s*(\w+)\s*\)" + ), + # ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, start, end, type, name) + re.compile( + r"ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME\s*\(\s*\w+\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*\d+\s*,\s*\w+\s*,\s*(\w+)\s*\)" + ), + ] + + # Also handle CUDA_MS_OP_CLASS_NAME and CUDA_MS_OP_TYPED_CLASS_NAME + ms_op_patterns = [ + # CUDA_MS_OP_CLASS_NAME(ver, name) -> domain is kMSDomain + re.compile(r"CUDA_MS_OP_CLASS_NAME\s*\(\s*(\d+)\s*,\s*(\w+)\s*\)"), + # CUDA_MS_OP_TYPED_CLASS_NAME(ver, type, name) + re.compile(r"CUDA_MS_OP_TYPED_CLASS_NAME\s*\(\s*(\d+)\s*,\s*\w+\s*,\s*(\w+)\s*\)"), + ] + + # Scan from function start to end of file (conservative) + search_region = content[func_start:] + + for pattern in class_name_patterns: + for m in pattern.finditer(search_region): + domain, version, name = m.group(1), int(m.group(2)), m.group(3) + registrations.add((name, domain, version)) + + for pattern in ms_op_patterns: + for m in pattern.finditer(search_region): + version, name = int(m.group(1)), m.group(2) + registrations.add((name, "kMSDomain", version)) + + return registrations + + +def get_excluded_files(cmake_path, repo_root): + """Parse the plugin CMake file to get regex exclusion patterns.""" + exclusion_patterns = [] + try: + content = Path(cmake_path).read_text() + except OSError: + return exclusion_patterns + + # Match: list(FILTER CC_SRCS EXCLUDE REGEX "pattern") + # or: list(FILTER CU_SRCS EXCLUDE REGEX "pattern") + for m in re.finditer(r'list\s*\(\s*FILTER\s+\w+\s+EXCLUDE\s+REGEX\s+"([^"]+)"\s*\)', content): + pat = m.group(1) + # Only keep non-commented lines + line_start = content.rfind("\n", 0, m.start()) + 1 + line = content[line_start : m.start()] + if "#" not in line: + exclusion_patterns.append(pat) + + return exclusion_patterns + + +def find_kernel_files(base_dirs, extensions=(".cc",)): + """Find all kernel source files in the given directories.""" + files = [] + for base_dir in base_dirs: + for ext in extensions: + for path in Path(base_dir).rglob(f"*{ext}"): + files.append(str(path)) + return sorted(files) + + +def is_excluded(file_path, exclusion_patterns): + """Check if a file path matches any exclusion pattern.""" + return any(re.search(pat, file_path) for pat in exclusion_patterns) + + +def generate_report(repo_root): + """Generate the full parity report.""" + repo_root = Path(repo_root) + + # Paths + cuda_ep_cc = repo_root / "onnxruntime/core/providers/cuda/cuda_execution_provider.cc" + cuda_nhwc_cc = repo_root / "onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc" + contrib_cc = repo_root / "onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc" + plugin_cmake = repo_root / "cmake/onnxruntime_providers_cuda_plugin.cmake" + + # 1. Parse bundled EP registration tables + bundled_standard = parse_registration_table(cuda_ep_cc, "RegisterCudaKernels") + bundled_nhwc = parse_registration_table(cuda_nhwc_cc, "RegisterCudaKernels") # NHWC uses same function name pattern + bundled_contrib = parse_registration_table(contrib_cc, "RegisterCudaContribKernels") + + # 2. Get exclusion patterns from plugin CMake + exclusion_patterns = get_excluded_files(plugin_cmake, repo_root) + + # 3. Scan all CUDA kernel source files + core_cuda_dir = repo_root / "onnxruntime/core/providers/cuda" + contrib_cuda_dir = repo_root / "onnxruntime/contrib_ops/cuda" + + all_cc_files = find_kernel_files([core_cuda_dir, contrib_cuda_dir]) + + # 4. Categorize files and extract registrations + plugin_registrations = [] # (op, domain, ver, file) tuples - compiled into plugin + excluded_registrations = [] # (op, domain, ver, file) tuples - excluded from plugin + + for f in all_cc_files: + regs = extract_kernel_registrations(f) + if not regs: + continue + if is_excluded(f, exclusion_patterns): + excluded_registrations.extend(regs) + else: + plugin_registrations.extend(regs) + + # 5. Build op sets for comparison + plugin_ops = set() + for op, domain, ver, _ in plugin_registrations: + plugin_ops.add((op, domain, ver)) + + excluded_ops = set() + for op, domain, ver, _ in excluded_registrations: + excluded_ops.add((op, domain, ver)) + + # Unique op names (ignoring version/type variants) + plugin_op_names = {(op, domain) for op, domain, _ in plugin_ops} + excluded_op_names = {(op, domain) for op, domain, _ in excluded_ops} + bundled_op_names = set() + for ops_set in [bundled_standard, bundled_nhwc, bundled_contrib]: + for op, domain, _ver in ops_set: + bundled_op_names.add((op, domain)) + + # 6. Generate report + report = [] + report.append("=" * 70) + report.append("CUDA EP Plugin — Kernel Registration Parity Report") + report.append("=" * 70) + report.append("") + + report.append("## Summary") + report.append(" NOTE: Plugin macro counts may undercount due to nested macro") + report.append(" expansion (e.g., BINARY_OP_VERSIONED_UZILHFD wraps multiple") + report.append(" ONNX_OPERATOR_TYPED_KERNEL_EX calls). Bundled table counts") + report.append(" from RegisterCudaKernels/RegisterCudaContribKernels are accurate.") + report.append("") + report.append(" Bundled EP registration table entries:") + report.append(f" Standard ops: {len(bundled_standard)}") + report.append(f" NHWC ops: {len(bundled_nhwc)}") + report.append(f" Contrib ops: {len(bundled_contrib)}") + report.append(f" Total: {len(bundled_standard) + len(bundled_nhwc) + len(bundled_contrib)}") + report.append("") + report.append(" Plugin kernel macro invocations (in compiled .cc files):") + report.append(f" Total: {len(plugin_registrations)}") + report.append(" Excluded kernel macro invocations:") + report.append(f" Total: {len(excluded_registrations)}") + report.append("") + report.append(" Unique op names (op, domain):") + report.append(f" In plugin: {len(plugin_op_names)}") + report.append(f" Excluded: {len(excluded_op_names)}") + report.append(f" In bundled: {len(bundled_op_names)}") + report.append("") + + # Plugin-only ops (in plugin but not in bundled table — likely already handled) + plugin_only = plugin_op_names - bundled_op_names + if plugin_only: + report.append(f" Plugin-only op names (not in bundled table): {len(plugin_only)}") + for op, domain in sorted(plugin_only): + report.append(f" - {op} ({domain})") + report.append("") + + # Bundled-only ops (in bundled but not in plugin+excluded) + all_source_ops = plugin_op_names | excluded_op_names + bundled_only = bundled_op_names - all_source_ops + if bundled_only: + report.append(f" Bundled-only op names (not in any .cc file KERNEL_EX): {len(bundled_only)}") + for op, domain in sorted(bundled_only): + report.append(f" - {op} ({domain})") + report.append("") + + # Coverage ratio + if bundled_op_names: + coverage = len(plugin_op_names & bundled_op_names) / len(bundled_op_names) * 100 + report.append(f" Plugin coverage: {coverage:.1f}% of bundled unique op names") + report.append("") + + # 7. Excluded ops detail + report.append("## Excluded Ops by Category") + report.append("") + + # Group excluded by file/directory + excluded_by_dir = defaultdict(list) + for op, domain, ver, filepath in excluded_registrations: + # Extract a short category from the path + rel_path = str(filepath).replace(str(repo_root) + "/", "") + parts = rel_path.split("/") + # Find the most descriptive sub-directory + if "contrib_ops" in rel_path: + idx = parts.index("cuda") if "cuda" in parts else 0 + category = "/".join(parts[idx + 1 : -1]) or parts[-1] + elif "core/providers/cuda" in rel_path: + idx = [i for i, p in enumerate(parts) if p == "cuda"][-1] + category = "/".join(parts[idx + 1 : -1]) or parts[-1] + else: + category = "other" + excluded_by_dir[category].append((op, domain, ver, rel_path)) + + for category in sorted(excluded_by_dir): + entries = excluded_by_dir[category] + unique_ops = {(op, domain) for op, domain, _, _ in entries} + report.append(f" [{category}] ({len(entries)} registrations, {len(unique_ops)} unique ops)") + for op, domain in sorted(unique_ops): + report.append(f" - {op} ({domain})") + report.append("") + + report.append("## Active CMake Exclusion Patterns") + for i, pat in enumerate(exclusion_patterns, 1): + report.append(f" {i:2d}. {pat}") + report.append("") + + return "\n".join(report) + + +# ============================================================================ +# Runtime-based parity report (uses the actual kernel registries) +# ============================================================================ + +# Map C++ domain constants to the string forms used by the bundled EP table parser. +_DOMAIN_CONST_TO_STRING = { + "kOnnxDomain": "", + "kMSDomain": "com.microsoft", + "kMSInternalNHWCDomain": "com.microsoft.internal.nhwc", + "kPytorchAtenDomain": "com.microsoft.pytorch.aten", +} + +# Reverse: runtime domain string -> constant name for display. +_DOMAIN_STRING_TO_CONST = {v: k for k, v in _DOMAIN_CONST_TO_STRING.items()} +_DOMAIN_STRING_TO_CONST["ai.onnx"] = "kOnnxDomain" + + +def _runtime_domain_display(domain_str): + """Convert a runtime domain string (e.g. '' or 'com.microsoft') to the constant name used in reports.""" + return _DOMAIN_STRING_TO_CONST.get(domain_str, domain_str or "kOnnxDomain") + + +def _kernel_defs_to_op_names(kernel_defs): + """Convert a list of KernelDef objects to a set of (op_name, domain_display) tuples.""" + op_names = set() + for kd in kernel_defs: + domain = _runtime_domain_display(kd.domain) + op_names.add((kd.op_name, domain)) + return op_names + + +def generate_runtime_report(plugin_ep_name, plugin_lib_path, bundled_ep_name="CUDAExecutionProvider"): + """Generate a parity report by querying actual kernel registries at runtime.""" + import onnxruntime as ort # noqa: PLC0415 + import onnxruntime.capi.onnxruntime_pybind11_state as rtpy # noqa: PLC0415 + + # 1. Get bundled EP kernel defs + bundled_defs = [kd for kd in rtpy.get_all_opkernel_def() if kd.provider == bundled_ep_name] + bundled_op_names = _kernel_defs_to_op_names(bundled_defs) + + # 2. Register the plugin EP + try: + ort.register_execution_provider_library(plugin_ep_name, plugin_lib_path) + except Exception as e: + if "already registered" not in str(e).lower(): + print(f"Error: failed to register plugin EP '{plugin_ep_name}': {e}", file=sys.stderr) + sys.exit(1) + + # 3. Get plugin EP kernel defs via the C++ registry query API + if not hasattr(rtpy, "get_registered_ep_kernel_defs"): + raise RuntimeError( + "get_registered_ep_kernel_defs is not available. " + "Rebuild onnxruntime with the pybind change in onnxruntime_pybind_schema.cc." + ) + + plugin_defs = rtpy.get_registered_ep_kernel_defs(plugin_ep_name) + plugin_op_names = _kernel_defs_to_op_names(plugin_defs) + method = "kernel registry query" + + # 4. Build report + report = [] + report.append("=" * 70) + report.append("CUDA EP Plugin — Runtime Kernel Parity Report") + report.append("=" * 70) + report.append("") + + report.append(f"## Summary (Runtime — {method})") + report.append(f" Bundled EP ({bundled_ep_name}):") + report.append(f" Total kernel registrations: {len(bundled_defs)}") + report.append(f" Unique op names (op, domain): {len(bundled_op_names)}") + report.append("") + report.append(f" Plugin EP ({plugin_ep_name}):") + if plugin_defs is not None: + report.append(f" Total kernel registrations: {len(plugin_defs)}") + report.append(f" Unique op names (op, domain): {len(plugin_op_names)}") + report.append("") + + plugin_only = plugin_op_names - bundled_op_names + if plugin_only: + report.append(f" Plugin-only op names (not in bundled): {len(plugin_only)}") + for op, domain in sorted(plugin_only): + report.append(f" - {op} ({domain})") + report.append("") + + bundled_only = bundled_op_names - plugin_op_names + if bundled_only: + report.append(f" Bundled-only op names (not in plugin): {len(bundled_only)}") + for op, domain in sorted(bundled_only): + report.append(f" - {op} ({domain})") + report.append("") + + common = bundled_op_names & plugin_op_names + if bundled_op_names: + coverage = len(common) / len(bundled_op_names) * 100 + report.append( + f" Plugin coverage: {coverage:.1f}% of bundled unique op names ({len(common)}/{len(bundled_op_names)})" + ) + report.append("") + + # 5. Detailed version/type comparison (only with registry API) + if plugin_defs is not None: + bundled_by_op = defaultdict(list) + for kd in bundled_defs: + key = (kd.op_name, _runtime_domain_display(kd.domain)) + bundled_by_op[key].append(kd) + + plugin_by_op = defaultdict(list) + for kd in plugin_defs: + key = (kd.op_name, _runtime_domain_display(kd.domain)) + plugin_by_op[key].append(kd) + + version_gaps = [] + for op_key in sorted(common): + b_versions = {kd.version_range for kd in bundled_by_op[op_key]} + p_versions = {kd.version_range for kd in plugin_by_op[op_key]} + missing = b_versions - p_versions + if missing: + version_gaps.append((op_key, missing)) + + if version_gaps: + report.append("## Version Gaps (op present but some version ranges missing in plugin)") + for (op, domain), missing in version_gaps: + ranges = ", ".join(f"[{v[0]}, {v[1]}]" if v[1] < 2147483647 else f"{v[0]}+" for v in sorted(missing)) + report.append(f" {op} ({domain}): missing versions {ranges}") + report.append("") + + return "\n".join(report) + + +# _probe_plugin_ops and _make_probe_model removed — session-based probing +# has been moved to test_cuda_plugin_ep.py (test_plugin_ep_claims_key_ops). +# The runtime report now requires the C++ get_registered_ep_kernel_defs API. + + +def main(): + parser = argparse.ArgumentParser(description="CUDA EP Plugin Registration Parity Report") + parser.add_argument("--repo-root", default=None, help="Path to onnxruntime repo root") + parser.add_argument( + "--runtime", + action="store_true", + help="Use runtime kernel registry queries instead of static source analysis. " + "Requires a built onnxruntime with the plugin EP library available.", + ) + parser.add_argument( + "--plugin-ep-name", + default="CudaPluginExecutionProvider", + help="Name of the plugin EP (default: CudaPluginExecutionProvider)", + ) + parser.add_argument( + "--plugin-ep-lib", + default=None, + help="Path to the plugin EP shared library. Auto-detected from build dir if not specified.", + ) + args = parser.parse_args() + + if args.runtime: + # Auto-detect plugin library path if not specified + lib_path = args.plugin_ep_lib + if lib_path is None: + lib_path = _auto_detect_plugin_lib(args.repo_root) + if lib_path is None: + print( + "Error: could not auto-detect plugin EP library. Use --plugin-ep-lib to specify the path.", + file=sys.stderr, + ) + sys.exit(1) + + report = generate_runtime_report(args.plugin_ep_name, lib_path) + else: + if args.repo_root: + repo_root = args.repo_root + else: + script_dir = Path(__file__).resolve().parent + repo_root = script_dir.parent.parent + if not (Path(repo_root) / "onnxruntime").exists(): + print("Error: Could not find repo root. Use --repo-root flag.", file=sys.stderr) + sys.exit(1) + + report = generate_report(repo_root) + + print(report) + + +def _auto_detect_plugin_lib(repo_root): + """Try to find the plugin EP shared library in common build directories.""" + if repo_root is None: + script_dir = Path(__file__).resolve().parent + repo_root = script_dir.parent.parent + + repo_root = Path(repo_root) + lib_name = "libonnxruntime_providers_cuda_plugin.so" + + # Check ORT_CUDA_PLUGIN_PATH env var first + env_path = os.environ.get("ORT_CUDA_PLUGIN_PATH") + if env_path and Path(env_path).exists(): + return env_path + + # Search common build directories (pick the newest if multiple exist) + build_root = repo_root / "build" + if build_root.is_dir(): + candidates = sorted(build_root.rglob(lib_name), key=lambda p: p.stat().st_mtime, reverse=True) + if candidates: + return str(candidates[0]) + + # Check onnxruntime package installation + try: + import onnxruntime # noqa: PLC0415 + + pkg_dir = Path(onnxruntime.__file__).parent / "capi" + candidate = pkg_dir / lib_name + if candidate.exists(): + return str(candidate) + except ImportError: + pass + + return None + + +if __name__ == "__main__": + main() From 481b8950a8ea91dc18ca7d787eb61228932fdcbf Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sat, 28 Mar 2026 21:14:49 -0700 Subject: [PATCH 39/48] Add test cases --- cmake/onnxruntime_providers_cuda_plugin.cmake | 1 - .../core/providers/cuda/math/einsum.cc | 3 +- .../cuda/plugin/cuda_kernel_adapter.h | 7 +- .../transformers/test_cuda_plugin_ep.py | 516 ++++++++++++++++++ 4 files changed, 522 insertions(+), 5 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index 790545568cb18..b823ca83a5e15 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -49,7 +49,6 @@ list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX "onnxruntime/contrib_ops/cuda/c list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_execution_provider\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_provider_factory\\.cc$") list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_provider_interface\\.cc$") -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/math/einsum\\.cc$") # Exclude the framework controlflow/ subdirectory — these inherit from CPU base # classes (If, Loop, Scan). The plugin has its own control flow wrappers in diff --git a/onnxruntime/core/providers/cuda/math/einsum.cc b/onnxruntime/core/providers/cuda/math/einsum.cc index 7250597e4f3b0..648cb30f0fe55 100644 --- a/onnxruntime/core/providers/cuda/math/einsum.cc +++ b/onnxruntime/core/providers/cuda/math/einsum.cc @@ -42,8 +42,9 @@ Status Einsum::ComputeInternal(OpKernelContext* context) const { EinsumEquationPreprocessor einsum_equation_preprocessor(*einsum_equation_preprocessor_); + auto ort_stream = GetOrtStream(context); EinsumOp::EinsumCudaAssets einsum_cuda_assets( - GetComputeStream(context), + ort_stream, GetDeviceProp(), GetCublasHandle(context), GetCudnnHandle(context), diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index ecf2354ac2779..4272541eb482a 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -780,7 +780,7 @@ namespace cuda { // re-open onnxruntime::cuda class CudaKernel : public OpKernel { public: - explicit CudaKernel(const OpKernelInfo& info) : OpKernel(info), info_(info) { + explicit CudaKernel(const OpKernelInfo& info) : OpKernel(info) { const auto* provider = info.GetExecutionProvider(); runtime_config_ = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); use_tf32_ = runtime_config_->use_tf32; @@ -901,7 +901,9 @@ class CudaKernel : public OpKernel { bool UseTF32() const { return use_tf32_; } bool IsFuseConvBias() const { return runtime_config_->fuse_conv_bias; } bool IsArchAvailable(int arch) const { return GetDeviceProp().major >= arch; } - const OpKernelInfo& Info() const { return info_; } + // Delegate to the base OpKernel::Info() which holds a safe copy of OpKernelInfo. + // Do NOT store a reference to the constructor parameter — it becomes dangling. + const OpKernelInfo& Info() const { return OpKernel::Info(); } const onnxruntime::AttentionKernelOptions* GetAttentionKernelOptions() const { runtime_config_->attention_kernel_options.InitializeOnce(runtime_config_->sdpa_kernel, true, true); return &runtime_config_->attention_kernel_options; @@ -1022,7 +1024,6 @@ class CudaKernel : public OpKernel { }; private: - const OpKernelInfo& info_; std::shared_ptr runtime_config_; cudaDeviceProp device_prop_{}; bool use_tf32_ = true; diff --git a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py index 96ecd834ec6ad..f9a4a7b05afd7 100644 --- a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py +++ b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py @@ -1202,6 +1202,522 @@ def test_plugin_ep_claims_key_ops(self): ) self.assertGreater(len(claimed), 0, "No ops were claimed at all") + # ---- Newly-included ops that previously lacked tests ---- + + def test_op_einsum(self): + """Test Einsum op (recently un-excluded from plugin build).""" + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Einsum", + [("A", TensorProto.FLOAT, [2, 3]), ("B", TensorProto.FLOAT, [3, 4])], + [("Y", TensorProto.FLOAT, [2, 4])], + attrs={"equation": "ij,jk->ik"}, + opset=12, + ) + feed = {"A": np.random.rand(2, 3).astype(np.float32), "B": np.random.rand(3, 4).astype(np.float32)} + result = _run_model_test(target_device, "Einsum", model, feed, lambda f: f["A"] @ f["B"]) + self.assertEqual(result, TEST_PASS, "Einsum test failed") + + def test_op_einsum_batch(self): + """Test Einsum op with batch matrix multiply.""" + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Einsum", + [("A", TensorProto.FLOAT, [2, 3, 4]), ("B", TensorProto.FLOAT, [2, 4, 5])], + [("Y", TensorProto.FLOAT, [2, 3, 5])], + attrs={"equation": "bij,bjk->bik"}, + opset=12, + ) + feed = {"A": np.random.rand(2, 3, 4).astype(np.float32), "B": np.random.rand(2, 4, 5).astype(np.float32)} + result = _run_model_test(target_device, "Einsum_batch", model, feed, lambda f: np.matmul(f["A"], f["B"])) + self.assertEqual(result, TEST_PASS, "Einsum batch test failed") + + def test_op_softmax(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Softmax", + [("X", TensorProto.FLOAT, [2, 5])], + [("Y", TensorProto.FLOAT, [2, 5])], + attrs={"axis": 1}, + opset=13, + ) + feed = {"X": np.random.rand(2, 5).astype(np.float32)} + + def expected(f): + x = f["X"] + e = np.exp(x - np.max(x, axis=1, keepdims=True)) + return e / np.sum(e, axis=1, keepdims=True) + + result = _run_model_test(target_device, "Softmax", model, feed, expected) + self.assertEqual(result, TEST_PASS, "Softmax test failed") + + def test_op_relu(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Relu", + [("X", TensorProto.FLOAT, [3, 4])], + [("Y", TensorProto.FLOAT, [3, 4])], + opset=14, + ) + feed = {"X": np.random.randn(3, 4).astype(np.float32)} + result = _run_model_test(target_device, "Relu", model, feed, lambda f: np.maximum(f["X"], 0)) + self.assertEqual(result, TEST_PASS, "Relu test failed") + + def test_op_sigmoid(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Sigmoid", + [("X", TensorProto.FLOAT, [3, 4])], + [("Y", TensorProto.FLOAT, [3, 4])], + opset=13, + ) + feed = {"X": np.random.randn(3, 4).astype(np.float32)} + result = _run_model_test(target_device, "Sigmoid", model, feed, lambda f: 1.0 / (1.0 + np.exp(-f["X"]))) + self.assertEqual(result, TEST_PASS, "Sigmoid test failed") + + def test_op_tanh(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Tanh", + [("X", TensorProto.FLOAT, [3, 4])], + [("Y", TensorProto.FLOAT, [3, 4])], + opset=13, + ) + feed = {"X": np.random.randn(3, 4).astype(np.float32)} + result = _run_model_test(target_device, "Tanh", model, feed, lambda f: np.tanh(f["X"])) + self.assertEqual(result, TEST_PASS, "Tanh test failed") + + def test_op_transpose(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Transpose", + [("X", TensorProto.FLOAT, [2, 3, 4])], + [("Y", TensorProto.FLOAT, [4, 2, 3])], + attrs={"perm": [2, 0, 1]}, + opset=13, + ) + feed = {"X": np.random.rand(2, 3, 4).astype(np.float32)} + result = _run_model_test(target_device, "Transpose", model, feed, lambda f: np.transpose(f["X"], (2, 0, 1))) + self.assertEqual(result, TEST_PASS, "Transpose test failed") + + def test_op_cast(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Cast", + [("X", TensorProto.FLOAT, [3, 4])], + [("Y", TensorProto.FLOAT16, [3, 4])], + attrs={"to": int(TensorProto.FLOAT16)}, + opset=13, + ) + feed = {"X": np.random.rand(3, 4).astype(np.float32)} + result = _run_model_test( + target_device, "Cast", model, feed, lambda f: f["X"].astype(np.float16), rtol=1e-2, atol=1e-2 + ) + self.assertEqual(result, TEST_PASS, "Cast test failed") + + def test_op_where(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Where", + [ + ("cond", TensorProto.BOOL, [3, 4]), + ("X", TensorProto.FLOAT, [3, 4]), + ("Y", TensorProto.FLOAT, [3, 4]), + ], + [("out", TensorProto.FLOAT, [3, 4])], + opset=16, + ) + cond = np.random.randint(0, 2, size=(3, 4)).astype(bool) + x = np.random.rand(3, 4).astype(np.float32) + y = np.random.rand(3, 4).astype(np.float32) + feed = {"cond": cond, "X": x, "Y": y} + result = _run_model_test(target_device, "Where", model, feed, lambda f: np.where(f["cond"], f["X"], f["Y"])) + self.assertEqual(result, TEST_PASS, "Where test failed") + + def test_op_flatten(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Flatten", + [("X", TensorProto.FLOAT, [2, 3, 4])], + [("Y", TensorProto.FLOAT, [2, 12])], + attrs={"axis": 1}, + opset=13, + ) + feed = {"X": np.random.rand(2, 3, 4).astype(np.float32)} + result = _run_model_test(target_device, "Flatten", model, feed, lambda f: f["X"].reshape(2, 12)) + self.assertEqual(result, TEST_PASS, "Flatten test failed") + + def test_op_argmax(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "ArgMax", + [("X", TensorProto.FLOAT, [3, 5])], + [("Y", TensorProto.INT64, [3, 1])], + attrs={"axis": 1, "keepdims": 1}, + opset=13, + ) + feed = {"X": np.random.rand(3, 5).astype(np.float32)} + result = _run_model_test( + target_device, "ArgMax", model, feed, lambda f: np.argmax(f["X"], axis=1).reshape(3, 1) + ) + self.assertEqual(result, TEST_PASS, "ArgMax test failed") + + def test_op_topk(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "TopK", + [("X", TensorProto.FLOAT, [3, 6]), ("K", TensorProto.INT64, [1])], + [("values", TensorProto.FLOAT, [3, 3]), ("indices", TensorProto.INT64, [3, 3])], + attrs={"axis": 1}, + opset=11, + ) + x = np.random.rand(3, 6).astype(np.float32) + k = np.array([3], dtype=np.int64) + feed = {"X": x, "K": k} + + def expected(f): + idx = np.argsort(-f["X"], axis=1)[:, :3] + vals = np.take_along_axis(f["X"], idx, axis=1) + return [vals, idx] + + result = _run_model_test(target_device, "TopK", model, feed, expected) + self.assertEqual(result, TEST_PASS, "TopK test failed") + + def test_op_layer_normalization(self): + """Test LayerNormalization — critical for transformer models.""" + target_device = get_cuda_plugin_device() + normalized_shape = 8 + model = _make_simple_model( + "LayerNormalization", + [ + ("X", TensorProto.FLOAT, [2, 3, normalized_shape]), + ("scale", TensorProto.FLOAT, [normalized_shape]), + ("bias", TensorProto.FLOAT, [normalized_shape]), + ], + [("Y", TensorProto.FLOAT, [2, 3, normalized_shape])], + attrs={"axis": -1, "epsilon": 1e-5}, + opset=17, + ) + scale = np.ones(normalized_shape, dtype=np.float32) + bias = np.zeros(normalized_shape, dtype=np.float32) + scale_init = helper.make_tensor("scale", TensorProto.FLOAT, [normalized_shape], scale.tolist()) + bias_init = helper.make_tensor("bias", TensorProto.FLOAT, [normalized_shape], bias.tolist()) + model.graph.initializer.append(scale_init) + model.graph.initializer.append(bias_init) + + x = np.random.rand(2, 3, normalized_shape).astype(np.float32) + feed = {"X": x} + + def expected(f): + x = f["X"] + mean = np.mean(x, axis=-1, keepdims=True) + var = np.var(x, axis=-1, keepdims=True) + return (x - mean) / np.sqrt(var + 1e-5) * scale + bias + + result = _run_model_test(target_device, "LayerNormalization", model, feed, expected) + self.assertEqual(result, TEST_PASS, "LayerNormalization test failed") + + def test_op_instance_normalization(self): + target_device = get_cuda_plugin_device() + C = 3 + model = _make_simple_model( + "InstanceNormalization", + [ + ("X", TensorProto.FLOAT, [1, C, 4, 4]), + ("scale", TensorProto.FLOAT, [C]), + ("B", TensorProto.FLOAT, [C]), + ], + [("Y", TensorProto.FLOAT, [1, C, 4, 4])], + attrs={"epsilon": 1e-5}, + opset=6, + ) + scale = np.ones(C, dtype=np.float32) + bias = np.zeros(C, dtype=np.float32) + model.graph.initializer.append(helper.make_tensor("scale", TensorProto.FLOAT, [C], scale.tolist())) + model.graph.initializer.append(helper.make_tensor("B", TensorProto.FLOAT, [C], bias.tolist())) + + x = np.random.rand(1, C, 4, 4).astype(np.float32) + feed = {"X": x} + + def expected(f): + x = f["X"] + result = np.empty_like(x) + for c in range(C): + ch = x[0, c] + mean = ch.mean() + var = ch.var() + result[0, c] = (ch - mean) / np.sqrt(var + 1e-5) * scale[c] + bias[c] + return result + + result = _run_model_test(target_device, "InstanceNormalization", model, feed, expected) + self.assertEqual(result, TEST_PASS, "InstanceNormalization test failed") + + def test_op_conv_transpose(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "ConvTranspose", + [ + ("X", TensorProto.FLOAT, [1, 3, 4, 4]), + ("W", TensorProto.FLOAT, [3, 2, 3, 3]), + ], + [("Y", TensorProto.FLOAT, [1, 2, 6, 6])], + attrs={"kernel_shape": [3, 3], "strides": [1, 1], "pads": [0, 0, 0, 0]}, + opset=11, + ) + x = np.random.rand(1, 3, 4, 4).astype(np.float32) + w = np.random.rand(3, 2, 3, 3).astype(np.float32) + feed = {"X": x, "W": w} + + def expected(f): + return F.conv_transpose2d(torch.from_numpy(f["X"]), torch.from_numpy(f["W"])).numpy() + + result = _run_model_test(target_device, "ConvTranspose", model, feed, expected) + self.assertEqual(result, TEST_PASS, "ConvTranspose test failed") + + def test_op_reduce_mean(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "ReduceMean", + [("X", TensorProto.FLOAT, [3, 4, 5])], + [("Y", TensorProto.FLOAT, [3, 1, 5])], + attrs={"axes": [1], "keepdims": 1}, + opset=13, + ) + feed = {"X": np.random.rand(3, 4, 5).astype(np.float32)} + result = _run_model_test( + target_device, "ReduceMean", model, feed, lambda f: np.mean(f["X"], axis=1, keepdims=True) + ) + self.assertEqual(result, TEST_PASS, "ReduceMean test failed") + + def test_op_reduce_sum(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "ReduceSum", + [("X", TensorProto.FLOAT, [3, 4, 5]), ("axes", TensorProto.INT64, [1])], + [("Y", TensorProto.FLOAT, [3, 1, 5])], + attrs={"keepdims": 1}, + opset=13, + ) + axes_init = helper.make_tensor("axes", TensorProto.INT64, [1], [1]) + model.graph.initializer.append(axes_init) + feed = {"X": np.random.rand(3, 4, 5).astype(np.float32)} + result = _run_model_test( + target_device, "ReduceSum", model, feed, lambda f: np.sum(f["X"], axis=1, keepdims=True) + ) + self.assertEqual(result, TEST_PASS, "ReduceSum test failed") + + def test_op_gather_nd(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "GatherND", + [ + ("data", TensorProto.FLOAT, [2, 3, 4]), + ("indices", TensorProto.INT64, [2, 1]), + ], + [("Y", TensorProto.FLOAT, [2, 4])], + attrs={"batch_dims": 1}, + opset=12, + ) + data = np.random.rand(2, 3, 4).astype(np.float32) + indices = np.array([[1], [2]], dtype=np.int64) + feed = {"data": data, "indices": indices} + + def expected(f): + return np.array([f["data"][0, 1], f["data"][1, 2]]) + + result = _run_model_test(target_device, "GatherND", model, feed, expected) + self.assertEqual(result, TEST_PASS, "GatherND test failed") + + def test_op_scatter_elements(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "ScatterElements", + [ + ("data", TensorProto.FLOAT, [3, 3]), + ("indices", TensorProto.INT64, [2, 3]), + ("updates", TensorProto.FLOAT, [2, 3]), + ], + [("Y", TensorProto.FLOAT, [3, 3])], + attrs={"axis": 0}, + opset=16, + ) + data = np.zeros((3, 3), dtype=np.float32) + indices = np.array([[1, 0, 2], [0, 2, 1]], dtype=np.int64) + updates = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + feed = {"data": data, "indices": indices, "updates": updates} + + def expected(f): + result = f["data"].copy() + for i in range(2): + for j in range(3): + result[f["indices"][i, j], j] = f["updates"][i, j] + return result + + result = _run_model_test(target_device, "ScatterElements", model, feed, expected) + self.assertEqual(result, TEST_PASS, "ScatterElements test failed") + + def test_op_onehot(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "OneHot", + [ + ("indices", TensorProto.INT64, [4]), + ("depth", TensorProto.INT64, [1]), + ("values", TensorProto.FLOAT, [2]), + ], + [("Y", TensorProto.FLOAT, [4, 6])], + attrs={"axis": 1}, + opset=11, + ) + indices = np.array([0, 2, 4, 5], dtype=np.int64) + depth = np.array([6], dtype=np.int64) + values = np.array([0.0, 1.0], dtype=np.float32) + feed = {"indices": indices, "depth": depth, "values": values} + + def expected(f): + result = np.zeros((4, 6), dtype=np.float32) + for i, idx in enumerate(f["indices"]): + result[i, idx] = 1.0 + return result + + result = _run_model_test(target_device, "OneHot", model, feed, expected) + self.assertEqual(result, TEST_PASS, "OneHot test failed") + + # NOTE: Range is excluded — it runs on CPU (shape computation op, not claimed by CUDA EP). + + def test_op_non_zero(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "NonZero", + [("X", TensorProto.FLOAT, [3, 4])], + [("Y", TensorProto.INT64, None)], + opset=13, + ) + x = np.array([[1, 0, 3, 0], [0, 5, 0, 7], [0, 0, 0, 10]], dtype=np.float32) + feed = {"X": x} + result = _run_model_test(target_device, "NonZero", model, feed, lambda f: np.array(np.nonzero(f["X"]))) + self.assertEqual(result, TEST_PASS, "NonZero test failed") + + def test_op_grid_sample(self): + target_device = get_cuda_plugin_device() + N, C, H, W = 1, 1, 4, 4 + model = _make_simple_model( + "GridSample", + [ + ("X", TensorProto.FLOAT, [N, C, H, W]), + ("grid", TensorProto.FLOAT, [N, 2, 2, 2]), + ], + [("Y", TensorProto.FLOAT, [N, C, 2, 2])], + attrs={"mode": "bilinear", "padding_mode": "zeros", "align_corners": 0}, + opset=16, + ) + x = np.random.rand(N, C, H, W).astype(np.float32) + grid = np.random.rand(N, 2, 2, 2).astype(np.float32) * 2 - 1 # in [-1, 1] + feed = {"X": x, "grid": grid} + + def expected(f): + return F.grid_sample( + torch.from_numpy(f["X"]), + torch.from_numpy(f["grid"]), + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ).numpy() + + result = _run_model_test(target_device, "GridSample", model, feed, expected, rtol=1e-3, atol=1e-3) + self.assertEqual(result, TEST_PASS, "GridSample test failed") + + def test_op_gelu(self): + """Test Gelu contrib op — important for transformer models.""" + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Gelu", + [("X", TensorProto.FLOAT, [2, 8])], + [("Y", TensorProto.FLOAT, [2, 8])], + domain="com.microsoft", + opset=13, + ) + feed = {"X": np.random.randn(2, 8).astype(np.float32)} + + def expected(f): + return torch.nn.functional.gelu(torch.from_numpy(f["X"])).numpy() + + result = _run_model_test(target_device, "Gelu", model, feed, expected) + self.assertEqual(result, TEST_PASS, "Gelu test failed") + + def test_op_bias_gelu(self): + """Test BiasGelu contrib op.""" + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "BiasGelu", + [("X", TensorProto.FLOAT, [2, 8]), ("bias", TensorProto.FLOAT, [8])], + [("Y", TensorProto.FLOAT, [2, 8])], + domain="com.microsoft", + opset=13, + ) + bias = np.random.randn(8).astype(np.float32) + model.graph.initializer.append(helper.make_tensor("bias", TensorProto.FLOAT, [8], bias.tolist())) + feed = {"X": np.random.randn(2, 8).astype(np.float32)} + + def expected(f): + x = torch.from_numpy(f["X"]) + torch.from_numpy(bias) + return torch.nn.functional.gelu(x).numpy() + + result = _run_model_test(target_device, "BiasGelu", model, feed, expected) + self.assertEqual(result, TEST_PASS, "BiasGelu test failed") + + def test_op_fused_matmul(self): + """Test FusedMatMul contrib op (MatMul with alpha).""" + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "FusedMatMul", + [("A", TensorProto.FLOAT, [3, 4]), ("B", TensorProto.FLOAT, [4, 5])], + [("Y", TensorProto.FLOAT, [3, 5])], + attrs={"alpha": 2.0}, + domain="com.microsoft", + opset=13, + ) + feed = {"A": np.random.rand(3, 4).astype(np.float32), "B": np.random.rand(4, 5).astype(np.float32)} + result = _run_model_test(target_device, "FusedMatMul", model, feed, lambda f: 2.0 * (f["A"] @ f["B"])) + self.assertEqual(result, TEST_PASS, "FusedMatMul test failed") + + def test_op_trilu(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Trilu", + [("X", TensorProto.FLOAT, [4, 4])], + [("Y", TensorProto.FLOAT, [4, 4])], + attrs={"upper": 1}, + opset=14, + ) + feed = {"X": np.random.rand(4, 4).astype(np.float32)} + result = _run_model_test(target_device, "Trilu", model, feed, lambda f: np.triu(f["X"])) + self.assertEqual(result, TEST_PASS, "Trilu test failed") + + def test_op_matmul_integer(self): + """Test MatMulInteger — used in INT8 quantization pipelines.""" + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "MatMulInteger", + [ + ("A", TensorProto.INT8, [3, 4]), + ("B", TensorProto.INT8, [4, 5]), + ], + [("Y", TensorProto.INT32, [3, 5])], + opset=10, + ) + a = np.random.randint(-128, 127, size=(3, 4)).astype(np.int8) + b = np.random.randint(-128, 127, size=(4, 5)).astype(np.int8) + feed = {"A": a, "B": b} + result = _run_model_test( + target_device, + "MatMulInteger", + model, + feed, + lambda f: f["A"].astype(np.int32) @ f["B"].astype(np.int32), + ) + self.assertEqual(result, TEST_PASS, "MatMulInteger test failed") + if __name__ == "__main__": unittest.main() From c14fb1eee382f279be81b698050338786c339399 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 29 Mar 2026 00:33:56 -0700 Subject: [PATCH 40/48] MemcpyToHost and MemcpyFromHosst --- .../cuda/plugin/cuda_memcpy_plugin.cc | 79 +++++++++++++++++ .../transformers/test_cuda_plugin_ep.py | 85 +++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 onnxruntime/core/providers/cuda/plugin/cuda_memcpy_plugin.cc diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_memcpy_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_memcpy_plugin.cc new file mode 100644 index 0000000000000..f8b4d27e5cade --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_memcpy_plugin.cc @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Plugin-side MemcpyFromHost / MemcpyToHost kernels. +// These handle the common Tensor case using cudaMemcpyAsync directly. + +#include "core/providers/cuda/plugin/cuda_kernel_adapter.h" + +#include + +namespace onnxruntime { +namespace cuda { + +class PluginMemcpy final : public CudaKernel { + public: + PluginMemcpy(const OpKernelInfo& info) : CudaKernel(info) {} + + Status ComputeInternal(OpKernelContext* ctx) const override { + const Tensor* X = ctx->Input(0); + ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); + Tensor* Y = ctx->Output(0, X->Shape()); + ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); + + if (X->SizeInBytes() == 0) { + return Status::OK(); + } + + const void* src = X->DataRaw(); + void* dst = Y->MutableDataRaw(); + if (src == dst) { + return Status::OK(); + } + + // Determine copy direction from device placement. + const auto& src_loc = X->Location(); + const auto& dst_loc = Y->Location(); + cudaMemcpyKind kind; + if (src_loc.device.Type() == OrtDevice::CPU && dst_loc.device.Type() == OrtDevice::GPU) { + kind = cudaMemcpyHostToDevice; + } else if (src_loc.device.Type() == OrtDevice::GPU && dst_loc.device.Type() == OrtDevice::CPU) { + kind = cudaMemcpyDeviceToHost; + } else if (src_loc.device.Type() == OrtDevice::GPU && dst_loc.device.Type() == OrtDevice::GPU) { + kind = cudaMemcpyDeviceToDevice; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "PluginMemcpy: unsupported copy direction"); + } + + cudaStream_t stream = Stream(ctx); + if (stream != nullptr) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst, src, X->SizeInBytes(), kind, stream)); + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst, src, X->SizeInBytes(), kind)); + } + return Status::OK(); + } +}; + +ONNX_OPERATOR_KERNEL_EX( + MemcpyFromHost, + kOnnxDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + PluginMemcpy); + +ONNX_OPERATOR_KERNEL_EX( + MemcpyToHost, + kOnnxDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPUOutput, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + PluginMemcpy); + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py index f9a4a7b05afd7..d9f1ff62db750 100644 --- a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py +++ b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py @@ -1718,6 +1718,91 @@ def test_op_matmul_integer(self): ) self.assertEqual(result, TEST_PASS, "MatMulInteger test failed") + # ---- MemcpyFromHost / MemcpyToHost tests ---- + # These tests explicitly place MemcpyFromHost/MemcpyToHost nodes in the graph + # to directly exercise the plugin-side copy kernels. + + def test_memcpy_from_host_explicit(self): + """Explicit MemcpyFromHost node: CPU input → GPU copy → Relu on GPU.""" + target_device = get_cuda_plugin_device() + # X (CPU) → MemcpyFromHost → X_gpu → Relu → Y + copy_node = helper.make_node("MemcpyFromHost", ["X"], ["X_gpu"]) + relu_node = helper.make_node("Relu", ["X_gpu"], ["Y"]) + graph = helper.make_graph( + [copy_node, relu_node], + "test-explicit-memcpy-from-host", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4])], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + feed = {"X": np.random.randn(3, 4).astype(np.float32)} + result = _run_model_test( + target_device, + "MemcpyFromHost_explicit", + model, + feed, + lambda f: np.maximum(f["X"], 0), + ) + self.assertEqual(result, TEST_PASS, "Explicit MemcpyFromHost test failed") + + def test_memcpy_to_host_explicit(self): + """Explicit MemcpyToHost node: GPU Add → GPU-to-CPU copy → output.""" + target_device = get_cuda_plugin_device() + # A, B → Add (GPU) → sum_gpu → MemcpyToHost → Y (CPU) + add_node = helper.make_node("Add", ["A", "B"], ["sum_gpu"]) + copy_node = helper.make_node("MemcpyToHost", ["sum_gpu"], ["Y"]) + graph = helper.make_graph( + [add_node, copy_node], + "test-explicit-memcpy-to-host", + [ + helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [2, 3]), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + feed = { + "A": np.random.rand(2, 3).astype(np.float32), + "B": np.random.rand(2, 3).astype(np.float32), + } + result = _run_model_test( + target_device, + "MemcpyToHost_explicit", + model, + feed, + lambda f: f["A"] + f["B"], + ) + self.assertEqual(result, TEST_PASS, "Explicit MemcpyToHost test failed") + + def test_memcpy_roundtrip_explicit(self): + """Explicit both directions: CPU → MemcpyFromHost → Relu (GPU) → MemcpyToHost → CPU.""" + target_device = get_cuda_plugin_device() + copy_in = helper.make_node("MemcpyFromHost", ["X"], ["X_gpu"]) + relu_node = helper.make_node("Relu", ["X_gpu"], ["relu_out"]) + copy_out = helper.make_node("MemcpyToHost", ["relu_out"], ["Y"]) + graph = helper.make_graph( + [copy_in, relu_node, copy_out], + "test-explicit-memcpy-roundtrip", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 5])], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 5])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + feed = {"X": np.random.randn(4, 5).astype(np.float32)} + result = _run_model_test( + target_device, + "MemcpyRoundtrip_explicit", + model, + feed, + lambda f: np.maximum(f["X"], 0), + ) + self.assertEqual(result, TEST_PASS, "Explicit MemcpyFromHost→Relu→MemcpyToHost roundtrip test failed") + if __name__ == "__main__": unittest.main() From 2b994c970226d9bb615bea1c0e473863f8cc366f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 29 Mar 2026 10:57:42 -0700 Subject: [PATCH 41/48] lintrunner --- .../transformers/test_cuda_plugin_ep.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py index d9f1ff62db750..45682674756a4 100644 --- a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py +++ b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py @@ -1419,30 +1419,30 @@ def expected(f): def test_op_instance_normalization(self): target_device = get_cuda_plugin_device() - C = 3 + n_channels = 3 model = _make_simple_model( "InstanceNormalization", [ - ("X", TensorProto.FLOAT, [1, C, 4, 4]), - ("scale", TensorProto.FLOAT, [C]), - ("B", TensorProto.FLOAT, [C]), + ("X", TensorProto.FLOAT, [1, n_channels, 4, 4]), + ("scale", TensorProto.FLOAT, [n_channels]), + ("B", TensorProto.FLOAT, [n_channels]), ], - [("Y", TensorProto.FLOAT, [1, C, 4, 4])], + [("Y", TensorProto.FLOAT, [1, n_channels, 4, 4])], attrs={"epsilon": 1e-5}, opset=6, ) - scale = np.ones(C, dtype=np.float32) - bias = np.zeros(C, dtype=np.float32) - model.graph.initializer.append(helper.make_tensor("scale", TensorProto.FLOAT, [C], scale.tolist())) - model.graph.initializer.append(helper.make_tensor("B", TensorProto.FLOAT, [C], bias.tolist())) + scale = np.ones(n_channels, dtype=np.float32) + bias = np.zeros(n_channels, dtype=np.float32) + model.graph.initializer.append(helper.make_tensor("scale", TensorProto.FLOAT, [n_channels], scale.tolist())) + model.graph.initializer.append(helper.make_tensor("B", TensorProto.FLOAT, [n_channels], bias.tolist())) - x = np.random.rand(1, C, 4, 4).astype(np.float32) + x = np.random.rand(1, n_channels, 4, 4).astype(np.float32) feed = {"X": x} def expected(f): x = f["X"] result = np.empty_like(x) - for c in range(C): + for c in range(n_channels): ch = x[0, c] mean = ch.mean() var = ch.var() @@ -1600,19 +1600,19 @@ def test_op_non_zero(self): def test_op_grid_sample(self): target_device = get_cuda_plugin_device() - N, C, H, W = 1, 1, 4, 4 + n, c, h, w = 1, 1, 4, 4 model = _make_simple_model( "GridSample", [ - ("X", TensorProto.FLOAT, [N, C, H, W]), - ("grid", TensorProto.FLOAT, [N, 2, 2, 2]), + ("X", TensorProto.FLOAT, [n, c, h, w]), + ("grid", TensorProto.FLOAT, [n, 2, 2, 2]), ], - [("Y", TensorProto.FLOAT, [N, C, 2, 2])], + [("Y", TensorProto.FLOAT, [n, c, 2, 2])], attrs={"mode": "bilinear", "padding_mode": "zeros", "align_corners": 0}, opset=16, ) - x = np.random.rand(N, C, H, W).astype(np.float32) - grid = np.random.rand(N, 2, 2, 2).astype(np.float32) * 2 - 1 # in [-1, 1] + x = np.random.rand(n, c, h, w).astype(np.float32) + grid = np.random.rand(n, 2, 2, 2).astype(np.float32) * 2 - 1 # in [-1, 1] feed = {"X": x, "grid": grid} def expected(f): From 47f76744ebcc3eb4e2c2040a018220901d1b7a5e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 29 Mar 2026 11:37:34 -0700 Subject: [PATCH 42/48] add second gpu test --- cmake/onnxruntime_providers_cuda_plugin.cmake | 6 +- include/onnxruntime/ep/api.h | 4 + .../cuda/plugin/cuda_data_transfer_plugin.cc | 6 +- .../cuda/plugin/cuda_data_transfer_plugin.h | 4 +- .../core/providers/cuda/plugin/cuda_ep.cc | 19 ++- .../providers/cuda/plugin/cuda_ep_factory.cc | 113 +++++++----------- .../providers/cuda/plugin/cuda_ep_factory.h | 43 +++++-- .../cuda/plugin/cuda_stream_plugin.cc | 18 ++- .../cuda/plugin/cuda_stream_plugin.h | 1 + .../transformers/test_cuda_plugin_ep.py | 55 ++++++++- tools/ci_build/cuda_plugin_parity_report.py | 1 + 11 files changed, 172 insertions(+), 98 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index b823ca83a5e15..9dbcf3721b06b 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -164,7 +164,11 @@ endif() # Mirror the core CUDA provider's CUDA 12.8+ NVCC workarounds so the plugin # target handles stricter cudafe diagnostics consistently. -set(onnxruntime_plugin_nvcc_threads "1") +if (DEFINED onnxruntime_NVCC_THREADS) + set(onnxruntime_plugin_nvcc_threads "${onnxruntime_NVCC_THREADS}") +else() + set(onnxruntime_plugin_nvcc_threads "1") +endif() target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE "$<$:SHELL:--threads \"${onnxruntime_plugin_nvcc_threads}\">" "$<$:--diag-suppress=177>" diff --git a/include/onnxruntime/ep/api.h b/include/onnxruntime/ep/api.h index 4e2b97d544e3b..c22e52ed8aaa5 100644 --- a/include/onnxruntime/ep/api.h +++ b/include/onnxruntime/ep/api.h @@ -5,6 +5,7 @@ #include #include +#include #pragma push_macro("ORT_API_MANUAL_INIT") #undef ORT_API_MANUAL_INIT @@ -31,6 +32,9 @@ inline std::optional g_api_ptrs; /// Get the global instance of ApiPtrs. /// inline const ApiPtrs& Api() { + if (!detail::g_api_ptrs.has_value()) { + throw std::logic_error("onnxruntime::ep::Api() called before ApiInit()."); + } return *detail::g_api_ptrs; } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc index 0b248438a421d..e4b3ed8f3c314 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc @@ -6,12 +6,10 @@ namespace onnxruntime { namespace cuda_plugin { -CudaDataTransfer::CudaDataTransfer(const OrtApi& ort_api, const OrtEpApi& ep_api, - const OrtMemoryDevice* gpu_device) +CudaDataTransfer::CudaDataTransfer(const OrtApi& ort_api, const OrtEpApi& ep_api) : OrtDataTransferImpl{}, ort_api_(ort_api), - ep_api_(ep_api), - gpu_device_(gpu_device) { + ep_api_(ep_api) { ort_version_supported = ORT_API_VERSION; Release = ReleaseImpl; CanCopy = CanCopyImpl; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h index 2fcf40be06fc2..a43f90cf01f72 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h @@ -15,8 +15,7 @@ namespace cuda_plugin { /// CUDA data transfer implementation for CPU↔GPU and GPU↔GPU copies. class CudaDataTransfer : public OrtDataTransferImpl { public: - CudaDataTransfer(const OrtApi& ort_api, const OrtEpApi& ep_api, - const OrtMemoryDevice* gpu_device); + CudaDataTransfer(const OrtApi& ort_api, const OrtEpApi& ep_api); ~CudaDataTransfer() = default; private: @@ -36,7 +35,6 @@ class CudaDataTransfer : public OrtDataTransferImpl { const OrtApi& ort_api_; const OrtEpApi& ep_api_; - const OrtMemoryDevice* gpu_device_; }; } // namespace cuda_plugin diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index f8e7af0f49ed4..a210cc2162bd9 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -185,9 +185,16 @@ OrtStatus* ORT_API_CALL CudaEp::ShouldConvertDataLayoutForOpImpl( OrtEpDataLayout target_data_layout, int* should_convert) noexcept { ORT_UNUSED_PARAMETER(this_ptr); + if (should_convert == nullptr) { + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "should_convert must not be null."); + } + + const char* safe_domain = domain != nullptr ? domain : ""; + const char* safe_op_type = op_type != nullptr ? op_type : ""; + #ifndef ENABLE_CUDA_NHWC_OPS - ORT_UNUSED_PARAMETER(domain); - ORT_UNUSED_PARAMETER(op_type); + ORT_UNUSED_PARAMETER(safe_domain); + ORT_UNUSED_PARAMETER(safe_op_type); ORT_UNUSED_PARAMETER(target_data_layout); *should_convert = 0; // NHWC kernels are not compiled into this plugin build. return nullptr; @@ -215,15 +222,15 @@ OrtStatus* ORT_API_CALL CudaEp::ShouldConvertDataLayoutForOpImpl( }; // Check ONNX domain (empty string) or MS domain (com.microsoft) - bool is_onnx_domain = (domain[0] == '\0'); - bool is_ms_domain = (std::strcmp(domain, "com.microsoft") == 0); + bool is_onnx_domain = (safe_domain[0] == '\0'); + bool is_ms_domain = (std::strcmp(safe_domain, "com.microsoft") == 0); - if (is_onnx_domain && cuda_nhwc_onnx_ops.count(op_type) > 0) { + if (is_onnx_domain && cuda_nhwc_onnx_ops.count(safe_op_type) > 0) { *should_convert = 1; // Convert return nullptr; } - if (is_ms_domain && std::strcmp(op_type, "GridSample") == 0) { + if (is_ms_domain && std::strcmp(safe_op_type, "GridSample") == 0) { *should_convert = 1; // Convert return nullptr; } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index 7ee0ccedf6bee..922984eb16905 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -20,9 +20,7 @@ CudaEpFactory::CudaEpFactory(const OrtApi& ort_api, const OrtEpApi& ep_api, : OrtEpFactory{}, ort_api_(ort_api), ep_api_(ep_api), - default_logger_(default_logger), - default_memory_info_{nullptr}, - pinned_memory_info_{nullptr} { + default_logger_(default_logger) { ort_version_supported = ORT_API_VERSION; if (!::onnxruntime::ep::adapter::LoggingManager::HasDefaultLogger()) { @@ -42,22 +40,6 @@ CudaEpFactory::CudaEpFactory(const OrtApi& ort_api, const OrtEpApi& ep_api, CreateDataTransfer = CreateDataTransferImpl; IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; - - // Initialize default memory info for CUDA device memory. - // The NVIDIA PCI vendor ID (0x10DE) is used to identify the device type. - default_memory_info_ = Ort::MemoryInfo{"Cuda", - OrtMemoryInfoDeviceType_GPU, - vendor_id_, - static_cast(device_id_), - OrtDeviceMemoryType_DEFAULT, - /*alignment*/ 0, - OrtAllocatorType::OrtDeviceAllocator}; - - // Initialize pinned (host accessible) memory info - pinned_memory_info_ = Ort::MemoryInfo{"CudaPinned", - OrtAllocatorType::OrtDeviceAllocator, - 0, - OrtMemType::OrtMemTypeCPU}; } CudaEpFactory::~CudaEpFactory() { @@ -122,6 +104,15 @@ std::string GetProviderOptionPrefix(std::string_view provider_name) { } // namespace +CudaEpFactory::HardwareDeviceKey CudaEpFactory::MakeDeviceKey(const OrtApi& ort_api, + const OrtHardwareDevice& device) { + return { + ort_api.HardwareDevice_Type(&device), + ort_api.HardwareDevice_VendorId(&device), + ort_api.HardwareDevice_DeviceId(&device), + }; +} + /*static*/ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( OrtEpFactory* this_ptr, @@ -134,17 +125,6 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( size_t& num_ep_devices = *p_num_ep_devices; num_ep_devices = 0; - // Clear stale entries from previous enumerations. The OrtHardwareDevice* - // keys point to ORT-owned memory that may be freed between calls. - { - std::lock_guard lock(factory->device_map_mutex_); - factory->hw_device_to_cuda_index_.clear(); - } - { - std::lock_guard lock(factory->cached_memory_info_mutex_); - factory->cached_memory_infos_.clear(); - } - int cuda_device_index = 0; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *hw_devices[i]; @@ -162,11 +142,29 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( // CUDA uses contiguous ordinals for CUDA-visible NVIDIA devices. Build that // mapping from the filtered hardware-device list instead of relying on the // ORT hardware device id, which is not guaranteed to be a CUDA ordinal. - const int32_t current_device_id = cuda_device_index++; - + int current_device_id = cuda_device_index++; + const auto device_key = CudaEpFactory::MakeDeviceKey(factory->ort_api_, device); + DeviceCacheEntry* cache_entry = nullptr; { - std::lock_guard lock(factory->device_map_mutex_); - factory->hw_device_to_cuda_index_[&device] = current_device_id; + std::lock_guard lock(factory->device_cache_mutex_); + auto [it, inserted] = factory->device_cache_.try_emplace(device_key); + if (inserted) { + it->second.cuda_device_id = current_device_id; + it->second.device_memory_info = Ort::MemoryInfo{"Cuda", + OrtMemoryInfoDeviceType_GPU, + factory->vendor_id_, + static_cast(current_device_id), + OrtDeviceMemoryType_DEFAULT, + /*alignment is default*/ 0, + OrtAllocatorType::OrtDeviceAllocator}; + it->second.pinned_memory_info = Ort::MemoryInfo{"CudaPinned", + OrtAllocatorType::OrtDeviceAllocator, + current_device_id, + OrtMemType::OrtMemTypeCPU}; + } + + cache_entry = &it->second; + current_device_id = cache_entry->cuda_device_id; } OrtKeyValuePairs* ep_metadata = nullptr; @@ -199,27 +197,14 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( return status; } - Ort::MemoryInfo device_memory_info{"Cuda", - OrtMemoryInfoDeviceType_GPU, - factory->vendor_id_, - static_cast(current_device_id), - OrtDeviceMemoryType_DEFAULT, - /*alignment is default*/ 0, - OrtAllocatorType::OrtDeviceAllocator}; - - OrtMemoryInfo* raw_memory_info = device_memory_info; - { - std::lock_guard lock(factory->cached_memory_info_mutex_); - factory->cached_memory_infos_.push_back(std::move(device_memory_info)); - } - // Register allocator info for GPU device memory RETURN_IF_ERROR(factory->ep_api_.EpDevice_AddAllocatorInfo( - ep_device, raw_memory_info)); + ep_device, cache_entry->device_memory_info)); - // Register allocator info for CPU pinned memory (host accessible) + // Register allocator info for pinned host memory associated with the + // same CUDA ordinal as the device allocator above. RETURN_IF_ERROR(factory->ep_api_.EpDevice_AddAllocatorInfo( - ep_device, factory->pinned_memory_info_)); + ep_device, cache_entry->pinned_memory_info)); ep_devices[num_ep_devices++] = ep_device; } @@ -259,13 +244,16 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( // absent or malformed the default value in Config is kept. CudaEp::Config config{}; - config.device_id = 0; // Default { - std::lock_guard lock(factory->device_map_mutex_); - auto it = factory->hw_device_to_cuda_index_.find(devices[0]); - if (it != factory->hw_device_to_cuda_index_.end()) { - config.device_id = it->second; + std::lock_guard lock(factory->device_cache_mutex_); + auto it = factory->device_cache_.find(CudaEpFactory::MakeDeviceKey(factory->ort_api_, *devices[0])); + if (it == factory->device_cache_.end()) { + return factory->ort_api_.CreateStatus( + ORT_INVALID_ARGUMENT, + "CUDA EP factory could not resolve the requested device. " + "Enumerate EP devices again and retry session creation."); } + config.device_id = it->second.cuda_device_id; } auto try_get_session_config = [&](std::string_view key) -> std::optional { @@ -491,18 +479,7 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateDataTransferImpl( OrtEpFactory* this_ptr, OrtDataTransferImpl** data_transfer) noexcept { auto& factory = *static_cast(this_ptr); - - // Use the device ID this factory was created for - Ort::MemoryInfo device_memory_info{"Cuda", - OrtMemoryInfoDeviceType_GPU, - factory.vendor_id_, - static_cast(factory.device_id_), - OrtDeviceMemoryType_DEFAULT, - 0, - OrtAllocatorType::OrtDeviceAllocator}; - - const OrtMemoryDevice* gpu_device = factory.ep_api_.MemoryInfo_GetMemoryDevice(device_memory_info); - auto data_transfer_impl = std::make_unique(factory.ort_api_, factory.ep_api_, gpu_device); + auto data_transfer_impl = std::make_unique(factory.ort_api_, factory.ep_api_); *data_transfer = data_transfer_impl.release(); return nullptr; } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h index d88bf1ac9b647..ea4e2da19001d 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h @@ -10,6 +10,7 @@ #include #include +#include #include namespace onnxruntime { @@ -89,18 +90,36 @@ class CudaEpFactory : public OrtEpFactory { const uint32_t vendor_id_ = 0x10DE; // NVIDIA PCI vendor ID const std::string ep_version_{"1.0.0"}; - // Memory info for GPU device and CPU pinned memory - Ort::MemoryInfo default_memory_info_{nullptr}; - Ort::MemoryInfo pinned_memory_info_{nullptr}; - - std::mutex cached_memory_info_mutex_; - std::vector cached_memory_infos_; - int device_id_ = 0; - - // Map ORT hardware device pointers to CUDA ordinals for the NVIDIA devices - // visible to the CUDA runtime. - std::mutex device_map_mutex_; - std::unordered_map hw_device_to_cuda_index_; + struct DeviceCacheEntry { + int cuda_device_id{-1}; + Ort::MemoryInfo device_memory_info{nullptr}; + Ort::MemoryInfo pinned_memory_info{nullptr}; + }; + + struct HardwareDeviceKey { + OrtHardwareDeviceType type{OrtHardwareDeviceType::OrtHardwareDeviceType_CPU}; + uint32_t vendor_id{0}; + uint32_t device_id{0}; + + bool operator==(const HardwareDeviceKey&) const = default; + }; + + struct HardwareDeviceKeyHasher { + size_t operator()(const HardwareDeviceKey& key) const noexcept { + size_t hash = static_cast(key.type); + hash = (hash * 1315423911u) ^ static_cast(key.vendor_id); + hash = (hash * 1315423911u) ^ static_cast(key.device_id); + return hash; + } + }; + + static HardwareDeviceKey MakeDeviceKey(const OrtApi& ort_api, + const OrtHardwareDevice& device); + + // Stable per-device cache keyed by public hardware-device properties instead + // of the transient OrtHardwareDevice* pointer received during enumeration. + std::mutex device_cache_mutex_; + std::unordered_map device_cache_; // Kernel registry (cached, shared across EP instances) OrtKernelRegistry* kernel_registry_ = nullptr; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc index 66aa3c4cf1d45..6a6ab68103ef3 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc @@ -53,7 +53,13 @@ CudaSyncStream::CudaSyncStream(CudaEpFactory& factory, int device_id, } CudaSyncStream::~CudaSyncStream() { - if (!deferred_cpu_buffers_.empty()) { + bool has_deferred_cpu_buffers = false; + { + std::lock_guard lock(deferred_cpu_buffers_mutex_); + has_deferred_cpu_buffers = !deferred_cpu_buffers_.empty(); + } + + if (has_deferred_cpu_buffers) { if (cuda_stream_ != nullptr) { OrtStatus* status = OnSessionRunEndImpl(this); if (status != nullptr) { @@ -114,12 +120,19 @@ OrtStatus* CudaSyncStream::InitHandles() { } void CudaSyncStream::EnqueueDeferredCPUBuffer(void* cpu_buffer) { + std::lock_guard lock(deferred_cpu_buffers_mutex_); deferred_cpu_buffers_.push_back(cpu_buffer); } OrtStatus* CudaSyncStream::CleanupDeferredCPUBuffers() noexcept { + std::vector buffers_to_free; + { + std::lock_guard lock(deferred_cpu_buffers_mutex_); + buffers_to_free.swap(deferred_cpu_buffers_); + } + OrtStatus* first_error = nullptr; - for (void* buf : deferred_cpu_buffers_) { + for (void* buf : buffers_to_free) { cudaError_t err = cudaFreeHost(buf); if (err != cudaSuccess && first_error == nullptr) { first_error = Ort::GetApi().CreateStatus( @@ -127,7 +140,6 @@ OrtStatus* CudaSyncStream::CleanupDeferredCPUBuffers() noexcept { (std::string("CUDA error: ") + cudaGetErrorName(err) + ": " + cudaGetErrorString(err)).c_str()); } } - deferred_cpu_buffers_.clear(); return first_error; } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h index 60d85c043bf67..4b72dee82ca38 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h @@ -65,6 +65,7 @@ class CudaSyncStream : public OrtSyncStreamImpl { // CPU buffers whose deallocation is deferred to OnSessionRunEnd. // Pinned memory must remain valid until all async device operations that // reference it have completed, so we synchronize the stream first. + mutable std::mutex deferred_cpu_buffers_mutex_; std::vector deferred_cpu_buffers_; }; diff --git a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py index 45682674756a4..cccc70b5c6f02 100644 --- a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py +++ b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py @@ -37,6 +37,10 @@ def require_cuda_plugin_ep(): def get_cuda_plugin_device(): + return get_cuda_plugin_devices()[0] + + +def get_cuda_plugin_devices(): require_cuda_plugin_ep() try: @@ -48,7 +52,18 @@ def get_cuda_plugin_device(): if not plugin_devices: raise unittest.SkipTest("CUDA plugin EP registered, but no plugin devices were enumerated") - return plugin_devices[0] + return plugin_devices + + +def get_cuda_plugin_device_by_id(device_id: int): + expected_device_id = str(device_id) + for device in get_cuda_plugin_devices(): + if device.ep_options.get("device_id") == expected_device_id: + return device + if device.ep_metadata.get("cuda_device_id") == expected_device_id: + return device + + raise unittest.SkipTest(f"CUDA plugin EP device_id={device_id} is not available in this environment") def _create_session_options(session_config=None): @@ -476,6 +491,44 @@ def test_provider_options_invalid_device(self): result = run_provider_options_test({"device_id": "999"}, expect_plugin_provider=False) self.assertTrue(result, "Provider options with invalid device_id failed") + def test_provider_options_second_device(self): + plugin_devices = get_cuda_plugin_devices() + if len(plugin_devices) < 2: + self.skipTest("Multi-GPU CUDA plugin EP test requires at least two plugin devices") + + target_device = get_cuda_plugin_device_by_id(1) + + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp: + model_path = tmp.name + try: + create_add_model(model_path) + providers = [(CUDA_PLUGIN_EP_NAME, {"device_id": "1"}), "CPUExecutionProvider"] + sess = onnxrt.InferenceSession(model_path, sess_options=_create_session_options(), providers=providers) + + active_providers = sess.get_providers() + assigned_nodes, assignment_info = _get_assigned_nodes(sess, CUDA_PLUGIN_EP_NAME) + self.assertTrue( + assigned_nodes, + f"{CUDA_PLUGIN_EP_NAME} was assigned no nodes. Providers: {active_providers}. " + f"Assignments: {_format_assignment_summary(assignment_info)}", + ) + + provider_options = sess.get_provider_options() + self.assertEqual( + provider_options[CUDA_PLUGIN_EP_NAME].get("device_id"), + "1", + f"Expected provider option device_id=1, got {provider_options[CUDA_PLUGIN_EP_NAME]}", + ) + self.assertEqual(target_device.ep_options.get("device_id"), "1") + + a = np.random.rand(3, 2).astype(np.float32) + b = np.random.rand(3, 2).astype(np.float32) + res = sess.run(None, {"A": a, "B": b}) + np.testing.assert_allclose(res[0], a + b, rtol=1e-3, atol=1e-3) + finally: + if os.path.exists(model_path): + os.remove(model_path) + # ---- NHWC layout tests ---- def test_nhwc_conv(self): diff --git a/tools/ci_build/cuda_plugin_parity_report.py b/tools/ci_build/cuda_plugin_parity_report.py index 23fd8fbb6c39f..1dffb5ea1292b 100755 --- a/tools/ci_build/cuda_plugin_parity_report.py +++ b/tools/ci_build/cuda_plugin_parity_report.py @@ -727,6 +727,7 @@ def _auto_detect_plugin_lib(repo_root): if candidate.exists(): return str(candidate) except ImportError: + # onnxruntime is not installed in the current environment. Return None here and the script will fail later. pass return None From 8471fec4bda2a84ad4201fe619e57bc3dc7cae68 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 30 Mar 2026 02:24:49 -0700 Subject: [PATCH 43/48] Fix dropout and webgpu plugin test failures --- include/onnxruntime/ep/adapter/op_kernel.h | 8 ++------ include/onnxruntime/ep/adapter/op_kernel_info.h | 3 +++ onnxruntime/core/providers/cuda/nn/dropout.cc | 14 +++++++++++--- onnxruntime/core/providers/cuda/nn/dropout.h | 2 ++ .../python/transformers/test_cuda_plugin_ep.py | 2 +- 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/include/onnxruntime/ep/adapter/op_kernel.h b/include/onnxruntime/ep/adapter/op_kernel.h index c90ba2205abf1..3cbf3d03a2c31 100644 --- a/include/onnxruntime/ep/adapter/op_kernel.h +++ b/include/onnxruntime/ep/adapter/op_kernel.h @@ -131,16 +131,12 @@ struct OpKernelContext { return Output(index, TensorShape{shape}); } [[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const { - const auto* execution_provider = op_kernel_.Info().GetExecutionProvider(); - ORT_ENFORCE(execution_provider != nullptr, "Kernel does not have an execution provider."); - const auto* ort_ep = execution_provider->GetOrtEp(); + const auto* ort_ep = op_kernel_.Info().GetOrtEp(); ORT_ENFORCE(ort_ep != nullptr, "Kernel execution provider is not associated with an OrtEp instance."); return static_cast(ort_ep)->GetTempSpaceCPUAllocator(output); } [[nodiscard]] Status GetTempSpaceAllocator(AllocatorPtr* output) const { - const auto* execution_provider = op_kernel_.Info().GetExecutionProvider(); - ORT_ENFORCE(execution_provider != nullptr, "Kernel does not have an execution provider."); - const auto* ort_ep = execution_provider->GetOrtEp(); + const auto* ort_ep = op_kernel_.Info().GetOrtEp(); ORT_ENFORCE(ort_ep != nullptr, "Kernel execution provider is not associated with an OrtEp instance."); return static_cast(ort_ep)->GetTempSpaceAllocator(output); } diff --git a/include/onnxruntime/ep/adapter/op_kernel_info.h b/include/onnxruntime/ep/adapter/op_kernel_info.h index 3ed5f3034d9ee..f0b620c334d40 100644 --- a/include/onnxruntime/ep/adapter/op_kernel_info.h +++ b/include/onnxruntime/ep/adapter/op_kernel_info.h @@ -79,6 +79,9 @@ struct OpKernelInfo { const IExecutionProvider* GetExecutionProvider() const noexcept { return cache_->ep_impl_; } + const OrtEp* GetOrtEp() const noexcept { + return cache_->ort_ep_; + } KernelDef GetKernelDef() const noexcept { return KernelDef{cache_->kernel_info_}; diff --git a/onnxruntime/core/providers/cuda/nn/dropout.cc b/onnxruntime/core/providers/cuda/nn/dropout.cc index 7233f2e241a73..5011e7aef7872 100644 --- a/onnxruntime/core/providers/cuda/nn/dropout.cc +++ b/onnxruntime/core/providers/cuda/nn/dropout.cc @@ -104,14 +104,22 @@ Status Dropout::ComputeInternal(OpKernelContext* context) const { CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y_data, X_data, X->SizeInBytes(), cudaMemcpyDeviceToDevice, Stream(context))); } - // If mask is requested, return all 1s. + // If mask is requested, fill it appropriately. + // BitmaskDropout (UseBitmask=true): mask is always bitmask where 1 = kept. All 1s for inference. + // Opset 12+: mask is bool, spec says "mask will contain all ones". All true for inference. + // Opset 7-11: mask is memset to 0, consistent with CPU IdentityOp behavior. if (mask) { - if (UseBitmask) { + if constexpr (UseBitmask) { + // BitmaskDropout always uses bitmask semantics (all bits set = all kept). CUDA_RETURN_IF_ERROR( cudaMemsetAsync(mask->MutableDataRaw(), -1, mask_element_count * sizeof(BitmaskElementType), Stream(context))); - } else { + } else if (opset_ >= 12) { CUDA_RETURN_IF_ERROR( cudaMemsetAsync(mask->MutableData(), true, mask_element_count * sizeof(bool), Stream(context))); + } else { + // Opset 7-11: zero-fill mask to match CPU IdentityOp behavior. + CUDA_RETURN_IF_ERROR( + cudaMemsetAsync(mask->MutableDataRaw(), 0, mask->SizeInBytes(), Stream(context))); } } diff --git a/onnxruntime/core/providers/cuda/nn/dropout.h b/onnxruntime/core/providers/cuda/nn/dropout.h index 183456573f317..10049f7320981 100644 --- a/onnxruntime/core/providers/cuda/nn/dropout.h +++ b/onnxruntime/core/providers/cuda/nn/dropout.h @@ -18,6 +18,7 @@ class Dropout final : public CudaKernel { if (info.GetAttr("seed", &seed).IsOK()) { generator_ = std::make_unique(static_cast(seed)); } + opset_ = info.node().SinceVersion(); } Status ComputeInternal(OpKernelContext* context) const override; @@ -25,6 +26,7 @@ class Dropout final : public CudaKernel { private: mutable std::unique_ptr generator_; static constexpr float default_ratio_ = 0.5f; + int opset_ = 12; }; } // namespace cuda diff --git a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py index cccc70b5c6f02..f1ce5b3d187ea 100644 --- a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py +++ b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py @@ -945,7 +945,7 @@ def test_op_dropout_opset10(self): "Dropout_opset10", model, {"X": x}, - lambda f: [f["X"], np.ones((2, 4), dtype=bool)], + lambda f: [f["X"], np.zeros((2, 4), dtype=bool)], ) self.assertEqual(result, TEST_PASS, "Dropout opset 10 plugin op test failed") From 20b960926d90c2b03c4b91a7cce541797ddaba7c Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 30 Mar 2026 14:05:21 -0700 Subject: [PATCH 44/48] review feedback --- include/onnxruntime/ep/adapter/op_kernel.h | 9 ++++- .../cuda/plugin/cuda_controlflow_plugin.cu | 4 +- .../core/providers/cuda/plugin/cuda_ep.cc | 11 +++--- .../providers/cuda/plugin/cuda_ep_factory.cc | 31 +++++++++++++--- .../cuda/plugin/cuda_kernel_adapter.h | 37 ++++++++++++++----- .../cuda/plugin/cuda_stream_plugin.cc | 2 +- 6 files changed, 68 insertions(+), 26 deletions(-) diff --git a/include/onnxruntime/ep/adapter/op_kernel.h b/include/onnxruntime/ep/adapter/op_kernel.h index 3cbf3d03a2c31..8c9c16b62692e 100644 --- a/include/onnxruntime/ep/adapter/op_kernel.h +++ b/include/onnxruntime/ep/adapter/op_kernel.h @@ -8,6 +8,7 @@ #endif #include +#include #include #include "core/framework/allocator.h" @@ -131,12 +132,16 @@ struct OpKernelContext { return Output(index, TensorShape{shape}); } [[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const { - const auto* ort_ep = op_kernel_.Info().GetOrtEp(); + const auto* execution_provider = op_kernel_.Info().GetExecutionProvider(); + ORT_ENFORCE(execution_provider != nullptr, "Kernel does not have an execution provider."); + const auto* ort_ep = execution_provider->GetOrtEp(); ORT_ENFORCE(ort_ep != nullptr, "Kernel execution provider is not associated with an OrtEp instance."); return static_cast(ort_ep)->GetTempSpaceCPUAllocator(output); } [[nodiscard]] Status GetTempSpaceAllocator(AllocatorPtr* output) const { - const auto* ort_ep = op_kernel_.Info().GetOrtEp(); + const auto* execution_provider = op_kernel_.Info().GetExecutionProvider(); + ORT_ENFORCE(execution_provider != nullptr, "Kernel does not have an execution provider."); + const auto* ort_ep = execution_provider->GetOrtEp(); ORT_ENFORCE(ort_ep != nullptr, "Kernel execution provider is not associated with an OrtEp instance."); return static_cast(ort_ep)->GetTempSpaceAllocator(output); } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu index 5e4b7acc2f95a..6ff3296dadcb8 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu @@ -2,8 +2,8 @@ // Licensed under the MIT License. // GPU transpose kernel for the Scan control flow helper. -// Handles arbitrary N-D permutations by computing output coordinates -// from linear indices. +// Supports permutations up to kMaxTransposeDims dimensions by computing +// output coordinates from linear indices. #include #include diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index a210cc2162bd9..e6934548acfd0 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -61,11 +61,12 @@ CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& lo "CUDA Plugin EP created", ORT_FILE, __LINE__, __FUNCTION__)); - // Store per-EP runtime configuration in a global map keyed by the - // adapter-wrapped execution provider pointer. Migrated kernels retrieve these - // settings at compute time via GetCudaKernelAdapterRuntimeConfigForProvider(). - // Adding a new config field only requires updating CudaKernelAdapterRuntimeConfig, - // CudaEp::Config, and the struct-initializer below — no function-signature change. + // Store per-EP runtime configuration inside the adapter-wrapped execution + // provider itself. Migrated kernels retrieve a shared config object at + // compute time via GetCudaKernelAdapterRuntimeConfigForProvider(). + // Adding a new config field only requires updating + // CudaKernelAdapterRuntimeConfig, CudaEp::Config, and the struct-initializer + // below — no function-signature change. onnxruntime::cuda::detail::CudaKernelAdapterRuntimeConfig adapter_config; adapter_config.use_tf32 = config_.use_tf32; adapter_config.skip_layer_norm_strict_mode = config_.enable_skip_layer_norm_strict_mode; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index 922984eb16905..f5443da32433b 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -124,6 +124,14 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( auto* factory = static_cast(this_ptr); size_t& num_ep_devices = *p_num_ep_devices; num_ep_devices = 0; + auto release_ep_devices = [&](OrtStatus* status) -> OrtStatus* { + for (size_t j = 0; j < num_ep_devices; ++j) { + factory->ep_api_.ReleaseEpDevice(ep_devices[j]); + ep_devices[j] = nullptr; + } + num_ep_devices = 0; + return status; + }; int cuda_device_index = 0; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { @@ -194,19 +202,30 @@ OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( factory->ort_api_.ReleaseKeyValuePairs(ep_options); if (status != nullptr) { - return status; + return release_ep_devices(status); } + auto release_current_ep_device = [factory](OrtEpDevice* device) { + factory->ep_api_.ReleaseEpDevice(device); + }; + // ep_device_guard owns the current device. On error, release_ep_devices cleans up + // previously committed devices [0, num_ep_devices), while the guard cleans up this one. + std::unique_ptr ep_device_guard(ep_device, release_current_ep_device); + // Register allocator info for GPU device memory - RETURN_IF_ERROR(factory->ep_api_.EpDevice_AddAllocatorInfo( - ep_device, cache_entry->device_memory_info)); + status = factory->ep_api_.EpDevice_AddAllocatorInfo(ep_device, cache_entry->device_memory_info); + if (status != nullptr) { + return release_ep_devices(status); + } // Register allocator info for pinned host memory associated with the // same CUDA ordinal as the device allocator above. - RETURN_IF_ERROR(factory->ep_api_.EpDevice_AddAllocatorInfo( - ep_device, cache_entry->pinned_memory_info)); + status = factory->ep_api_.EpDevice_AddAllocatorInfo(ep_device, cache_entry->pinned_memory_info); + if (status != nullptr) { + return release_ep_devices(status); + } - ep_devices[num_ep_devices++] = ep_device; + ep_devices[num_ep_devices++] = ep_device_guard.release(); } } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 4272541eb482a..81dee25576816 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -439,7 +439,7 @@ template <> struct SizeOf { static constexpr size_t value = 0; }; -inline size_t BytesForCount(size_t count_or_bytes, size_t element_size) { +[[nodiscard]] inline size_t BytesForCount(size_t count_or_bytes, size_t element_size) { if (element_size == 0) return count_or_bytes; if (count_or_bytes > (std::numeric_limits::max() / element_size)) return 0; return count_or_bytes * element_size; @@ -479,19 +479,27 @@ inline DefaultCudaHandles& GetDefaultCudaHandlesForDevice(int device_id) { auto [it, inserted] = handles_by_device.try_emplace(device_id); if (inserted) { int prev_device = -1; - cudaGetDevice(&prev_device); + const cudaError_t get_device_result = cudaGetDevice(&prev_device); PL_CUDA_CALL_THROW(cudaSetDevice(device_id)); if (cublasCreate(&it->second.cublas) != CUBLAS_STATUS_SUCCESS) { - cudaSetDevice(prev_device); + if (get_device_result == cudaSuccess) { + cudaSetDevice(prev_device); + } + handles_by_device.erase(it); ORT_THROW("Failed to create default cuBLAS handle for CUDA plugin device ", device_id); } if (cudnnCreate(&it->second.cudnn) != CUDNN_STATUS_SUCCESS) { cublasDestroy(it->second.cublas); it->second.cublas = nullptr; - cudaSetDevice(prev_device); + if (get_device_result == cudaSuccess) { + cudaSetDevice(prev_device); + } + handles_by_device.erase(it); ORT_THROW("Failed to create default cuDNN handle for CUDA plugin device ", device_id); } - PL_CUDA_CALL_THROW(cudaSetDevice(prev_device)); + if (get_device_result == cudaSuccess) { + PL_CUDA_CALL_THROW(cudaSetDevice(prev_device)); + } } return it->second; @@ -504,9 +512,9 @@ inline const cudaDeviceProp& GetDevicePropForDevice(int device_id) { auto it = props.find(device_id); if (it == props.end()) { auto prop = std::make_unique(); - if (cudaGetDeviceProperties(prop.get(), device_id) != cudaSuccess) { - std::memset(prop.get(), 0, sizeof(*prop)); - prop->major = -1; + const cudaError_t result = cudaGetDeviceProperties(prop.get(), device_id); + if (result != cudaSuccess) { + ORT_THROW("Failed to query CUDA device properties for device ", device_id, ": ", cudaGetErrorString(result)); } it = props.emplace(device_id, std::move(prop)).first; } @@ -983,6 +991,9 @@ class CudaKernel : public OpKernel { inline IAllocatorUniquePtr AllocateBufferOnCPUPinned(size_t cnt) const { if (cnt == 0) return IAllocatorUniquePtr(nullptr, [](T*) {}); size_t sz = detail::BytesForCount(cnt, detail::SizeOf::value); + if (sz == 0) { + ORT_THROW("CUDA pinned CPU buffer allocation size overflow for ", cnt, " elements"); + } void* p = nullptr; if (cudaHostAlloc(&p, sz, cudaHostAllocDefault) != cudaSuccess) return IAllocatorUniquePtr(nullptr, [](T*) {}); return IAllocatorUniquePtr(static_cast(p), [](T* ptr) { if (ptr) cudaFreeHost(ptr); }); @@ -997,7 +1008,9 @@ class CudaKernel : public OpKernel { T* p = CpuPtr(); for (size_t i = 0; i != n; ++i) *p++ = v; } - CudaAsyncBuffer(const CudaKernel* ok, gsl::span vec) : CudaAsyncBuffer(ok, vec.size()) { memcpy(CpuPtr(), vec.data(), vec.size() * sizeof(T)); } + CudaAsyncBuffer(const CudaKernel* ok, gsl::span vec) : CudaAsyncBuffer(ok, vec.size()) { + memcpy(CpuPtr(), vec.data(), detail::BytesForCount(vec.size(), sizeof(T))); + } void AllocCpuPtr(size_t n) { cpu_ = op_kernel_->AllocateBufferOnCPUPinned(n); if (!cpu_) throw std::runtime_error("alloc fail"); @@ -1006,7 +1019,11 @@ class CudaKernel : public OpKernel { Status CopyToGpu(void* s) { if (cpu_) { gpu_ = op_kernel_->GetScratchBuffer(count_, s); - if (cudaMemcpyAsync(gpu_.get(), cpu_.get(), count_ * sizeof(T), cudaMemcpyHostToDevice, static_cast(s)) != cudaSuccess) return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Memcpy fail"); + const size_t bytes = detail::BytesForCount(count_, sizeof(T)); + if (count_ > 0 && bytes == 0) { + ORT_THROW("CUDA async buffer copy size overflow for ", count_, " elements"); + } + if (cudaMemcpyAsync(gpu_.get(), cpu_.get(), bytes, cudaMemcpyHostToDevice, static_cast(s)) != cudaSuccess) return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Memcpy fail"); op_kernel_->AddDeferredReleaseCPUPtr(cpu_.release(), s); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc index 6a6ab68103ef3..521c6bb15c13f 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc @@ -106,7 +106,6 @@ OrtStatus* CudaSyncStream::InitHandles() { PL_CUDA_RETURN_IF_ERROR(cudaSetDevice(device_id_)); PL_CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&cuda_stream_, cudaStreamNonBlocking)); - RegisterStream(cuda_stream_, this); PL_CUBLAS_RETURN_IF_ERROR(cublasCreate(&cublas_handle_)); PL_CUBLAS_RETURN_IF_ERROR(cublasSetStream(cublas_handle_, cuda_stream_)); @@ -115,6 +114,7 @@ OrtStatus* CudaSyncStream::InitHandles() { PL_CUDNN_RETURN_IF_ERROR(cudnnSetStream(cudnn_handle_, cuda_stream_)); PL_CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&cublas_lt_handle_)); + RegisterStream(cuda_stream_, this); return nullptr; } From d5ff49ff0c382e575a0db2b3430dce6f63c01711 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 30 Mar 2026 14:24:03 -0700 Subject: [PATCH 45/48] update design doc --- docs/cuda_plugin_ep/cuda_plugin_ep_design.md | 42 ++++++++++++-------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md index 7be29ff5e9442..c03453eafbd2d 100644 --- a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -117,7 +117,7 @@ In the plugin build, `provider_api.h` (normally included from `cuda_common.h`) i ### 3.4 Kernel Adapter (`cuda_kernel_adapter.h`) -This 700+ line header provides everything CUDA kernels need that would normally come from framework infrastructure: +This 1100+ line header provides everything CUDA kernels need that would normally come from framework infrastructure: | Section | What It Provides | |---------|-----------------| @@ -575,7 +575,6 @@ Section 7 reflects the current source exclusions in `cmake/onnxruntime_providers |----------------|--------------------------|------------------------| | `core/providers/cuda/controlflow/*` | The framework controlflow kernels inherit from CPU-side controlflow bases (`If`, `Loop`, `Scan`) and are intentionally omitted from the plugin source list | No change is currently planned. The plugin uses its own `cuda_controlflow_plugin.cc` wrappers instead of these framework sources | | `tunable/*` | Depends on the real `CudaTuningContext` and other framework CUDA EP infrastructure that is not available in the plugin build | Add a plugin-capable tuning context and remove the remaining framework-only tunable dependencies | -| `math/einsum.cc` | The top-level framework einsum provider path is still not plugin-safe even though supporting utility code is now included | Finish a plugin-safe top-level einsum path and then remove the CMake exclusion | | `tensor/sequence_op.cc` | Uses `TensorSeq`, which is still not adapter-safe here | Add `TensorSeq` adapter coverage | | `contrib_ops/cuda/llm/*` | Contrib LLM sources have not gone through the same adapter-migration pass as the core CUDA LLM kernels | Finish contrib-LLM-specific adapter work | | `contrib_ops/cuda/tensor/shrunken_gather.cc` | The training-side header path still depends on framework/provider API wiring | Low-priority training-specific adapter work | @@ -594,8 +593,6 @@ The current exclusions fall into a few categories: 3. **Deliberate scope cuts** — ATen and collective/NCCL sources remain intentionally out of scope for the standalone CUDA plugin. -4. **Top-level framework wrappers still excluded** — `math/einsum.cc` remains excluded even though supporting pieces such as `einsum_utils/*` are now plugin-safe. - --- ## 8. Remaining `#ifdef` Guards in Kernel Code @@ -672,21 +669,28 @@ The plugin is then available as `CudaPluginExecutionProvider` in session provide `onnxruntime/test/python/transformers/test_cuda_plugin_ep.py` provides the current focused plugin validation flow: -| Stage | What It Tests | -|-------|---------------| -| Registration | Dynamic loading via `register_execution_provider_library()` and EP device discovery | -| Stage 2 | Basic ops: Add, MatMul, Gemm, Conv | -| Stage 3 | NHWC-requested sessions: Conv, BatchNorm, MaxPool, AveragePool. These validate correctness under `prefer_nhwc` and require plugin-backed NHWC execution to succeed; they are the focused regression suite for the plugin NHWC path | -| Stage 5A | Standard ops: Reshape, Split, Concat, Gather, Unsqueeze | -| Stage 5B | More ops: Tile, CumSum, ConstantOfShape, SpaceToDepth, Pad, Slice, Resize, Sum | -| Stage 5C | CPU base class ops: Upsample, DepthToSpace | -| Stage 5D | Contrib ops: FastGelu, SkipLayerNorm, BiasDropout | +| Category | What It Tests | +|----------|---------------| +| Registration | Dynamic loading via `register_execution_provider_library()` and EP device discovery (Add, MatMul, Gemm, Conv) | +| Provider options | Valid option parsing, invalid device rejection, multi-device selection | +| NHWC | NHWC-requested sessions: Conv, BatchNorm, MaxPool, AveragePool. These validate correctness under `prefer_nhwc` and require plugin-backed NHWC execution to succeed; they are the focused regression suite for the plugin NHWC path | +| Tensor ops | Reshape, Split, Concat, Gather, Unsqueeze, Tile, Pad, Slice, Transpose, Cast, Where, Flatten, ArgMax, TopK, Trilu, NonZero | +| Math ops | Softmax, Relu, Sigmoid, Tanh, Einsum (single and batched) | +| Reduce | ReduceMean, ReduceSum | +| Space/depth | SpaceToDepth, DepthToSpace, Upsample | +| Shape ops | CumSum, ConstantOfShape, Resize, Sum (variadic) | +| Normalization | LayerNormalization, InstanceNormalization | +| Conv | ConvTranspose | +| Scatter/gather | GatherND, ScatterElements, OneHot | +| Spatial | GridSample | +| Contrib ops | FastGelu, Gelu, BiasGelu, SkipLayerNorm, BiasDropout, FusedMatMul | | Dropout | Dropout opset 7 and opset 10 — verifies registrations moved to `dropout.cc` | -| Quantization | DequantizeLinear / QuantizeLinear opset 21 — exercises `TWO_TYPED_KERNEL_EX` adapter macro | +| Quantization | DequantizeLinear / QuantizeLinear opset 21 — exercises `TWO_TYPED_KERNEL_EX` adapter macro; MatMulInteger | | GatherBlockQuantized | Contrib `GatherBlockQuantized` — exercises `THREE_TYPED_KERNEL_EX` adapter macro | | Identity | Identity opset 13 and opset 25 — re-enabled op with `TensorSeq` path guarded | | Crop | Crop (opset 1) — previously excluded contrib op, now re-enabled | -| Key-ops probe | Session-based probing of representative ops: Sub, Relu, Softmax, Transpose, Cast, Sigmoid, ConvTranspose, LRN | +| Memcpy | Explicit `MemcpyFromHost` and `MemcpyToHost` standalone tests to ensure copy ops are dispatched | +| Key-ops probe | Session-based probing that all key ops are assigned to `CudaPluginExecutionProvider` | ### 10.2 Running Tests @@ -846,18 +850,22 @@ onnxruntime/core/providers/cuda/plugin/ ├── cuda_ep.h / .cc # CudaEp : OrtEp implementation ├── cuda_ep_factory.h / .cc # CudaEpFactory : OrtEpFactory ├── cuda_plugin_ep.cc # DLL entry points (CreateEpFactories/ReleaseEpFactory) +├── cuda_plugin_ep_symbols.def # Windows DLL export definitions ├── cuda_plugin_kernels.h / .cu # Kernel registry creation ├── cuda_stream_plugin.h / .cc # CudaSyncStream (handles, notifications) ├── cuda_allocator_plugin.h / .cc # Device/pinned allocators ├── cuda_data_transfer_plugin.h / .cc # GPU↔CPU data transfer +├── cuda_memcpy_plugin.cc # MemcpyFromHost/MemcpyToHost standalone kernels ├── cuda_controlflow_plugin.h / .cc / .cu # If/Loop/Scan wrappers ├── cuda_plugin_utils.h # Common macros, error handling └── provider_api_shims.cc # Reimplemented utility functions include/onnxruntime/ep/ +├── README.md # EP adapter layer overview ├── adapters.h # Master include + type aliasing (force-included) ├── api.h # ORT C API includes ├── common.h # EP common utilities +├── get_capability_utils.h # GetCapability helper utilities └── adapter/ ├── allocator.h # IAllocator adapter ├── data_transfer_manager.h # DataTransferManager adapter @@ -979,7 +987,7 @@ include/onnxruntime/ep/ For plugin EPs, the `OrtEp::GetCapability` callback currently has no mechanism to receive or report resource usage — the `OrtEp` C API does not expose `IResourceAccountant`. Two design options: - - **Option A (preferred — ORT core change):** Add an `OrtEp` analogue of the current `IResourceAccountant` flow instead of inventing a separate plugin-only protocol. In the core path, the partitioner passes an `IResourceAccountant*` into `IExecutionProvider::GetCapability(...)`. The plugin equivalent should preserve that model as closely as possible: either expose an accountant-like object through the `OrtEp` API, or add a small `OrtEp` callback surface that the partitioner can use to compute and accumulate per-node resource cost while still keeping the partitioner's threshold/stop-assignment logic in one place. + - **Option A (preferred — ORT core change, completed in PR #27595):** Add an `OrtEp` analogue of the current `IResourceAccountant` flow. PR #27595 introduced `OrtEpGraphSupportInfo_RequestResourceForNode` and `OrtEpGraphSupportInfo_StopAssigningNodesDueToResourceExhaustion` to the C API. This is the implementation path moving forward. - **Option B (plugin-side workaround):** Expose the VRAM threshold through a plugin-specific session option key. During `GetCapabilityImpl`, the plugin reads the threshold from the parsed config and performs its own initializer-size accounting using `OrtEp_GetNodeAttributes` / node-graph-view APIs already present in the `OrtEp` API surface. This avoids an ORT core change but duplicates budget-tracking logic. @@ -987,4 +995,4 @@ include/onnxruntime/ep/ PR #27595 also introduces `layering_annotations` — node-level `"layer_ann"` metadata that routes nodes to specific EPs or CPU during partitioning. The expected model is that plugin EPs participate through the same `GetCapability` flow and therefore observe whatever node set ORT presents after applying layering rules. In practice that should mean no plugin-specific changes are needed to respect annotations that exclude nodes from the plugin. However, the plugin design should avoid depending on undocumented filtering details in the `OrtGraph*` contract. If the plugin EP itself needs to *read* layering annotations for internal decisions, or if the API needs to make filtered-vs-unfiltered graph semantics explicit, that would require new `OrtEp` API surface. - **Recommended action:** Extend the `OrtEp` C API with a plugin equivalent of the existing resource-accounting flow (Option A), and document the graph-view contract for plugin `GetCapability` under layering annotations as part of the same effort that exposes `kOrtSessionOptionsLayerAssignmentSettings` to plugin EP sessions. + **Recommended action:** Combine with the recently added `OrtEpGraphSupportInfo_RequestResourceForNode` C API explicitly (completed in PR #27595 on the ORT core side) to correctly assign nodes within the budget in the plugin's `CudaEp::GetCapabilityImpl()` when layer assignments exist. From 6377dfd0dff2b1aec42a9164f71d6fff493cbe7b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 30 Mar 2026 14:38:55 -0700 Subject: [PATCH 46/48] address remaining feedbacks --- .../providers/cuda/plugin/cuda_ep_factory.cc | 33 ++++++++++++---- .../cuda/plugin/cuda_kernel_adapter.h | 39 +++++++++++++++---- .../cuda/plugin/cuda_plugin_kernels.cu | 2 +- 3 files changed, 59 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index f5443da32433b..494deff257b7b 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -7,6 +7,7 @@ #include "core/common/string_utils.h" #include +#include #include #include #include @@ -102,6 +103,14 @@ std::string GetProviderOptionPrefix(std::string_view provider_name) { return "ep." + onnxruntime::utils::GetLowercaseString(std::string{provider_name}) + "."; } +void LogWarning(const OrtApi& ort_api, const OrtLogger& logger, const char* file, int line, + const char* function, const char* msg) { + OrtStatus* st = ort_api.Logger_LogMessage(&logger, ORT_LOGGING_LEVEL_WARNING, msg, file, line, function); + if (st != nullptr) { + ort_api.ReleaseStatus(st); + } +} + } // namespace CudaEpFactory::HardwareDeviceKey CudaEpFactory::MakeDeviceKey(const OrtApi& ort_api, @@ -368,7 +377,7 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( } }; - auto read_session_config_int = [&](std::initializer_list keys, int& value) { + auto read_session_config_non_negative_int = [&](std::initializer_list keys, int& value) { for (const auto& key : keys) { auto raw_value = try_get_session_config(key); if (!raw_value.has_value()) { @@ -376,12 +385,18 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( } try { - value = std::stoi(*raw_value); + int parsed = std::stoi(*raw_value); + if (parsed < 0) { + log_invalid_session_config(key, "a non-negative integer"); + return; + } + + value = parsed; return; } catch (const std::exception&) { } - log_invalid_session_config(key, "an integer"); + log_invalid_session_config(key, "a non-negative integer"); return; } }; @@ -420,7 +435,7 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( read_session_config_bool( {fuse_conv_bias_key, "ep.cuda.fuse_conv_bias", "fuse_conv_bias"}, config.fuse_conv_bias); - read_session_config_int( + read_session_config_non_negative_int( {sdpa_kernel_key, "ep.cuda.sdpa_kernel", "sdpa_kernel"}, config.sdpa_kernel); @@ -477,8 +492,9 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateAllocatorImpl( /*static*/ void ORT_API_CALL CudaEpFactory::ReleaseAllocatorImpl( - OrtEpFactory* /*this_ptr*/, OrtAllocator* allocator) noexcept { + OrtEpFactory* this_ptr, OrtAllocator* allocator) noexcept { if (!allocator) return; + auto* factory = static_cast(this_ptr); auto* typed_allocator = static_cast(allocator); switch (typed_allocator->GetKind()) { case CudaAllocatorKind::kDevice: @@ -488,8 +504,11 @@ void ORT_API_CALL CudaEpFactory::ReleaseAllocatorImpl( delete static_cast(allocator); return; default: - // Cannot throw in noexcept function - break; + LogWarning(factory->ort_api_, factory->default_logger_, __FILE__, __LINE__, + "CudaEpFactory::ReleaseAllocatorImpl", + "ReleaseAllocatorImpl received an unknown CudaAllocatorKind. Leaking the allocator instance."); + assert(false && "Unknown CudaAllocatorKind"); + return; } } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 81dee25576816..95ca59a584aea 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -221,11 +221,14 @@ class PluginKernelCollector { std::lock_guard lock(mutex_); entries_.push_back(fn); } - const std::vector& Entries() const { return entries_; } + std::vector Entries() const { + std::lock_guard lock(mutex_); + return entries_; + } private: std::vector entries_; - std::mutex mutex_; + mutable std::mutex mutex_; }; } // namespace cuda @@ -963,14 +966,22 @@ class CudaKernel : public OpKernel { return IAllocatorUniquePtr(static_cast(p), [s, used_async_alloc](T* ptr) { if (ptr) { - if (used_async_alloc && s) { + // Guard: only attempt async free if the stream is still registered. + // CudaSyncStream::~CudaSyncStream guarantees UnregisterStream() is + // called before cudaStreamDestroy(), so a non-null lookup here means + // the raw cudaStream_t handle is still valid. + if (used_async_alloc && s && + cuda_plugin::CudaSyncStream::FromCudaStream(static_cast(s)) != nullptr) { cudaError_t free_result = cudaFreeAsync(ptr, static_cast(s)); - if (free_result == cudaErrorNotSupported || free_result == cudaErrorInvalidValue) { - cudaFree(ptr); + if (free_result == cudaSuccess) { + return; } - } else { - cudaFree(ptr); } + + // Fall back to synchronous free if async free is unsupported or if the + // stream is no longer registered. cudaFree is valid for allocations + // returned by cudaMallocAsync and avoids using a stale stream handle. + cudaFree(ptr); } }); } @@ -985,6 +996,20 @@ class CudaKernel : public OpKernel { sync->EnqueueDeferredCPUBuffer(p); return; } + + if (s != nullptr) { + cudaError_t sync_result = cudaStreamSynchronize(static_cast(s)); + if (sync_result != cudaSuccess) { + // If the raw stream handle is already invalid during teardown, prefer a + // bounded leak over freeing pinned memory that could still be in use by + // an in-flight async copy. + LOGS_DEFAULT(WARNING) << "AddDeferredReleaseCPUPtr: cudaStreamSynchronize failed (" + << cudaGetErrorString(sync_result) + << "); leaking pinned buffer to avoid use-after-free"; + return; + } + } + cudaFreeHost(p); } template diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu index 684784400cbeb..b5b3f19d8a7c9 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu +++ b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu @@ -41,7 +41,7 @@ OrtStatus* CreateCudaKernelRegistry(const OrtEpApi& /*ep_api*/, ::onnxruntime::ep::adapter::KernelRegistry registry; // Iterate all self-registered BuildKernelCreateInfoFn pointers. - const auto& entries = ::onnxruntime::cuda::PluginKernelCollector::Instance().Entries(); + auto entries = ::onnxruntime::cuda::PluginKernelCollector::Instance().Entries(); for (auto build_fn : entries) { ::onnxruntime::ep::adapter::KernelCreateInfo info = build_fn(); if (info.kernel_def != nullptr) { // filter the BuildKernelCreateInfo sentinel From 3f7ec373303f25a8fe7c7e2900734d37ced350e2 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 30 Mar 2026 16:50:37 -0700 Subject: [PATCH 47/48] refine code and doc --- cmake/CMakeLists.txt | 4 +- docs/cuda_plugin_ep/cuda_plugin_ep_design.md | 5 +++ .../cuda/plugin/cuda_controlflow_plugin.cc | 22 ++++++++++- .../cuda/plugin/cuda_kernel_adapter.h | 38 +++++++++++++------ .../test/framework/dynamic_plugin_ep_test.cc | 18 +++++++++ .../transformers/cuda_plugin_ep_helper.py | 12 +++--- 6 files changed, 80 insertions(+), 19 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 1e399cc6ce948..cb653b102aea2 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1440,7 +1440,7 @@ if (Git_FOUND) if (onnxruntime_USE_FP8_KV_CACHE) string(APPEND ORT_BUILD_INFO "fp8-kv-cache=1, ") endif() - if (onnxruntime_BUILD_CUDA_EP_AS_PLUGIN) + if (onnxruntime_USE_CUDA AND onnxruntime_BUILD_CUDA_EP_AS_PLUGIN) string(APPEND ORT_BUILD_INFO "cuda-plugin-ep=1, ") endif() if (onnxruntime_DUMP_TENSOR) @@ -1777,7 +1777,7 @@ foreach(onnxruntime_cmake_file ${ONNXRUNTIME_CMAKE_FILES}) endforeach() # CUDA EP Plugin build (independent shared library) -if (onnxruntime_BUILD_CUDA_EP_AS_PLUGIN) +if (onnxruntime_USE_CUDA AND onnxruntime_BUILD_CUDA_EP_AS_PLUGIN) include(onnxruntime_providers_cuda_plugin.cmake) endif() if (UNIX) diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md index c03453eafbd2d..e4e6794b18f94 100644 --- a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -995,4 +995,9 @@ include/onnxruntime/ep/ PR #27595 also introduces `layering_annotations` — node-level `"layer_ann"` metadata that routes nodes to specific EPs or CPU during partitioning. The expected model is that plugin EPs participate through the same `GetCapability` flow and therefore observe whatever node set ORT presents after applying layering rules. In practice that should mean no plugin-specific changes are needed to respect annotations that exclude nodes from the plugin. However, the plugin design should avoid depending on undocumented filtering details in the `OrtGraph*` contract. If the plugin EP itself needs to *read* layering annotations for internal decisions, or if the API needs to make filtered-vs-unfiltered graph semantics explicit, that would require new `OrtEp` API surface. + Current known limitations to keep in future work: + + - The `cuda(...)` device selector currently matches only the built-in `CUDAExecutionProvider`. It does not match the plugin EP name `CudaPluginExecutionProvider`, so layer assignment settings written against `cuda(...)` do not work with the CUDA plugin EP today. + - The `gpu:(...)` selector is currently matched using `OrtHardwareDevice::device_id`. That field is not a stable CUDA ordinal and is not guaranteed to uniquely identify one physical GPU, so index-based layer assignment is unreliable for the CUDA plugin EP, especially on hosts with multiple similar NVIDIA GPUs. + **Recommended action:** Combine with the recently added `OrtEpGraphSupportInfo_RequestResourceForNode` C API explicitly (completed in PR #27595 on the ORT core side) to correctly assign nodes within the budget in the plugin's `CudaEp::GetCapabilityImpl()` when layer assignments exist. diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc index 31ca602a60b77..a65c4e925c97f 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc @@ -110,20 +110,28 @@ OrtStatus* ORT_API_CALL PluginLoopHelper::ConcatOutputImpl( Ort::ConstValue first_output(per_iteration_outputs[0]); size_t bytes_per_iteration = first_output.GetTensorSizeInBytes(); + if (bytes_per_iteration > output_size_in_bytes) { + return Ort::Status("Loop ConcatOutput: output buffer too small for first iteration", ORT_FAIL).release(); + } char* cur = static_cast(output); + size_t total_bytes_copied = 0; for (size_t i = 0; i < num_per_iteration_outputs; i++) { Ort::ConstValue val(per_iteration_outputs[i]); size_t cur_bytes = val.GetTensorSizeInBytes(); if (cur_bytes != bytes_per_iteration) { return Ort::Status("Inconsistent size in loop output iteration", ORT_FAIL).release(); } + if (cur_bytes > output_size_in_bytes - total_bytes_copied) { + return Ort::Status("Loop ConcatOutput: output buffer too small", ORT_FAIL).release(); + } PL_CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cur, val.GetTensorRawData(), bytes_per_iteration, cudaMemcpyDeviceToDevice, cuda_stream)); cur += bytes_per_iteration; + total_bytes_copied += bytes_per_iteration; } - if (static_cast(cur - static_cast(output)) != output_size_in_bytes) { + if (total_bytes_copied != output_size_in_bytes) { return Ort::Status("Loop ConcatOutput: output buffer not fully filled", ORT_FAIL).release(); } @@ -196,6 +204,18 @@ OrtStatus* ORT_API_CALL PluginScanHelper::TransposeImpl( return Ort::Status("Scan Transpose: permutation size does not match input rank", ORT_FAIL).release(); } + std::vector seen_permutation_indices(num_dims, false); + for (size_t i = 0; i < num_permutation_elems; ++i) { + const size_t perm_index = permutation[i]; + if (perm_index >= num_dims) { + return Ort::Status("Scan Transpose: permutation index is out of range", ORT_FAIL).release(); + } + if (seen_permutation_indices[perm_index]) { + return Ort::Status("Scan Transpose: permutation contains duplicate indices", ORT_FAIL).release(); + } + seen_permutation_indices[perm_index] = true; + } + if (total_elements == 0) return nullptr; // Determine element size from the data type diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 95ca59a584aea..b72058dc90baa 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -442,10 +442,22 @@ template <> struct SizeOf { static constexpr size_t value = 0; }; -[[nodiscard]] inline size_t BytesForCount(size_t count_or_bytes, size_t element_size) { - if (element_size == 0) return count_or_bytes; - if (count_or_bytes > (std::numeric_limits::max() / element_size)) return 0; - return count_or_bytes * element_size; + +[[nodiscard]] inline bool TryBytesForCount(size_t count_or_bytes, size_t element_size, size_t& bytes) { + if (element_size == 0) { + // `element_size == 0` is the sentinel for the `T = void` path. + // In that mode callers already pass a raw byte count to helpers like + // GetScratchBuffer(workspace_bytes, ...), so no multiplication is needed. + bytes = count_or_bytes; + return true; + } + + if (count_or_bytes > (std::numeric_limits::max() / element_size)) { + return false; + } + + bytes = count_or_bytes * element_size; + return true; } template @@ -942,8 +954,8 @@ class CudaKernel : public OpKernel { template inline IAllocatorUniquePtr GetScratchBuffer(size_t cnt, void* s) const { if (cnt == 0) return IAllocatorUniquePtr(nullptr, [](T*) {}); - size_t sz = detail::BytesForCount(cnt, detail::SizeOf::value); - if (sz == 0) { + size_t sz = 0; + if (!detail::TryBytesForCount(cnt, detail::SizeOf::value, sz)) { ORT_THROW("CUDA scratch buffer allocation size overflow for ", cnt, " elements"); } @@ -1015,8 +1027,8 @@ class CudaKernel : public OpKernel { template inline IAllocatorUniquePtr AllocateBufferOnCPUPinned(size_t cnt) const { if (cnt == 0) return IAllocatorUniquePtr(nullptr, [](T*) {}); - size_t sz = detail::BytesForCount(cnt, detail::SizeOf::value); - if (sz == 0) { + size_t sz = 0; + if (!detail::TryBytesForCount(cnt, detail::SizeOf::value, sz)) { ORT_THROW("CUDA pinned CPU buffer allocation size overflow for ", cnt, " elements"); } void* p = nullptr; @@ -1034,7 +1046,11 @@ class CudaKernel : public OpKernel { for (size_t i = 0; i != n; ++i) *p++ = v; } CudaAsyncBuffer(const CudaKernel* ok, gsl::span vec) : CudaAsyncBuffer(ok, vec.size()) { - memcpy(CpuPtr(), vec.data(), detail::BytesForCount(vec.size(), sizeof(T))); + size_t bytes = 0; + if (!detail::TryBytesForCount(vec.size(), sizeof(T), bytes)) { + ORT_THROW("CUDA async buffer host copy size overflow for ", vec.size(), " elements"); + } + memcpy(CpuPtr(), vec.data(), bytes); } void AllocCpuPtr(size_t n) { cpu_ = op_kernel_->AllocateBufferOnCPUPinned(n); @@ -1044,8 +1060,8 @@ class CudaKernel : public OpKernel { Status CopyToGpu(void* s) { if (cpu_) { gpu_ = op_kernel_->GetScratchBuffer(count_, s); - const size_t bytes = detail::BytesForCount(count_, sizeof(T)); - if (count_ > 0 && bytes == 0) { + size_t bytes = 0; + if (!detail::TryBytesForCount(count_, sizeof(T), bytes)) { ORT_THROW("CUDA async buffer copy size overflow for ", count_, " elements"); } if (cudaMemcpyAsync(gpu_.get(), cpu_.get(), bytes, cudaMemcpyHostToDevice, static_cast(s)) != cudaSuccess) return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Memcpy fail"); diff --git a/onnxruntime/test/framework/dynamic_plugin_ep_test.cc b/onnxruntime/test/framework/dynamic_plugin_ep_test.cc index 9c7b9fc3523f0..be2225ee66b80 100644 --- a/onnxruntime/test/framework/dynamic_plugin_ep_test.cc +++ b/onnxruntime/test/framework/dynamic_plugin_ep_test.cc @@ -6,6 +6,7 @@ #include "core/framework/execution_provider.h" #include "test/unittest_util/test_dynamic_plugin_ep.h" +#include #include #include "test/util/include/asserts.h" @@ -113,6 +114,23 @@ TEST(DynamicPluginEpInfraTest, CudaKernelAdapterRuntimeConfigExposesFuseConvBias EXPECT_FALSE(attention_kernel_options->UseEfficientAttention()); EXPECT_FALSE(attention_kernel_options->UseCudnnFlashAttention()); } + +TEST(DynamicPluginEpInfraTest, CudaKernelAdapterTryBytesForCountDetectsOverflow) { + size_t bytes = 0; + EXPECT_FALSE(onnxruntime::cuda::detail::TryBytesForCount(std::numeric_limits::max(), 2, bytes)); +} + +TEST(DynamicPluginEpInfraTest, CudaKernelAdapterTryBytesForCountPreservesRawByteCounts) { + size_t bytes = 0; + ASSERT_TRUE(onnxruntime::cuda::detail::TryBytesForCount(123, 0, bytes)); + EXPECT_EQ(bytes, size_t{123}); +} + +TEST(DynamicPluginEpInfraTest, CudaKernelAdapterTryBytesForCountNormalCase) { + size_t bytes = 0; + ASSERT_TRUE(onnxruntime::cuda::detail::TryBytesForCount(10, 4, bytes)); + EXPECT_EQ(bytes, size_t{40}); +} #endif } // namespace onnxruntime::test diff --git a/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py b/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py index 581edc5940c77..ebba84ccd0c27 100644 --- a/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py +++ b/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import logging import os import sys from importlib.metadata import PackageNotFoundError, distribution @@ -10,16 +11,16 @@ import onnxruntime as onnxrt +CUDA_PLUGIN_EP_NAME = "CudaPluginExecutionProvider" +enable_debug_print = False +logger = logging.getLogger(__name__) + class _CudaPluginRegistrationState: attempted = False registered = False -CUDA_PLUGIN_EP_NAME = "CudaPluginExecutionProvider" -enable_debug_print = False - - def should_test_with_cuda_plugin_ep(default_value: bool = True) -> bool: return os.getenv("ORT_TEST_CUDA_PLUGIN_EP", "1" if default_value else "0") == "1" @@ -188,5 +189,6 @@ def get_cuda_provider_name() -> str | None: def _is_plugin_provider_type_available() -> bool: try: return CUDA_PLUGIN_EP_NAME in onnxrt.get_available_providers() - except Exception: + except Exception as e: + logger.warning("Failed to query available providers while checking %s availability: %s", CUDA_PLUGIN_EP_NAME, e) return False From a7fb8f1f698dd121b943cad213480f1010075bc2 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 30 Mar 2026 21:43:59 -0700 Subject: [PATCH 48/48] fix webgpu --- include/onnxruntime/ep/adapter/op_kernel.h | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/include/onnxruntime/ep/adapter/op_kernel.h b/include/onnxruntime/ep/adapter/op_kernel.h index 8c9c16b62692e..273461b36e75f 100644 --- a/include/onnxruntime/ep/adapter/op_kernel.h +++ b/include/onnxruntime/ep/adapter/op_kernel.h @@ -132,16 +132,18 @@ struct OpKernelContext { return Output(index, TensorShape{shape}); } [[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const { - const auto* execution_provider = op_kernel_.Info().GetExecutionProvider(); - ORT_ENFORCE(execution_provider != nullptr, "Kernel does not have an execution provider."); - const auto* ort_ep = execution_provider->GetOrtEp(); + // Use GetOrtEp() directly from the cached KernelInfoCache rather than going through + // GetExecutionProvider()->GetOrtEp(). GetExecutionProvider() returns the native EP impl + // (e.g. WebGpuExecutionProvider), which doesn't override GetOrtEp() and returns nullptr. + // The cached ort_ep_ is resolved from the plugin wrapper's IExecutionProvider during + // KernelInfoCache construction, so it correctly holds the OrtEp instance. + const auto* ort_ep = op_kernel_.Info().GetOrtEp(); ORT_ENFORCE(ort_ep != nullptr, "Kernel execution provider is not associated with an OrtEp instance."); return static_cast(ort_ep)->GetTempSpaceCPUAllocator(output); } [[nodiscard]] Status GetTempSpaceAllocator(AllocatorPtr* output) const { - const auto* execution_provider = op_kernel_.Info().GetExecutionProvider(); - ORT_ENFORCE(execution_provider != nullptr, "Kernel does not have an execution provider."); - const auto* ort_ep = execution_provider->GetOrtEp(); + // See comment in GetTempSpaceCPUAllocator for why we use GetOrtEp() directly. + const auto* ort_ep = op_kernel_.Info().GetOrtEp(); ORT_ENFORCE(ort_ep != nullptr, "Kernel execution provider is not associated with an OrtEp instance."); return static_cast(ort_ep)->GetTempSpaceAllocator(output); }