diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0b70e01d15dbe..cb653b102aea2 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -84,6 +84,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS" OFF) cmake_dependent_option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" ON "onnxruntime_USE_CUDA" OFF) +cmake_dependent_option(onnxruntime_BUILD_CUDA_EP_AS_PLUGIN "Build CUDA EP as a separate plugin shared library" OFF "onnxruntime_USE_CUDA" OFF) option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF) option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF) option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) @@ -1439,6 +1440,9 @@ if (Git_FOUND) if (onnxruntime_USE_FP8_KV_CACHE) string(APPEND ORT_BUILD_INFO "fp8-kv-cache=1, ") endif() + if (onnxruntime_USE_CUDA AND onnxruntime_BUILD_CUDA_EP_AS_PLUGIN) + string(APPEND ORT_BUILD_INFO "cuda-plugin-ep=1, ") + endif() if (onnxruntime_DUMP_TENSOR) string(APPEND ORT_BUILD_INFO "dump-tensor=1, ") endif() @@ -1771,6 +1775,11 @@ endif() foreach(onnxruntime_cmake_file ${ONNXRUNTIME_CMAKE_FILES}) include(${onnxruntime_cmake_file}.cmake) endforeach() + +# CUDA EP Plugin build (independent shared library) +if (onnxruntime_USE_CUDA AND onnxruntime_BUILD_CUDA_EP_AS_PLUGIN) + include(onnxruntime_providers_cuda_plugin.cmake) +endif() if (UNIX) option(BUILD_PKGCONFIG_FILES "Build and install pkg-config files" ON) else() diff --git a/cmake/external/cuda_configuration.cmake b/cmake/external/cuda_configuration.cmake index d8378c934f0cc..df180d185a268 100644 --- a/cmake/external/cuda_configuration.cmake +++ b/cmake/external/cuda_configuration.cmake @@ -161,6 +161,20 @@ macro(setup_cuda_architectures) set(CMAKE_CUDA_ARCHITECTURES_ORIG "${CMAKE_CUDA_ARCHITECTURES}") message(STATUS "GPU architectures: ${CMAKE_CUDA_ARCHITECTURES_ORIG}") + unset(ORT_HAS_SM80_OR_LATER) + foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES_ORIG) + if(CUDA_ARCH MATCHES "^([0-9]+)") + if(CMAKE_MATCH_1 GREATER_EQUAL 80) + set(ORT_HAS_SM80_OR_LATER ON) + break() + endif() + endif() + endforeach() + + if(ORT_HAS_SM80_OR_LATER) + add_definitions("-DHAS_SM80_OR_LATER") + endif() + set(ARCHITECTURES_WITH_KERNELS "80" "86" "89" "90" "100" "110" "120") foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS) if(NOT "${CUDA_ARCH}" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index ff667d8f117e0..6336070e836c3 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -20,6 +20,9 @@ "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" ) endif() + # Exclude plugin directory if it was picked up by GLOB_RECURSE + list(FILTER onnxruntime_providers_cuda_cc_srcs EXCLUDE REGEX "core/providers/cuda/plugin/.*") + # Remove pch files list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h" @@ -43,6 +46,8 @@ "${ONNXRUNTIME_ROOT}/core/providers/cuda/math/unary_elementwise_ops_impl.cu" ) endif() + # Exclude plugin directory if it was picked up by GLOB_RECURSE + list(FILTER onnxruntime_providers_cuda_cu_srcs EXCLUDE REGEX "core/providers/cuda/plugin/.*") source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake new file mode 100644 index 0000000000000..9dbcf3721b06b --- /dev/null +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -0,0 +1,287 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# Build the CUDA Execution Provider as a plugin shared library. +# This file is included from the main CMakeLists.txt when onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON. + +message(STATUS "Building CUDA EP as plugin shared library") + + + +set(CUDA_PLUGIN_EP_DIR "${ONNXRUNTIME_ROOT}/core/providers/cuda/plugin") + +# --- Collect standard CUDA EP sources --- +file(GLOB_RECURSE CUDA_EP_CC_SRCS CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" +) + +file(GLOB_RECURSE CUDA_EP_CU_SRCS CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" +) + +# --- Collect contrib ops sources --- +file(GLOB_RECURSE CUDA_CONTRIB_OPS_CC_SRCS CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cc" +) + +file(GLOB_RECURSE CUDA_CONTRIB_OPS_CU_SRCS CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/*.cu" +) + +list(APPEND CUDA_PLUGIN_EP_CC_SRCS + ${CUDA_EP_CC_SRCS} + ${CUDA_CONTRIB_OPS_CC_SRCS} +) + +list(APPEND CUDA_PLUGIN_EP_CU_SRCS + ${CUDA_EP_CU_SRCS} + ${CUDA_CONTRIB_OPS_CU_SRCS} +) + +list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX "onnxruntime/contrib_ops/cuda/aten_ops/.*") +list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX "onnxruntime/contrib_ops/cuda/collective/.*") + +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX "onnxruntime/contrib_ops/cuda/aten_ops/.*") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX "onnxruntime/contrib_ops/cuda/collective/.*") + +# Exclude files that include cuda_execution_provider.h (directly or transitively), +# which conflicts with the adapter shim CUDAExecutionProvider class. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_execution_provider\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_provider_factory\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_provider_interface\\.cc$") + +# Exclude the framework controlflow/ subdirectory — these inherit from CPU base +# classes (If, Loop, Scan). The plugin has its own control flow wrappers in +# plugin/cuda_controlflow_plugin.cc that delegate to OrtEpApi. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/core/providers/cuda/controlflow/.*") + +# Exclude the entire tunable/ subdirectory — it depends on the real CudaTuningContext +# and CUDAExecutionProvider which are not available in the plugin build. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tunable/.*") + +# Exclude real EP infrastructure files (replaced by plugin/ equivalents). +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_stream_handle\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_execution_provider_info\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_graph\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_mempool_arena\\.cc$") + +# Exclude cuda_common.cc — its HalfGemmOptions definitions conflict with the +# adapter's inline shim. Utility functions are replaced or not needed. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_common\\.cc$") + +# Exclude cuda_nhwc_kernels.cc and cuda_contrib_kernels.cc — these files contain +# explicit BuildKernelCreateInfo<> registration tables that reference ALL kernel +# classes (including those in excluded source files like space_depth_ops.cc, +# controlflow/, transformers/, etc.), causing undefined symbols at link time. +# With PluginKernelCollector, individual kernel files self-register via macro +# overrides, so these centralized tables are not needed in the plugin build. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_nhwc_kernels\\.cc$") +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/cuda_contrib_kernels\\.cc$") + +# Exclude sequence_op.cc — uses TensorSeq (incomplete type in plugin build). +# identity_op.cc is now included: TensorSeq code path is guarded by +# BUILD_CUDA_EP_AS_PLUGIN and opset 14+ registrations use Tensor-only types. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/sequence_op\\.cc$") + +# Permanently excluded — pure CPU ops, handled by GetCpuPreferredNodes. +# size.cc registers onnxruntime::Size (CPU op) whose Compute() body lives +# in the CPU provider and is not linked into the plugin. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/size\\.cc$") + +# Permanently excluded — pure CPU ops, handled by GetCpuPreferredNodes. +# shape_op.cc inherits from onnxruntime::OpKernel (framework) +# which cannot convert to ep::adapter::OpKernel in the plugin build. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/shape_op\\.cc$") + +# Exclude contrib llm/ for now. The core CUDA llm kernels are adapter-safe, but +# contrib llm kernels still need their own plugin pass. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/llm/.*") +list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/llm/.*") + +# Exclude contrib training ops (shrunken_gather depends on provider_api.h in header). +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/shrunken_gather\\.cc$") + + +# Exclude contrib transformers/ (beam search, greedy search, sampling). Those need subgraph inference. +list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/transformers/.*") +list(FILTER CUDA_PLUGIN_EP_CU_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/transformers/.*") + +# Create shared library target using the ORT helper function for plugins +onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_plugin + ${CUDA_PLUGIN_EP_CC_SRCS} + ${CUDA_PLUGIN_EP_CU_SRCS} +) +# Keep the plugin CUDA target aligned with the repo-wide C++20 baseline. +# Forcing CUDA C++17 here breaks newer protobuf/absl headers used by the plugin +# build, as absl::compare expects standard ordering support in this configuration. +set_target_properties(onnxruntime_providers_cuda_plugin PROPERTIES + CUDA_STANDARD 20 + CUDA_STANDARD_REQUIRED ON +) + +# Suppress -Werror=maybe-uninitialized for local variables written by +# adapter OpKernelInfo::GetAttr<> (GCC falsely warns about variables that are +# initialized inside GetAttr’s output parameter path). +target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE + $<$,$>:-Wno-maybe-uninitialized> +) +target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE + # Flash-attention, XQA, MoE, and other pure CUDA kernel .cu files must NOT + # receive the ORT-framework force-include (it conflicts with cute::Tensor etc.). + # cuda_plugin_kernels.cu already #include "cuda_kernel_adapter.h" directly. + # Op-registration .cc files do not include it directly, so they need it here. + # + # IMPORTANT: The CXX force-include order matters — adapters.h MUST precede + # cuda_kernel_adapter.h because the adapter establishes type aliases that the + # kernel adapter header depends on. + # + # Force NVCC onto C++20 explicitly. With the VS generator the CUDA standard + # property alone still leaves `-std=c++17` in AdditionalOptions. + # Suppress NVCC cudafe warnings: + # 550 - variable set but never used (in adapter headers) + # 2810 - [[nodiscard]] false positive on Status assignments in op_kernel.h / kernel_registry.h + "$<$:SHELL:--std c++20>" + "$<$:--expt-relaxed-constexpr;-Xcudafe;--diag_suppress=550>" + "$<$:SHELL:-Xcudafe --diag_suppress=2810>" + "$<$:-include;${REPO_ROOT}/include/onnxruntime/ep/adapters.h>" + "$<$:SHELL:-include ${CUDA_PLUGIN_EP_DIR}/cuda_kernel_adapter.h>" +) + +if (MSVC) + target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE + "$<$:SHELL:-Xcompiler /permissive>" + "$<$:SHELL:-Xcompiler /wd4834>" + "$<$:SHELL:-Xcompiler /wd4127>" + "$<$:SHELL:-Xcompiler /wd4211>" + "$<$:SHELL:-Xcompiler /Zc:__cplusplus>" + "$<$:SHELL:-Xcompiler /bigobj>" + ) + + target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE + "$<$:/wd4127>" + ) +endif() + +# Mirror the core CUDA provider's CUDA 12.8+ NVCC workarounds so the plugin +# target handles stricter cudafe diagnostics consistently. +if (DEFINED onnxruntime_NVCC_THREADS) + set(onnxruntime_plugin_nvcc_threads "${onnxruntime_NVCC_THREADS}") +else() + set(onnxruntime_plugin_nvcc_threads "1") +endif() +target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE + "$<$:SHELL:--threads \"${onnxruntime_plugin_nvcc_threads}\">" + "$<$:--diag-suppress=177>" +) + +if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) + target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE + "$<$:--static-global-template-stub=false>" + "$<$:--diag-suppress=221>" + ) + + if (MSVC) + target_compile_options(onnxruntime_providers_cuda_plugin PRIVATE + "$<$:SHELL:-Xcompiler /wd4505>" + ) + endif() +endif() + +include(cudnn_frontend) +include(cutlass) + +# --- Find cuDNN (may be at a custom path via onnxruntime_CUDNN_HOME) --- +set(_CUDNN_SEARCH_PATHS "") +if(onnxruntime_CUDNN_HOME) + list(APPEND _CUDNN_SEARCH_PATHS "${onnxruntime_CUDNN_HOME}") +endif() +if(DEFINED ENV{CUDNN_HOME}) + list(APPEND _CUDNN_SEARCH_PATHS "$ENV{CUDNN_HOME}") +endif() + +set(CUDA_PLUGIN_CUDNN_INCLUDE_DIR ${CUDNN_INCLUDE_DIR}) +set(CUDA_PLUGIN_CUDNN_LIBRARY ${cudnn_LIBRARY}) + +if(NOT CUDA_PLUGIN_CUDNN_INCLUDE_DIR OR NOT CUDA_PLUGIN_CUDNN_LIBRARY) + message(FATAL_ERROR "cuDNN not found (from main ORT search) for CUDA Plugin EP.") +endif() + +message(STATUS "CUDA Plugin EP: cuDNN include: ${CUDA_PLUGIN_CUDNN_INCLUDE_DIR}") +message(STATUS "CUDA Plugin EP: cuDNN library: ${CUDA_PLUGIN_CUDNN_LIBRARY}") + +# Include directories — only public ORT headers + CUDA toolkit + cuDNN + internal headers for adapter +target_include_directories(onnxruntime_providers_cuda_plugin PRIVATE + ${REPO_ROOT}/include + ${REPO_ROOT}/include/onnxruntime/core/session + ${ONNXRUNTIME_ROOT} + ${CUDAToolkit_INCLUDE_DIRS} + ${CUDA_PLUGIN_CUDNN_INCLUDE_DIR} + ${Eigen3_SOURCE_DIR} + ${cutlass_SOURCE_DIR}/include + ${cutlass_SOURCE_DIR}/examples + ${cutlass_SOURCE_DIR}/tools/util/include +) + +onnxruntime_add_include_to_target( + onnxruntime_providers_cuda_plugin + onnxruntime_common + onnx + onnx_proto + ${PROTOBUF_LIB} + flatbuffers::flatbuffers +) + +# Link libraries +target_link_libraries(onnxruntime_providers_cuda_plugin PRIVATE + CUDA::cudart + CUDA::cublas + CUDA::cublasLt + CUDA::cufft + CUDNN::cudnn_all + cudnn_frontend + Boost::mp11 + safeint_interface + onnxruntime_framework + onnxruntime_graph + onnxruntime_mlas + onnxruntime_flatbuffers + onnxruntime_common + cpuinfo::cpuinfo + onnx + onnx_proto + ${PROTOBUF_LIB} +) + +# Symbol visibility — only export CreateEpFactories and ReleaseEpFactory +target_compile_definitions(onnxruntime_providers_cuda_plugin PRIVATE ORT_API_MANUAL_INIT BUILD_CUDA_EP_AS_PLUGIN ORT_USE_EP_API_ADAPTERS=1 ONNX_ML=1 ONNX_NAMESPACE=onnx ONNX_USE_LITE_PROTO=1) + +if (onnxruntime_USE_CUDA_NHWC_OPS) + target_compile_definitions(onnxruntime_providers_cuda_plugin PRIVATE ENABLE_CUDA_NHWC_OPS) +endif() + +if(WIN32) + # Windows: use .def file for symbol exports + set(CUDA_PLUGIN_DEF_FILE ${CUDA_PLUGIN_EP_DIR}/cuda_plugin_ep_symbols.def) + if(EXISTS ${CUDA_PLUGIN_DEF_FILE}) + target_sources(onnxruntime_providers_cuda_plugin PRIVATE ${CUDA_PLUGIN_DEF_FILE}) + endif() +else() + # Linux/macOS: hide all symbols by default, explicitly export via __attribute__((visibility("default"))) + set_target_properties(onnxruntime_providers_cuda_plugin PROPERTIES + C_VISIBILITY_PRESET hidden + CXX_VISIBILITY_PRESET hidden + ) +endif() + + + +# Set output name +set_target_properties(onnxruntime_providers_cuda_plugin PROPERTIES + OUTPUT_NAME "onnxruntime_providers_cuda_plugin" +) + +# Install +install(TARGETS onnxruntime_providers_cuda_plugin + LIBRARY DESTINATION lib + RUNTIME DESTINATION bin +) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 8137f8b3a2529..cdd67fa55266f 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1090,6 +1090,10 @@ target_include_directories(onnxruntime_test_all PRIVATE ${ONNXRUNTIME_ROOT}/core onnxruntime_apply_test_target_workarounds(onnxruntime_test_all) +if (onnxruntime_USE_CUDA AND onnxruntime_BUILD_CUDA_EP_AS_PLUGIN) + target_compile_definitions(onnxruntime_test_all PRIVATE ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP=1) +endif() + if (MSVC) # The warning means the type of two integral values around a binary operator is narrow than their result. # If we promote the two input values first, it could be more tolerant to integer overflow. @@ -1264,6 +1268,10 @@ block() onnxruntime_apply_test_target_workarounds(onnxruntime_provider_test) onnxruntime_set_plugin_ep_test_environment(onnxruntime_provider_test) + if (onnxruntime_USE_CUDA AND onnxruntime_BUILD_CUDA_EP_AS_PLUGIN) + target_compile_definitions(onnxruntime_provider_test PRIVATE ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP=1) + endif() + # Expose QNN SDK headers to unit tests via an interface target if(onnxruntime_USE_QNN) add_library(qnn_sdk_headers_include INTERFACE) @@ -1474,6 +1482,11 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() else() target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common absl::flags absl::flags_parse ${onnx_test_libs}) + # When onnxruntime_BUILD_SHARED_LIB is OFF (the plugin build path), perf test was missing CUDA include directories and CUDA::cudart linkage. + if (onnxruntime_USE_CUDA OR onnxruntime_USE_NV OR onnxruntime_USE_TENSORRT) + target_include_directories(onnxruntime_perf_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_link_libraries(onnxruntime_perf_test PRIVATE CUDA::cudart) + endif() endif() set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest") diff --git a/docs/cuda_plugin_ep/QUICK_START.md b/docs/cuda_plugin_ep/QUICK_START.md new file mode 100644 index 0000000000000..2a94d895892dc --- /dev/null +++ b/docs/cuda_plugin_ep/QUICK_START.md @@ -0,0 +1,108 @@ +# CUDA Plugin EP Quick Start + +## Build Instructions + +To build ONNX Runtime with the CUDA Plugin Execution Provider instead of the statically linked CUDA EP, use the `--build_cuda_ep_as_plugin` flag with the build script. + +```bash +# Build the core framework and the CUDA Plugin EP +./build.sh --config RelWithDebInfo --build_shared_lib --use_cuda --build_cuda_ep_as_plugin +``` + +## Running + +When the plugin is built, it will produce `libonnxruntime_providers_cuda_plugin.so` (or `.dll` on Windows) in the build output directory alongside `libonnxruntime.so`. + +The plugin EP is registered under the name **`CudaPluginExecutionProvider`** and uses the EP Plugin API (`RegisterExecutionProviderLibrary` / `GetEpDevices` / `SessionOptionsAppendExecutionProvider_V2`). It is **not** a drop-in replacement for the in-tree `CUDAExecutionProvider` — you must register the plugin library, enumerate its devices, and add them to the session. + +### C++ API + +Use `Env::RegisterExecutionProviderLibrary` to load the plugin, `Env::GetEpDevices` to discover the CUDA devices it exposes, and `SessionOptions::AppendExecutionProvider_V2` to add the selected device to the session. + +```cpp +#include "onnxruntime_cxx_api.h" + +Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "PluginTest"); + +// 1. Register the plugin library. +env.RegisterExecutionProviderLibrary("CudaPluginExecutionProvider", + ORT_TSTR("libonnxruntime_providers_cuda_plugin.so")); + +// 2. Enumerate available EP devices and pick the CUDA plugin device. +auto ep_devices = env.GetEpDevices(); +std::vector plugin_devices; +for (const auto& dev : ep_devices) { + if (std::string(dev.EpName()) == "CudaPluginExecutionProvider") { + plugin_devices.push_back(dev); + break; // use the first CUDA plugin device + } +} + +// 3. Add the plugin device to session options. +Ort::SessionOptions session_options; +session_options.AppendExecutionProvider_V2(env, plugin_devices, {}); + +Ort::Session session(env, "model.onnx", session_options); +``` + +### Python API + +Use `onnxruntime.register_execution_provider_library` to load the plugin, `onnxruntime.get_ep_devices` to discover devices, and `SessionOptions.add_provider_for_devices` to add the selected device. + +**Device-based approach (recommended):** + +```python +import onnxruntime as ort + +# 1. Register the plugin library. +ort.register_execution_provider_library( + "CudaPluginExecutionProvider", + "libonnxruntime_providers_cuda_plugin.so", +) + +# 2. Enumerate devices and pick the CUDA plugin device. +devices = ort.get_ep_devices() +plugin_device = next(d for d in devices if d.ep_name == "CudaPluginExecutionProvider") + +# 3. Create session with the plugin device. +sess_options = ort.SessionOptions() +sess_options.add_provider_for_devices([plugin_device], {}) + +sess = ort.InferenceSession("model.onnx", sess_options=sess_options) +``` + +**Provider-name approach:** + +You can also pass `CudaPluginExecutionProvider` by name in the `providers` list +(the plugin library must already be registered): + +```python +import onnxruntime as ort + +ort.register_execution_provider_library( + "CudaPluginExecutionProvider", + "libonnxruntime_providers_cuda_plugin.so", +) + +sess = ort.InferenceSession( + "model.onnx", + providers=[ + ("CudaPluginExecutionProvider", {"device_id": "0"}), + "CPUExecutionProvider", + ], +) +``` + +## Known Limitations +* The plugin does not currently support CUDA Graphs. +* The plugin direct-allocates memory using `cudaMalloc` resulting in a potential performance penalty compared to the integrated Memory Arena. + +## Verification +You can generate a parity report comparing the kernels available in the plugin EP versus the statically linked CUDA EP. +```bash +# Check static source registration parity: +python tools/ci_build/cuda_plugin_parity_report.py + +# Check runtime registry parity: +python tools/ci_build/cuda_plugin_parity_report.py --runtime --plugin-ep-lib build/Linux/RelWithDebInfo/libonnxruntime_providers_cuda_plugin.so +``` diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md new file mode 100644 index 0000000000000..e4e6794b18f94 --- /dev/null +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -0,0 +1,1003 @@ +# CUDA Plugin EP — Design Document + +## 1. Overview + +The CUDA Plugin EP is an alternative build of the ONNX Runtime CUDA Execution Provider that compiles as a standalone shared library (`libonnxruntime_providers_cuda_plugin.so`). It loads at runtime through the ORT EP Plugin API instead of being statically linked into the main runtime binary. + +**Goals:** +- Allow CUDA EP updates independent of ORT core releases +- Support all operators currently supported by the in-tree CUDA EP (tunable ops are low priority) +- Minimize changes to existing CUDA kernel source files + +**Current status:** The plugin build is functional on this branch and the focused plugin validation script (`./cuda_plugin.sh --build --test_plugin`) passes. Most core CUDA kernels now compile in the plugin build; the remaining source-level exclusions are documented in [Section 7](#7-excluded-operators). + +--- + +## 2. Architecture + +### 2.1 Build Targets + +The ORT CUDA build produces four separate libraries: + +| Target | Output | Type | Description | +|--------|--------|------|-------------| +| `onnxruntime_providers` | `libonnxruntime_providers.a` | Static lib | CPU provider + framework ops | +| `onnxruntime_providers_shared` | `libonnxruntime_providers_shared.so` | Shared lib | DLL-boundary bridge for in-tree EPs | +| `onnxruntime_providers_cuda` | `libonnxruntime_providers_cuda.so` | Shared module | In-tree CUDA EP (uses `SHARED_PROVIDER` bridge) | +| `onnxruntime_providers_cuda_plugin` | `libonnxruntime_providers_cuda_plugin.so` | Shared module | Plugin CUDA EP (uses EP API adapters) | + +### 2.2 Preprocessor Defines + +Each build target uses different preprocessor defines that control how framework types are resolved: + +| Define | Set In | Purpose | +|--------|--------|---------| +| `SHARED_PROVIDER` | `onnxruntime_providers_shared`, `onnxruntime_providers_cuda` | Activates the DLL-boundary proxy types in `provider_api.h` | +| `BUILD_CUDA_EP_AS_PLUGIN` | `onnxruntime_providers_cuda_plugin` | Makes `provider_api.h` a no-op; activates plugin-specific code paths | +| `ORT_USE_EP_API_ADAPTERS` | `onnxruntime_providers_cuda_plugin` | Enables the EP adapter type aliases (`ep/adapters.h`) | +| `ORT_API_MANUAL_INIT` | `onnxruntime_providers_cuda_plugin` | Manual ORT API initialization in plugin DLL | + +### 2.3 Class Hierarchy + +``` +OrtEpFactory OrtEp + ↑ ↑ +CudaEpFactory adapter::Ep + │ ↑ + ├─ creates OrtEpDevice CudaEp + ├─ creates CudaSyncStream ├─ stores session-derived Config + ├─ caches kernel registry └─ owns a real shim CUDAExecutionProvider via EpImpl() + ├─ caches stable OrtMemoryInfo objects + └─ maps OrtHardwareDevice* → CUDA ordinal + +Migrated CUDA kernels + └─ use CudaKernel / cuda_kernel_adapter.h + ├─ cache a shared runtime-config handle during construction + ├─ use CudaKernel accessors for provider settings during Compute() + └─ resolve stream-local handles via CudaSyncStream::FromCudaStream() +``` + +Key ownership relationships: +- `CudaEpFactory` implements raw `OrtEpFactory` callbacks and owns shared factory-level state such as the kernel registry, cached `OrtMemoryInfo` instances, and the hardware-device to CUDA-ordinal map. +- `CudaEp` inherits from `ep::adapter::Ep`, which itself derives from `OrtEp` and owns a framework-facing `IExecutionProvider` object. +- The plugin-local `CUDAExecutionProvider` in `cuda_kernel_adapter.h` is a real shim object owned by `ep::adapter::Ep`. It is not the full in-tree CUDA EP, but it has its own object identity and stores plugin-specific members — including the wrapped `OrtEp*` and a provider-owned shared runtime-config object. +- Runtime configuration needed by migrated kernels is stored on that shim provider and exposed to kernels as a cached `shared_ptr`, rather than through a separate global map keyed by the provider address. +- `CudaSyncStream` owns `cudaStream_t`, `cublasHandle_t`, `cudnnHandle_t`, and `cublasLtHandle_t` for each sync stream created through the EP API. + +### 2.4 Plugin DLL Entry Points + +The plugin exports exactly two C symbols: +- `CreateEpFactories()` — called by ORT to create the EP factory +- `ReleaseEpFactory()` — called by ORT to destroy the factory + +All other symbols have hidden visibility. + +--- + +## 3. Type Resolution — How Kernel Code Compiles Unchanged + +The core design principle is that existing CUDA kernel `.cc` files compile in the plugin build with **zero or minimal source changes**. This is achieved through a two-layer force-include mechanism. + +### 3.1 Force-Include Chain + +For every `.cc` file in the plugin build, CMake injects two force-includes before any source code: + +``` +1. ep/adapters.h — adapter type aliasing +2. cuda_kernel_adapter.h — CudaKernel base class, macros, CPU shims +``` + +Note: `.cu` files do NOT receive force-includes (conflicts with CUTLASS/cute). They must include `cuda_kernel_adapter.h` explicitly if needed. + +### 3.2 Adapter Type Aliasing (`ep/adapters.h`) + +`ep/adapters.h` defines `using` aliases in both `onnxruntime::cuda` and `onnxruntime::contrib::cuda` namespaces: + +```cpp +namespace onnxruntime::cuda { + using OpKernel = ep::adapter::OpKernel; + using OpKernelContext = ep::adapter::OpKernelContext; + using OpKernelInfo = ep::adapter::OpKernelInfo; + using KernelRegistry = ep::adapter::KernelRegistry; + using KernelDefBuilder = ep::adapter::KernelDefBuilder; + using DataTransferManager = ep::adapter::DataTransferManager; + // ... etc +} +``` + +When kernel code in `namespace onnxruntime::cuda` references `OpKernelContext`, it resolves to the adapter type instead of the framework type. **No kernel source changes needed.** + +### 3.3 Provider API Bypass + +In the plugin build, `provider_api.h` (normally included from `cuda_common.h`) is a **complete no-op** — it does NOT define `SHARED_PROVIDER`. This means: + +- `#ifndef SHARED_PROVIDER` guards in framework headers remain **active**, exposing real types +- Header-inlined utility methods (see [Section 4](#4-cpu-base-class-helpers)) get their inline bodies +- The `ProviderHostCPU` virtual table bridge is bypassed entirely + +### 3.4 Kernel Adapter (`cuda_kernel_adapter.h`) + +This 1100+ line header provides everything CUDA kernels need that would normally come from framework infrastructure: + +| Section | What It Provides | +|---------|-----------------| +| Error macros | `CUDA_RETURN_IF_ERROR`, `CUBLAS_RETURN_IF_ERROR`, `CUDNN_RETURN_IF_ERROR`, `CUFFT_RETURN_IF_ERROR` | +| Type mappings | `ToCudaType::MappedType = half`, etc. | +| CudaKernel base | Scratch buffers, handle access, `Stream()`, `GetComputeStream()` | +| Kernel registration | Self-registering `ONNX_OPERATOR_*_KERNEL_EX` macro overrides via `PluginKernelCollector`, including `ONNX_OPERATOR_TWO_TYPED_KERNEL_EX`, `ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX`, and `ONNX_OPERATOR_THREE_TYPED_KERNEL_EX` variants | +| CPU shims | Lightweight reimplementations of CPU helpers not linked into plugin | +| Math helpers | `HalfGemmOptions`, `CublasMathModeSetter` | +| Stream shim | `OrtStreamAdapter`/`PluginStreamShim` to present a framework-compatible `Stream*` view over a raw `cudaStream_t` where needed | + +### 3.5 Kernel Registration + +In the in-tree build, kernels register through centralized tables (`cuda_nhwc_kernels.cc`, `cuda_contrib_kernels.cc`). In the plugin build, the `ONNX_OPERATOR_*_KERNEL_EX` macros are overridden to auto-register each kernel into the `PluginKernelCollector` singleton at static initialization time: + +```cpp +// Macro override generates: +// 1. BuildKernelCreateInfo() function +// 2. Static PluginKernelCollector::Register() call + +// At plugin startup, CreateCudaKernelRegistry() iterates the collector +// and registers each kernel into an adapter::KernelRegistry. +``` + +#### 3.5.1 Type Constraint Names and OpSchema Access + +Every kernel registration includes type constraint names — string literals such as `"T"`, `"T1"`, `"U"` — that must exactly match the formal parameter type-constraint strings defined in the ONNX operator schema. In the current plugin build, these names are **hard-coded** at compile time with no runtime validation against the actual schema. If a constraint name is wrong, kernel matching silently fails during `GetCapability`. + +PR #27713 adds `OrtEpApi` functions that let plugin EPs query ONNX operator schemas from ORT's global schema registry at runtime (available from ORT 1.25): + +| `OrtEpApi` Function | C++ Wrapper | Purpose | +|---------------------|-------------|----------| +| `GetOpSchema(name, max_ver, domain)` | `Ort::GetOpSchema()` | Look up a schema by op name, max opset version, and domain | +| `OpSchema_GetSinceVersion` | `ConstOpSchema::GetSinceVersion()` | Opset version that introduced this schema entry | +| `OpSchema_GetNumInputs` / `_GetNumOutputs` | `GetNumInputs()` / `GetNumOutputs()` | Formal input/output count | +| `OpSchema_GetInputName` / `_GetOutputName` | `GetInputName(i)` / `GetOutputName(i)` | Formal parameter name | +| `OpSchema_GetInputTypeStr` / `_GetOutputTypeStr` | `GetInputTypeStr(i)` / `GetOutputTypeStr(i)` | Type constraint string (e.g., `"T"`) | +| `OpSchema_HasTypeConstraint` | `HasTypeConstraint(str)` | Whether a string is a valid type constraint name in the schema | + +The returned `OrtOpSchema*` is non-owning — it points into the global `ONNX_NAMESPACE::OpSchemaRegistry` singleton and is valid for the lifetime of the ORT process. + +**Why the plugin cannot link its own ONNX library:** The `OpSchemaRegistry` is a Meyers singleton (`static` local in `Instance()`). Each shared library gets its own copy of that static variable — on Windows each DLL is isolated by default, on macOS two-level namespaces have the same effect, and on Linux behavior depends on `dlopen` flags. Even when isolation doesn't occur, the EP's registry would lack ORT's contrib and internal schemas, and version mismatches between the EP's ONNX library and ORT's vendored copy could cause silent divergence. The `OrtEpApi` route is the only reliable, portable way to query the schemas ORT actually uses. + +**Impact on the CUDA plugin EP:** + +1. **Registration-time validation.** `CreateCudaKernelRegistry()` can optionally validate each registered kernel's type constraint names against the schema after collecting all entries from `PluginKernelCollector`. A mismatch can be logged as a warning (debug builds) or an error, catching drift when ONNX spec updates rename constraint strings. + +2. **NHWC / internal-domain diagnostics.** For rewritten `com.ms.internal.nhwc` nodes, the schema API can confirm that the kernel's registered domain, version range, and constraint names actually match the internal-domain schema entry, improving the diagnostics called for in [Section 5.3.1.3](#5313-nhwc-design-requirements). + +3. **Parity tooling.** `cuda_plugin_parity_report.py` can use the C++ wrapper to compare the plugin's registered constraint names against the schema, flagging incorrect or missing constraints in the parity report. + +4. **Future: schema-driven registration helpers.** A `KernelDefBuilder` helper could derive constraint names automatically from the schema rather than relying on hard-coded strings, reducing the manual maintenance burden when new opset versions change constraint names. See [Section 11.6](#116-if-opschema-access-is-available-schema-validated-type-constraints). + +--- + +## 4. CPU Base Class Helpers — The SHARED_PROVIDER Pattern + +Many CUDA kernels inherit from CPU base classes and call utility methods (e.g., `PadBase::HandleDimValueZero`, `SliceBase::PrepareForCompute`). In the in-tree build, these call across the DLL boundary through `ProviderHostCPU`. The plugin doesn't use this bridge. + +### 4.1 Pattern: Inline in Header + +The primary approach moves pure-computation helpers from CPU `.cc` files to headers: + +```cpp +// In padbase.h: +#ifdef SHARED_PROVIDER + // In-tree build: declaration only, body in ProviderHostCPU bridge + static void HandleDimValueZero(Mode mode, const TensorShape& input_shape, TensorShape& output_shape); +#else + // Plugin build + CPU provider: inline body + static inline void HandleDimValueZero(Mode mode, const TensorShape& input_shape, + TensorShape& output_shape) { + // ... implementation ... + } +#endif +``` + +**Files refactored with this pattern:** +- `padbase.h` — `HandleDimValueZero`, `ComputePads` (delegates to `ComputePadsImpl` template) +- `scatter_nd.h` — `ValidateShapes` +- `split.h` — `PrepareForCompute` +- `tile.h` — `IsTileMemcpy` +- `slice.h` — `PrepareForCompute`, `FlattenOutputDims` +- `cumsum.h` — `cumsum_op::GetAxis` +- `bias_gelu_helper.h` — `bias_gelu_helper::CheckInputs` +- `concatbase.h` — `PrepareForCompute` +- `gatherbase.h` — `PrepareForCompute`/`PrepareForComputeImpl` (template) +- `unsqueeze.h` — `PrepareCompute` +- `embed_layer_norm_helper.h` — `embed_layer_norm::CheckInputs` (templatized on context type) +- `non_max_suppression_helper.h` — `NonMaxSuppressionBaseImpl` template class (new file) +- `attention_base.h` — `AttentionBase::CheckInputs`, `CheckMask`, `GetPresent` (templatized on context type) +- `longformer_attention_base.h` — `LongformerAttentionBase::CheckInputs` +- `roialign.h` — `CheckROIAlignValidInput`, `RoiAlignBase` constructor (templatized on info type) +- `upsamplebase.h` — `UpsampleBase::AdjustOutputSizeAsPolicy` +- `crop.h` — `CropBase` constructor (templatized on info type) +- `space_depth_ops.h` — `SpaceDepthBase` constructor (templatized on info type) +- `clip.h` — Clip min/max attribute handling (removed `Clip_6Base` CPU dependency) +- `cuda_common_type_helpers.h` — CUDA type conversion and handle error string helpers (moved from `cuda_common.cc`) + +### 4.2 Pattern: Template Methods + +For methods that take `OpKernelContext&` (which differs between plugin and in-tree builds), template versions accept any context type: + +```cpp +// In padbase.h: +template +static void ComputePadsImpl(KernelContextType& ctx, size_t data_rank, + gsl::span pads_data, PadsVector& pads) { ... } +``` + +The CUDA kernel calls `PadBase::ComputePadsImpl(*ctx, ...)` directly, avoiding the `OpKernelContext&` type mismatch. + +The same pattern is applied to constructors that receive `OpKernelInfo`: + +```cpp +// In roialign.h: +template +RoiAlignBase(const TKernelInfo& info) { + info.template GetAttr("mode", &mode_string); + info.template GetAttr("output_height", &output_height_); + // ... +} +``` + +This allows the base class constructor to work with both the framework `OpKernelInfo` and the plugin adapter's `OpKernelInfo`. Applied to: `RoiAlignBase`, `CropBase`, `SpaceDepthBase` (#27628). + +### 4.3 Files That Cannot Be Inlined + +Some CPU base classes have heavy dependencies (protobuf, `UnpackTensor`) that make inlining impractical: + +- **`ConstantOfShapeBase`** — depends on `TensorProto` and `UnpackTensor`. The plugin path in `constant_of_shape.h` stays self-contained: it reuses `ConstantOfShapeCore` but fetches the `value` attribute through the ORT C++ API instead of depending on the full CPU base implementation. +- **`UpsampleBase`** — partially addressed: `AdjustOutputSizeAsPolicy` moved to header (#27628). Still depends on `InputDefs()` and `OpKernelInfo::GetAllocator()` which are not in the adapter. + +--- + +## 5. Handle and Stream Management + +### 5.1 Stream Ownership + +`CudaSyncStream` is the plugin's CUDA sync-stream implementation: +- Owns `cudaStream_t`, `cublasHandle_t`, `cudnnHandle_t`, `cublasLtHandle_t` +- Is created by `CudaEpFactory::CreateSyncStreamForDevice` +- Registers itself in a global `cudaStream_t -> CudaSyncStream*` map so migrated kernels can recover per-stream handles from a raw CUDA stream +- Defers host-buffer cleanup until `OnSessionRunEnd()` after the stream is synchronized + +### 5.2 Handle Access Path + +``` +CudaKernel::GetCublasHandle(OpKernelContext* ctx) + → Stream(ctx) // raw cudaStream_t from adapter ctx + → CudaSyncStream::FromCudaStream() // global stream map + TLS cache + → sync_stream->GetCublasHandle() +``` + +The stream lookup path uses a thread-local last-hit cache plus a generation counter so destroyed streams invalidate stale TLS entries without requiring per-thread cleanup. + +For code paths that need handles without an active stream, `cuda_kernel_adapter.h` also provides thread-local default cuBLAS/cuDNN handles via `GetDefaultCudaHandlesForDevice(device_id)`. + +### 5.3 Provider Access + +Kernels still discover the shim provider through the pointer returned by `info.GetExecutionProvider()` at construction time. In the plugin build, `ep::adapter::OpKernelInfo` snapshots three related pointers from the framework `OpKernelInfo` when the kernel is created: + +- the session-owned outer `PluginExecutionProvider` +- the wrapped plugin `OrtEp` / `CudaEp` +- the inner shim provider returned by `static_cast(ort_ep)->EpImpl()` + +`OpKernelInfo::GetExecutionProvider()` then returns that cached shim pointer, so migrated kernels receive the real shim `CUDAExecutionProvider` object owned by `ep::adapter::Ep`, not the outer `PluginExecutionProvider` and not the `OrtEp`/`CudaEp` object reinterpreted at the same address. + +Caching the shim pointer at kernel-creation time is important for the NHWC path. Re-querying `OrtKernelInfo -> OrtEp -> EpImpl()` during execution was fragile after layout transformation. The current implementation resolves the shim once in the `CudaKernel` constructor, caches a `shared_ptr`, and routes later provider-setting reads through `CudaKernel` accessors instead of repeated provider-pointer casts during `Compute()`. + +This changes the safety model from the earlier "phantom shim" design: +- The shim no longer needs to remain layout-compatible with `IExecutionProvider`. +- Adding plugin-local members to the shim is safe as long as normal C++ object lifetime/ownership rules are respected. +- The shim still is not the full bundled `CUDAExecutionProvider`; it only exposes the subset of methods that migrated kernels currently need. + +Provider options flow through the plugin in two stages: +- `CudaEpFactory` parses session/provider options into `CudaEp::Config`. +- `CudaEp` copies the subset needed by migrated kernels into the shim provider's runtime config via `SetCudaKernelAdapterRuntimeConfigForProvider(EpImpl(), ...)` during EP construction. + +Because the runtime config is provider-owned and cached by kernels as a shared pointer, there is no global map and no mutex. Today that stored subset includes TF32, device ID/device properties, cuDNN convolution settings, skip-layer-norm strict mode, fused-conv-bias, and SDPA kernel selection. Other plugin behaviors, such as preferred layout, are handled directly by `CudaEp` callbacks instead of through the shim. + +For stream bridging, the preferred helpers are: +- `Stream(ctx)` when the kernel only needs a raw `cudaStream_t` +- `GetComputeStream(ctx)` when the kernel API already accepts the adapter's opaque stream pointer +- `GetOrtStream(ctx)` when framework-style `Stream*` plumbing is still needed by shared helper code + +### 5.3.1 NHWC Layout-Transformation Support + +The bundled CUDA EP's NHWC path is not just a kernel-registration feature. It is a coordinated contract between provider configuration, ORT's layout transformer, kernel registration, adapter/provider access, and graph partitioning. On the current branch, the CUDA plugin EP now supports this path when NHWC is compiled in and the session requests `prefer_nhwc`. + +#### 5.3.1.1 End-to-End Flow + +When NHWC is enabled for an EP, the expected ORT flow is: + +1. The EP reports `NHWC` from `GetPreferredLayout()` (or `OrtEp::GetPreferredDataLayout()` for plugins) when `prefer_nhwc` is enabled. +2. During layout transformation, ORT asks the EP whether each layout-sensitive op should be converted via `ShouldConvertDataLayoutForOp()`. +3. For each accepted op, `TransformLayoutForEP()` inserts `Transpose` nodes around the operator and rewrites the operator into the internal NHWC domain (`com.ms.internal.nhwc`). +4. Graph partitioning runs again. The EP must now claim the rewritten internal-domain nodes, not the original ONNX-domain nodes. +5. Kernel lookup must succeed against the EP's kernel registry for the rewritten internal-domain node, with matching domain, opset range, and type constraints. + +This means the plugin must satisfy two distinct contracts at the same time: + +| Contract | Owner | Requirement | +|----------|-------|-------------| +| Layout preference contract | `CudaEp` + ORT plugin bridge | Only request NHWC when the plugin can handle the rewritten graph | +| Kernel/capability contract | Kernel registry + `CudaEp::GetCapabilityImpl()` | Claim the resulting `com.ms.internal.nhwc` nodes during partitioning | + +#### 5.3.1.2 Current Plugin Status + +The current branch already has the core runtime pieces in place: + +| Component | Current state | +|-----------|---------------| +| ORT plugin bridge | `PluginExecutionProvider` already maps `OrtEp::GetPreferredDataLayout()` and `OrtEp::ShouldConvertDataLayoutForOp()` into the normal `IExecutionProvider` layout APIs | +| Plugin callback implementations | `CudaEp` installs `GetPreferredDataLayoutImpl()` and `ShouldConvertDataLayoutForOpImpl()` and advertises NHWC when `prefer_nhwc` is enabled | +| Provider option parsing | `CudaEpFactory` already parses `prefer_nhwc` / `prefer_nhwc_layout` into `CudaEp::Config` | +| Build-time gating | `cmake/onnxruntime_providers_cuda_plugin.cmake` propagates `ENABLE_CUDA_NHWC_OPS` to the plugin target when `onnxruntime_USE_CUDA_NHWC_OPS=ON` | +| NHWC kernel registration | NHWC kernels are compiled from the normal CUDA kernel sources and self-register through `PluginKernelCollector`; the centralized `cuda_nhwc_kernels.cc` table stays excluded in plugin builds | +| Second capability pass | `CudaEp::GetCapabilityImpl()` preserves nodes already assigned to `CudaPluginExecutionProvider`, so ORT's post-layout-transformation partitioning pass does not drop rewritten NHWC nodes that were previously selected by the plugin | +| Adapter provider access | `ep::adapter::OpKernelInfo` caches the inner shim `EpImpl()` pointer at kernel-creation time, avoiding a fragile runtime `OrtKernelInfo -> OrtEp -> EpImpl()` round-trip in NHWC kernels | +| Focused validation | `test_cuda_plugin_ep.py` Stage 3 now runs NHWC-requested sessions for Conv, BatchNormalization, MaxPool, and AveragePool and requires plugin-backed execution to succeed numerically | + +The fixes that made this work were not limited to turning the callbacks back on: + +- The plugin now keeps both newly discovered candidate nodes and nodes already assigned to `CudaPluginExecutionProvider` during the second `GetCapability()` pass that runs after layout transformation. +- NHWC kernels now obtain provider configuration through the cached shim pointer in `ep::adapter::OpKernelInfo`, which removed a runtime crash path in migrated kernels such as NHWC `Conv`. + +With those pieces in place, NHWC-requested sessions take the real plugin execution path rather than silently falling back to the stable NCHW path. + +#### 5.3.1.3 NHWC Design Requirements + +The implementation should preserve the following invariants: + +| Requirement | Why it matters | +|-------------|----------------| +| The plugin must never advertise NHWC unless it can claim every internal-domain op it requests ORT to generate | Otherwise ORT can create an invalid graph containing `com.ms.internal.nhwc` nodes that no EP selects | +| The NHWC conversion allowlist must come from a single shared source of truth | The bundled CUDA EP and the plugin EP must not drift on which ops are safe to rewrite | +| Kernel coverage checks must validate internal-domain registrations, not just original ONNX-domain registrations | The rewritten graph uses `com.ms.internal.nhwc`, so ONNX-domain coverage alone is insufficient | +| Capability diagnostics must identify internal-domain kernel misses clearly | NHWC failures are difficult to debug after rewrite unless the missing domain/op/version/type information is surfaced | +| Tests must verify plugin-backed NHWC execution explicitly | Output correctness alone is not enough because a fallback path can still pass numerically | + +#### 5.3.1.4 Implemented Design and Remaining Follow-Ups + +The current branch has already landed the minimum runtime fixes required for plugin-side NHWC execution. The remaining work is mostly cleanup, consolidation, and stronger diagnostics. + +**A. Keep partitioning registry-driven and preserve pre-assigned NHWC nodes** + +`CudaEp::GetCapabilityImpl()` should continue to rely on `EpGraphSupportInfo_LookUpKernel()` as the source of truth for whether a rewritten node is supported. The important implementation detail is that it must preserve nodes already assigned to the plugin when ORT reruns partitioning after layout transformation. + +That behavior is now implemented by tracking: +- `tentative_nodes`: newly discovered nodes with matching kernel registrations +- `candidate_nodes`: both tentative nodes and nodes already assigned to `CudaPluginExecutionProvider` + +The final support set is chosen from `candidate_nodes`, with the existing CPU-preferred-node filtering applied only where appropriate. + +**B. Cache the shim provider pointer at kernel creation** + +Migrated CUDA kernels expect `info.GetExecutionProvider()` to return the shim `CUDAExecutionProvider`, not the outer `PluginExecutionProvider`. The adapter now resolves that relationship once during kernel creation, captures the shim provider's runtime-config object, and uses `CudaKernel` accessors for later provider-setting reads. + +This is especially important for NHWC kernels because layout transformation introduces additional runtime paths before the actual CUDA kernel executes. Repeatedly reconstructing provider access from `OrtKernelInfo` during execution proved fragile in that path. The cached-config approach keeps provider access deterministic and matches the actual object model: + +- outer session EP: `PluginExecutionProvider` +- wrapped plugin object: `CudaEp` / `ep::adapter::Ep` +- inner shim: `CUDAExecutionProvider` returned by `EpImpl()` + +**C. Remaining follow-ups** + +The main follow-ups are now design quality items rather than blockers: + +- Unify the NHWC conversion allowlist between the bundled CUDA EP and the plugin CUDA EP instead of keeping separate hard-coded tables. +- Improve diagnostics when kernel lookup fails for a rewritten `com.ms.internal.nhwc` node. +- Extend tests to assert internal-domain rewrite structure directly, not just plugin-backed execution and numerical correctness. + +#### 5.3.1.5 Rollout Status + +The NHWC rollout is effectively in a "runtime enabled, cleanup remaining" state: + +| Phase | Change | Expected outcome | +|-------|--------|------------------| +| 1 | Enable plugin NHWC callbacks and preserve pre-assigned nodes in the second capability pass | Completed on the current branch | +| 2 | Cache the shim provider pointer in the adapter `OpKernelInfo` | Completed on the current branch; fixes the observed NHWC runtime crash | +| 3 | Consolidate allowlists, improve internal-domain diagnostics, and strengthen structural NHWC assertions | Recommended follow-up work | + +### 5.4 CUDA Graph Support + +#### 5.4.1 How CUDA Graph Works in Bundled CUDA EP + +CUDA Graph capture/replay in ORT is a **cooperative protocol** between the ORT session framework and the execution provider. Understanding this protocol is critical for the plugin EP design. + +**Session-level orchestration** (`inference_session.cc`): + +1. During session initialization, if an EP reports `IsGraphCaptureEnabled() == true` and all graph nodes are assigned to that EP (plus allowed CPU shape nodes), the session caches a pointer to the EP in `cached_execution_provider_for_graph_replay_`. + +2. At `Run()` time, the session checks `cached_execution_provider_for_graph_replay_.IsGraphCaptured(annotation_id)`: + - **If captured**: The session **skips the entire kernel dispatch pipeline** — no `OnRunStart`, no executor, no `OnRunEnd` — and calls `ReplayGraph(annotation_id)` directly. This is the fast path. + - **If not yet captured**: The session runs the normal kernel dispatch pipeline (including `OnRunStart` → executor → `OnRunEnd`), which allows the EP to manage warm-up counting and trigger capture. + +3. After each normal run, the session checks if graph capture is enabled but not yet captured, and **recursively calls `Run()`** to accumulate the required warm-up runs and trigger capture — so from the user's perspective, a single `Run()` call handles the entire warm-up + capture sequence. + +**EP-level capture** (`CUDAExecutionProvider`): + +- `OnRunStart()`: If warm-up is complete and graph not yet captured, calls `cudaStreamBeginCapture()`. +- `OnRunEnd()`: If capturing, calls `cudaStreamEndCapture()` + `cudaGraphInstantiate()` + first `Replay()` (since captured kernels don't execute on GPU during capture). +- `IsGraphCaptureEnabled()`: Returns `true` if `enable_cuda_graph` provider option is set. +- `IsGraphCaptured(annotation_id)`: Returns `true` if a graph has been captured for this annotation. +- `ReplayGraph(annotation_id)`: Calls `cudaGraphLaunch()` for the stored `cudaGraphExec_t`. + +The key insight is that the **session-level replay bypass** (`ReplayGraph()` without kernel dispatch) is what makes CUDA Graph efficient. Without it, the EP can capture a graph but can never replay it efficiently — kernels would still be dispatched by the executor on every run. + +``` +Session::Run() + ├── [Graph captured?] ──YES──→ ep->ReplayGraph(id) ──→ return ← FAST PATH + │ + └── [Not captured] ──→ OnRunStart() → executor dispatches kernels → OnRunEnd() + │ │ + │ (EP begins cudaStreamBeginCapture) │ (EP ends capture, first replay) + │ │ + └──────── Session recurses if warmup needed ─┘ +``` + +#### 5.4.2 Current Plugin EP Behavior — API Gap + +The `OrtEp` C API (`onnxruntime_ep_c_api.h`) still does not include: +- `IsGraphCaptureEnabled()` +- `IsGraphCaptured(annotation_id)` +- `ReplayGraph(annotation_id)` + +The current plugin EP does not implement CUDA graph callbacks at all: `CudaEp` sets `OnRunStart = nullptr` and `OnRunEnd = nullptr`, and the previously proposed graph-specific plugin files are not part of the branch. As a result, CUDA graph support is currently disabled rather than partially implemented. + +#### 5.4.3 Current Branch Design + +Given the API gap, the current branch uses the simplest correct design: + +> **The plugin EP does not manage CUDA graph capture/replay internally.** CUDA graph support remains deferred until the `OrtEp` C API grows the required session-cooperative callbacks. + +**Rationale:** + +1. The `OrtEp` C API has no `IsGraphCaptureEnabled`/`IsGraphCaptured`/`ReplayGraph` callbacks. Without these, the session cannot know that the EP supports graph capture, cannot bypass kernel dispatch for replay, and cannot trigger the recursive warm-up sequence. + +2. The plugin branch intentionally removed graph-specific implementation files instead of keeping an incomplete capture-only path. + +3. The session's graph validation logic (all nodes on one EP, no control flow) is also not triggered without `IsGraphCaptureEnabled()`. + +**Recommended approach:** + +| Option | Description | Effort | Status | +|--------|------------|--------|--------| +| **A. Extend the OrtEp C API** | Add `IsGraphCaptureEnabled`, `IsGraphCaptured`, `ReplayGraph` to `OrtEp`. Update `PluginExecutionProvider` to delegate to these. | Medium — requires ORT core changes | Preferred long-term solution | +| **B. Keep graph support disabled in the plugin EP** | Leave graph files and hooks out of the plugin build until Option A exists. | Small | Current branch behavior | + +**Recommendation**: Keep Option B in place until Option A is available. + +#### 5.4.4 What Needs to Change in ORT Core (Option A) + +To enable full CUDA graph support for plugin EPs, the `OrtEp` struct needs three new optional callbacks: + +```c +// Proposed additions to OrtEp (onnxruntime_ep_c_api.h) +struct OrtEp { + // ... existing fields ... + + /// Returns true if CUDA graph capture is enabled for this EP. + /// If nullptr, defaults to false. + ORT_API2_STATUS(IsGraphCaptureEnabled, _In_ const OrtEp* this_ptr, _Out_ bool* enabled); + + /// Returns true if a graph has been captured for the given annotation ID. + /// If nullptr, defaults to false. + ORT_API2_STATUS(IsGraphCaptured, _In_ const OrtEp* this_ptr, + _In_ int graph_annotation_id, _Out_ bool* captured); + + /// Replay a previously captured graph. + /// If nullptr, returns OK (no-op). + ORT_API2_STATUS(ReplayGraph, _In_ OrtEp* this_ptr, _In_ int graph_annotation_id); +}; +``` + +The `PluginExecutionProvider` bridge would then delegate these to the plugin: + +```cpp +// In ep_plugin_provider_interfaces.cc +bool PluginExecutionProvider::IsGraphCaptureEnabled() const { + if (ort_ep_->IsGraphCaptureEnabled == nullptr) return false; + bool enabled = false; + auto* status = ort_ep_->IsGraphCaptureEnabled(ort_ep_.get(), &enabled); + // handle status... + return enabled; +} +``` + +This would plug into the existing `cached_execution_provider_for_graph_replay_` mechanism in `InferenceSession` with no other session-level changes needed. + +#### 5.4.5 Current State + +| Component | Status | Notes | +|-----------|--------|-------| +| `cuda_graph_plugin.h/.cc` | **Removed** | Not present in the current branch. | +| `CudaEp::OnRunStart` / `OnRunEnd` | **Disabled** | `CudaEp` installs `nullptr` for both callbacks. | +| Session-level replay bypass | **Unavailable** | `OrtEp` API still lacks `IsGraphCaptureEnabled`/`IsGraphCaptured`/`ReplayGraph`. | +| Tests | Not included | The plugin test script has no CUDA graph stage. | + +**Action items:** +1. Keep CUDA graph support disabled in the plugin build until the `OrtEp` C API grows the required replay hooks. +2. Add `IsGraphCaptureEnabled`/`IsGraphCaptured`/`ReplayGraph` to the `OrtEp` C API. +3. Reintroduce plugin-side graph management only after the public API is capable of session-cooperative replay. + +--- + +## 6. EP Adapter Layer (`include/onnxruntime/ep/adapter/`) + +The adapter layer provides thin wrappers around the ORT C API that present a C++ interface matching the framework types: + +| Adapter Class | Wraps | Key Methods | +|---------------|-------|-------------| +| `OpKernel` | `OrtKernelImpl` | `Compute()`, `PrePack()` | +| `OpKernelContext` | `OrtKernelContext` | `Input()`, `Output()`, `InputCount()`, `GetGPUComputeStream()`, `GetComputeStream()` | +| `OpKernelInfo` | `Ort::ConstKernelInfo` | `GetAttr()`, `GetExecutionProvider()`, `TryGetConstantInput()`, `GetDataTransferManager()` | +| `KernelRegistry` | `Ort::KernelRegistry` | `Register(KernelCreateInfo&&)` | +| `KernelDefBuilder` | `Ort::KernelDefBuilder` | `TypeConstraint()`, `InputMemoryType()`, `SetName()` | +| `Ep` | `OrtEp` | `EpImpl()`, allocators, data transfer | +| `Logger` | Plugin logger | Logging interface | +| `DataTransferManager` | `IDataTransfer` | `CopyTensor()` | +| `ConstOpSchema` | `const OrtOpSchema*` | `GetSinceVersion()`, `GetNumInputs()`, `GetInputName()`, `GetInputTypeStr()`, `HasTypeConstraint()` (ORT ≥ 1.25, PR #27713) | + +--- + +## 7. Excluded Operators + +Section 7 reflects the current source exclusions in `cmake/onnxruntime_providers_cuda_plugin.cmake`, plus the small set of intentionally out-of-scope directories. This is the source of truth for what the plugin build omits today. + +### 7.1 Infrastructure (Permanently Excluded — Replaced by Plugin Equivalents) + +| File | Reason | +|------|--------| +| `cuda_execution_provider.cc` | Replaced by `cuda_ep.h/.cc` and the plugin adapter/runtime shim | +| `cuda_provider_factory.cc` | Replaced by `cuda_ep_factory.cc` | +| `cuda_provider_interface.cc` | Not needed in plugin architecture | +| `cuda_stream_handle.cc` | Replaced by `cuda_stream_plugin.cc` | +| `cuda_execution_provider_info.cc` | Config parsed directly in `CudaEp::Config` | +| `cuda_graph.cc` | CUDA graph support deferred (files removed pending OrtEp API extension) | +| `cuda_mempool_arena.cc` | Plugin uses `cudaMalloc`/`cudaFree` directly | +| `cuda_common.cc` | Utility functions shimmed in `cuda_kernel_adapter.h` | +| `cuda_nhwc_kernels.cc` | Replaced by `PluginKernelCollector` auto-registration | +| `cuda_contrib_kernels.cc` | Replaced by `PluginKernelCollector` auto-registration | + +### 7.2 Pure CPU Ops (Permanently Excluded) + +| File | Reason | +|------|--------| +| `tensor/size.cc` | Pure CPU op, handled by `GetCpuPreferredNodes` | +| `tensor/shape_op.cc` | Pure CPU op, inherits from `onnxruntime::OpKernel` (framework) | + +### 7.3 Additional Current Source Exclusions + +| File / Pattern | Why It Is Excluded Today | What Would Unblock It | +|----------------|--------------------------|------------------------| +| `core/providers/cuda/controlflow/*` | The framework controlflow kernels inherit from CPU-side controlflow bases (`If`, `Loop`, `Scan`) and are intentionally omitted from the plugin source list | No change is currently planned. The plugin uses its own `cuda_controlflow_plugin.cc` wrappers instead of these framework sources | +| `tunable/*` | Depends on the real `CudaTuningContext` and other framework CUDA EP infrastructure that is not available in the plugin build | Add a plugin-capable tuning context and remove the remaining framework-only tunable dependencies | +| `tensor/sequence_op.cc` | Uses `TensorSeq`, which is still not adapter-safe here | Add `TensorSeq` adapter coverage | +| `contrib_ops/cuda/llm/*` | Contrib LLM sources have not gone through the same adapter-migration pass as the core CUDA LLM kernels | Finish contrib-LLM-specific adapter work | +| `contrib_ops/cuda/tensor/shrunken_gather.cc` | The training-side header path still depends on framework/provider API wiring | Low-priority training-specific adapter work | + +| `contrib_ops/cuda/transformers/*` | Beam search, greedy search, and sampling depend on broader framework/subgraph integration that has not been adapted for the plugin build | Significant adapter and subgraph support work | +| `contrib_ops/cuda/aten_ops/*` | ATen interop is intentionally out of scope for the standalone CUDA plugin build | A separate ATen/plugin strategy | +| `contrib_ops/cuda/collective/*` | Collective/NCCL support is intentionally out of scope for the standalone CUDA plugin build | A separate collective/NCCL plugin strategy | + +### 7.4 Common Exclusion Themes + +The current exclusions fall into a few categories: + +1. **Tunable/framework-dependent infrastructure** — `tunable/*`, contrib transformers, and some contrib LLM paths still rely on framework-only execution-provider services. + +2. **Remaining adapter gaps** — `TensorSeq` (needed for `sequence_op.cc`) and contrib-LLM-specific plumbing still need dedicated adapter work. + +3. **Deliberate scope cuts** — ATen and collective/NCCL sources remain intentionally out of scope for the standalone CUDA plugin. + +--- + +## 8. Remaining `#ifdef` Guards in Kernel Code + +The branch still contains a small set of plugin guards in both infrastructure and operator code. The important pattern has not changed: + +- Infrastructure files such as `cuda_kernel.h`, `cuda_common.h`, and `cudnn_common.h` still need build-mode gates. +- `generator/constant_of_shape.h` still needs a plugin-specific path because `ConstantOfShapeBase` depends on framework-only tensor-attribute helpers. +- Tunable kernels such as `math/matmul.cc` still gate framework-only registration paths. +- `tensor/identity_op.h` guards the `TensorSeq` code path and `context->InputType()` call with `#ifndef BUILD_CUDA_EP_AS_PLUGIN` — the plugin build handles only the `Tensor` path. `identity_op.cc` uses conditional macros (`IDENTITY_V_TYPES` / `IDENTITY_V_TYPES_IRv9`) so opset 14+ registrations use `AllFixedSizeTensorTypes()` in the plugin build. Additionally, old Dropout opset 7–9 and 10–11 kernel registrations were moved from `identity_op.cc` to `nn/dropout.cc` so that each op's registrations live in that op's own source file. +- A few tensor kernels (`pad.cc`, `tile.cc`, `unsqueeze.cc`, `upsample.*`, `space_depth_ops.h`, `scatter_nd.*`) still contain localized plugin guards where adapter and framework paths have not fully converged. + +The broad trend remains positive: most operator-level plugin conditionals were removed by moving reusable CPU/helper logic into shared headers and by centralizing stream bridging in `CudaKernel` helpers. + +--- + +## 9. Building + +### 9.1 CMake Flag + +The plugin is enabled by setting `onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON`: + +```bash +sh build.sh --config Release --build_dir build/cuda --parallel --use_cuda \ + --cuda_version 12.8 --cuda_home /path/to/cuda \ + --cudnn_home /path/to/cudnn \ + --build_wheel --skip_tests \ + --cmake_generator Ninja \ + --enable_cuda_nhwc_ops \ + --cmake_extra_defines onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES="90" +``` + +### 9.2 Impact on Other Build Targets + +The `onnxruntime_BUILD_CUDA_EP_AS_PLUGIN=ON` flag has **no impact** on `libonnxruntime_providers_cuda.so` or `libonnxruntime_providers_shared.so`. It only: + +1. Adds the `onnxruntime_providers_cuda_plugin` target (producing `libonnxruntime_providers_cuda_plugin.so`) +2. Appends `"cuda-plugin-ep=1"` to the build info string (cosmetic) + +The in-tree CUDA EP and shared provider bridge are compiled identically regardless of this flag. A single build with the flag ON produces all four libraries — there is no need for separate build scripts or build directories. + +### 9.3 Plugin Independence + +`libonnxruntime_providers_cuda_plugin.so` is **fully self-contained**. It does not depend on `libonnxruntime_providers_cuda.so` or `libonnxruntime_providers_shared.so` at load time. It statically links against `onnxruntime_framework`, `onnxruntime_graph`, `onnxruntime_common`, `onnxruntime_mlas`, `onnxruntime_flatbuffers`, and links dynamically against CUDA (`cudart`, `cublas`, `cublasLt`, `cufft`), cuDNN, and protobuf. Communication with the ORT runtime happens exclusively through the C API (`OrtApi`/`OrtEpApi`) passed at load time. + +### 9.4 Build Outputs + +After a successful build with the plugin flag ON, `build/cuda/Release/` contains: + +| File | Description | +|------|-------------| +| `libonnxruntime_providers.a` | CPU provider (static, linked into main binary) | +| `libonnxruntime_providers_shared.so` | Shared provider bridge (for in-tree CUDA EP) | +| `libonnxruntime_providers_cuda.so` | In-tree CUDA EP (uses shared provider bridge) | +| `libonnxruntime_providers_cuda_plugin.so` | Plugin CUDA EP (standalone, uses C API) | + +### 9.5 Deployment + +To use the plugin EP, copy the `.so` to the ORT Python package's `capi/` directory: + +```bash +cp build/cuda/Release/libonnxruntime_providers_cuda_plugin.so \ + $(python -c "import onnxruntime; print(onnxruntime.__path__[0])")/capi/ +``` + +The plugin is then available as `CudaPluginExecutionProvider` in session provider lists. + +--- + +## 10. Testing + +### 10.1 Test Script + +`onnxruntime/test/python/transformers/test_cuda_plugin_ep.py` provides the current focused plugin validation flow: + +| Category | What It Tests | +|----------|---------------| +| Registration | Dynamic loading via `register_execution_provider_library()` and EP device discovery (Add, MatMul, Gemm, Conv) | +| Provider options | Valid option parsing, invalid device rejection, multi-device selection | +| NHWC | NHWC-requested sessions: Conv, BatchNorm, MaxPool, AveragePool. These validate correctness under `prefer_nhwc` and require plugin-backed NHWC execution to succeed; they are the focused regression suite for the plugin NHWC path | +| Tensor ops | Reshape, Split, Concat, Gather, Unsqueeze, Tile, Pad, Slice, Transpose, Cast, Where, Flatten, ArgMax, TopK, Trilu, NonZero | +| Math ops | Softmax, Relu, Sigmoid, Tanh, Einsum (single and batched) | +| Reduce | ReduceMean, ReduceSum | +| Space/depth | SpaceToDepth, DepthToSpace, Upsample | +| Shape ops | CumSum, ConstantOfShape, Resize, Sum (variadic) | +| Normalization | LayerNormalization, InstanceNormalization | +| Conv | ConvTranspose | +| Scatter/gather | GatherND, ScatterElements, OneHot | +| Spatial | GridSample | +| Contrib ops | FastGelu, Gelu, BiasGelu, SkipLayerNorm, BiasDropout, FusedMatMul | +| Dropout | Dropout opset 7 and opset 10 — verifies registrations moved to `dropout.cc` | +| Quantization | DequantizeLinear / QuantizeLinear opset 21 — exercises `TWO_TYPED_KERNEL_EX` adapter macro; MatMulInteger | +| GatherBlockQuantized | Contrib `GatherBlockQuantized` — exercises `THREE_TYPED_KERNEL_EX` adapter macro | +| Identity | Identity opset 13 and opset 25 — re-enabled op with `TensorSeq` path guarded | +| Crop | Crop (opset 1) — previously excluded contrib op, now re-enabled | +| Memcpy | Explicit `MemcpyFromHost` and `MemcpyToHost` standalone tests to ensure copy ops are dispatched | +| Key-ops probe | Session-based probing that all key ops are assigned to `CudaPluginExecutionProvider` | + +### 10.2 Running Tests + +After building and deploying the plugin (see [Section 9.5](#95-deployment)): + +```bash +# Run tests from /tmp to avoid module shadowing +cd /tmp +python /path/to/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py +``` + +The current branch has been validated with `./cuda_plugin.sh --build --test_plugin`, which runs this script against the locally built plugin library. + +### 10.3 Parity Report + +`tools/ci_build/cuda_plugin_parity_report.py` generates both static and runtime parity reports: + +- **Static mode** (default): Parses CMake exclusion patterns and kernel registration macros from source to compare what the plugin build includes vs. the bundled CUDA EP. + ```bash + python tools/ci_build/cuda_plugin_parity_report.py + ``` +- **Runtime mode** (`--runtime`): Uses the pybind `get_registered_ep_kernel_defs()` API (added in `onnxruntime_pybind_schema.cc`) to query actual kernel registries from both the bundled and plugin EPs, providing an accurate comparison of registered op/domain/version/type-constraint coverage. + ```bash + python tools/ci_build/cuda_plugin_parity_report.py --runtime [--plugin-ep-lib /path/to/plugin.so] + ``` + +The runtime API (`get_registered_ep_kernel_defs(ep_name)`) creates a temporary EP factory and EP instance for the named EP, iterates its kernel registry, and returns `KernelDef` objects with `op_name`, `domain`, `version_range`, `provider`, and `type_constraints` fields. + +--- + +## 11. How to Add a New Kernel to the Plugin + +### 11.1 If the kernel compiles as-is + +Most kernels that don't use `GetComputeStream()` (returning `Stream*`) or inherit from excluded CPU base classes will compile without changes. The force-include mechanism handles type resolution automatically. + +Just verify it's not in the exclusion list in `cmake/onnxruntime_providers_cuda_plugin.cmake`. + +### 11.2 If the kernel calls a CPU base class helper + +Apply the inline-header pattern: + +1. Move the helper implementation from the CPU `.cc` file to the `.h` file +2. Wrap with `#ifdef SHARED_PROVIDER` (declaration only) / `#else` (inline body) +3. In the CUDA kernel, call the base class method directly (remove any local wrappers) +4. Verify all 4 build targets compile + +Example from `cumsum.h`: +```cpp +namespace cumsum_op { +#ifdef SHARED_PROVIDER +Status GetAxis(const Tensor* axis_tensor, int64_t input_rank, int64_t& axis_out); +#else +inline Status GetAxis(const Tensor* axis_tensor, int64_t input_rank, int64_t& axis_out) { + // ... implementation ... +} +#endif +} +``` + +### 11.3 If the helper takes OpKernelContext& + +Use a template version that accepts any context type: + +```cpp +template +static void ComputePadsImpl(KernelContextType& ctx, ...) { ... } +``` + +The CUDA kernel calls `ComputePadsImpl(*ctx, ...)` directly. + +### 11.4 If the kernel uses stream helpers + +Prefer the shared helpers in `CudaKernel` instead of introducing new plugin-only stream shims: + +- If the code only needs a raw CUDA stream, use `Stream(ctx)`. +- If the shared helper API accepts the adapter's opaque stream handle, use `GetComputeStream(ctx)`. +- If framework-style helper code still expects `onnxruntime::Stream*`, use `GetOrtStream(ctx)`. +- Prefer `GetCublasHandle(ctx)`, `GetCudnnHandle(ctx)`, and `GetCublasLtHandle(ctx)` over re-discovering handles from the stream manually. + +### 11.5 If the kernel uses handle accessors + +Use the plugin-compatible overloads already in `CudaKernel`: + +```cpp +// Instead of: GetCublasHandle(ctx->GetComputeStream()) +// Use: GetCublasHandle(ctx) // or GetCublasHandle(Stream(ctx)) +``` + +### 11.6 If OpSchema access is available — schema-validated type constraints + +With the `OrtEpApi` OpSchema APIs (ORT ≥ 1.25, PR #27713), the plugin can validate or derive type constraint names at kernel registration time rather than relying solely on hard-coded strings. + +#### 11.6.1 Validation mode (recommended first step) + +Add a debug-mode validation pass in `CreateCudaKernelRegistry()` that runs after all kernels are collected from `PluginKernelCollector`. For each registered kernel, look up its `OrtOpSchema` and verify that every type constraint name used in the `KernelDef` actually appears in the schema's type constraint map: + +```cpp +// In CreateCudaKernelRegistry(), after building the registry: +#ifndef NDEBUG +for (auto build_fn : entries) { + auto info = build_fn(); + if (info.kernel_def == nullptr) continue; + + // Retrieve the op name, domain, and since_version from the KernelDef. + const char* op_name = info.kernel_def->GetOpName(); + const char* domain = info.kernel_def->GetDomain(); + int since_version = info.kernel_def->GetSinceVersion(); + + // Look up the ONNX schema from ORT's global registry. + Ort::ConstOpSchema schema = Ort::GetOpSchema(op_name, since_version, domain); + if (!schema) continue; // contrib/internal ops may not have an ONNX schema + + // Validate each type constraint name against the schema. + for (const auto& [constraint_name, types] : info.kernel_def->GetTypeConstraints()) { + if (!schema.HasTypeConstraint(constraint_name.c_str())) { + LOGS_DEFAULT(WARNING) << "Plugin kernel " << op_name + << ": type constraint '" << constraint_name + << "' not found in OpSchema (domain=" << domain + << ", version=" << since_version << ")"; + } + } +} +#endif +``` + +This catches hard-to-debug kernel-matching failures caused by constraint name typos or opset version drift. + +#### 11.6.2 Schema-driven constraint helper (future) + +A `KernelDefBuilder` extension could derive constraint names from the schema automatically: + +```cpp +/// Look up the type constraint string for a given input index from the OpSchema. +/// Falls back to the provided default if the schema is not found (e.g., contrib ops). +inline const char* GetInputTypeConstraintName( + const char* op_name, int opset_version, const char* domain, size_t input_index, + const char* fallback = "T") { + Ort::ConstOpSchema schema = Ort::GetOpSchema(op_name, opset_version, domain); + if (!schema || input_index >= schema.GetNumInputs()) return fallback; + // Cache the result to avoid repeated lookups for typed kernel variants. + static thread_local std::string cached_name; + cached_name = schema.GetInputTypeStr(input_index); + return cached_name.c_str(); +} +``` + +This is a quality-of-life improvement rather than a required change — the existing hard-coded constraint names are correct for all currently registered kernels. + +--- + +## 12. File Layout + +``` +onnxruntime/core/providers/cuda/plugin/ +├── cuda_kernel_adapter.h # CudaKernel base, macros, CPU shims (force-included) +├── cuda_ep.h / .cc # CudaEp : OrtEp implementation +├── cuda_ep_factory.h / .cc # CudaEpFactory : OrtEpFactory +├── cuda_plugin_ep.cc # DLL entry points (CreateEpFactories/ReleaseEpFactory) +├── cuda_plugin_ep_symbols.def # Windows DLL export definitions +├── cuda_plugin_kernels.h / .cu # Kernel registry creation +├── cuda_stream_plugin.h / .cc # CudaSyncStream (handles, notifications) +├── cuda_allocator_plugin.h / .cc # Device/pinned allocators +├── cuda_data_transfer_plugin.h / .cc # GPU↔CPU data transfer +├── cuda_memcpy_plugin.cc # MemcpyFromHost/MemcpyToHost standalone kernels +├── cuda_controlflow_plugin.h / .cc / .cu # If/Loop/Scan wrappers +├── cuda_plugin_utils.h # Common macros, error handling +└── provider_api_shims.cc # Reimplemented utility functions + +include/onnxruntime/ep/ +├── README.md # EP adapter layer overview +├── adapters.h # Master include + type aliasing (force-included) +├── api.h # ORT C API includes +├── common.h # EP common utilities +├── get_capability_utils.h # GetCapability helper utilities +└── adapter/ + ├── allocator.h # IAllocator adapter + ├── data_transfer_manager.h # DataTransferManager adapter + ├── ep.h # Ep base class (wraps IExecutionProvider) + ├── kernel_def.h # KernelDef adapter + ├── kernel_def_builder.h # KernelDefBuilder adapter + ├── kernel_registry.h # KernelRegistry adapter + ├── logging.h # Logger adapter + ├── node.h # Node adapter + ├── op_kernel.h # OpKernel + OpKernelContext adapters + ├── op_kernel_info.h # OpKernelInfo adapter + └── tensor_helper.h # Tensor creation from C API values +``` + +--- + +## 13. Future Work + +1. **Memory arena / allocator parity** — The plugin currently relies on direct `cudaMalloc`/`cudaFree` in `CudaDeviceAllocator` instead of an arena-backed allocator. Two complementary improvements are planned: + + **A. `CudaMempoolArena` (commit e6023b0c)** + + The in-tree CUDA EP gained a native-CUDA-mempool allocator (`cuda_mempool_arena.h/.cc`) that uses `cudaMallocFromPoolAsync` / `cudaFreeAsync` on stream-ordered allocation paths, with a configurable `cudaMemPoolAttrReleaseThreshold` to return memory to the device as it becomes idle. Enabling this in the plugin requires: + + 1. **Make `CudaMempoolArena` compilable in the plugin build.** `cuda_mempool_arena.h` currently includes `cuda_stream_handle.h` and `provider_api.h` (both `SHARED_PROVIDER`-only). The only real dependency is resolving the stream framework pointer. When migrating for plugin use, this class can be refactored to accept a raw `cudaStream_t` directly (or an `OrtSyncStream*`), bypassing the internal `stream->GetHandle()` logic. + + 2. **Implement a thin `OrtAllocator` wrapper around `CudaMempoolArena`.** The plugin factory's `CreateAllocatorImpl` returns an `OrtAllocator*`, while `CudaMempoolArena` is an `IArena` / `IAllocator`. A new class (e.g., `CudaMempoolOrtAllocator`) should own a `CudaMempoolArena` instance and forward the `OrtAllocator` callbacks to it: + + | `OrtAllocator` callback | Implementation | + |-------------------------|----------------| + | `Alloc(size)` | `arena_->Alloc(size)` (allocates on the legacy default stream) | + | `Free(ptr)` | `arena_->Free(ptr)` | + | `Reserve(size)` | `arena_->Reserve(size)` | + | `AllocOnStream(size, stream)` | `cudaStream_t cu_stream = (cudaStream_t)api->SyncStream_GetHandle(stream);`
`arena_->AllocWithCudaStream(size, cu_stream);` | + | `GetStats(kvps)` | Populate from `arena_->GetStats()` | + | `Info()` | Return the `OrtMemoryInfo*` used at construction | + + The `OrtAllocator` C API already supports stream-aware allocation via the optional `AllocOnStream` callback (set on `OrtAllocator` when `version >= kOrtAllocatorAllocOnStreamMinVersion`). ORT core wraps every plugin `OrtAllocator` into `IAllocatorImplWrappingOrtAllocator` (`allocator_adapters.cc`), which dispatches to `AllocOnStream` when the wrapper reports `IsStreamAware() == true`. So there is **no additional plumbing needed in the adapter or framework** — the plugin allocator just needs to set `AllocOnStream` to a non-null function pointer to get full stream-ordered semantics. + + **Important:** The `OrtMemoryInfo::alloc_type` returned by the wrapper must be `OrtDeviceAllocator`, **not** `OrtArenaAllocator`. Both `PluginExecutionProvider::CreatePreferredAllocators()` and `Environment::CreateSharedAllocatorImpl()` explicitly reject `OrtArenaAllocator` from plugin factories — the arena is expected to be opaque to ORT. + + 3. **Parse mempool options.** ORT can pass allocator configuration to the plugin factory through the `allocator_options` (`OrtKeyValuePairs*`) argument of `OrtEpFactory::CreateAllocator`. The relevant keys are defined in `OrtArenaCfg::Keys` (in `allocator.h`): + - `arena.use_cuda_mempool` — set to `"1"` to enable + - `arena.cuda_mempool_release_threshold` — bytes; `0` disables the threshold + - `arena.cuda_mempool_bytes_to_keep_on_shrink` — bytes retained after `Shrink()` + + **How options reach the plugin factory — two paths:** + + | Path | How it calls `CreateAllocator` | `allocator_options` | + |------|-------------------------------|---------------------| + | **Shared allocator** (`OrtApi::CreateSharedAllocator`) | `Environment::CreateSharedAllocatorImpl` → `ep_factory->CreateAllocator(factory, &mem_info, allocator_options, &alloc)` | Caller-provided `OrtKeyValuePairs*` — can carry arena keys | + | **Per-EP allocator** (`PluginExecutionProvider::CreatePreferredAllocators`) | `ep_factory.CreateAllocator(&ep_factory, memory_info, /*options*/ nullptr, &alloc)` | Always `nullptr` today | + + The per-EP path currently passes `nullptr` for options. To support mempool configuration on this path, either: + - **(a)** Parse the arena keys from session options inside `CudaEp` / `CudaEpFactory` (similar to how `CudaEp::Config` already parses other provider options) and store them so `CreateAllocatorImpl` can read them without needing `allocator_options`. + - **(b)** Extend the ORT core per-EP allocator path to forward the config entries to `CreateAllocator` (requires an ORT core change). + + Option (a) is self-contained within the plugin and does not require ORT core changes. + + 4. **Thread the factory logger.** `CudaMempoolArena` takes a `const logging::Logger*`. The plugin factory already owns a logger (`factory.default_logger_` / the `OrtLogger` passed at EP creation). Convert or wrap it and pass it to the arena constructor. + + 5. **Handle `ReleaseAllocatorImpl`.** The factory's `ReleaseAllocatorImpl` switch currently only knows about `CudaDeviceAllocator` and `CudaPinnedAllocator`. Add a third case (`kMempool` or similar) to correctly destroy the new wrapper and its owned `CudaMempoolArena`. + + **B. BFC arena (longer term)** + + If BFC-style arena behavior (`gpu_mem_limit`, `arena_extend_strategy`) is also needed, a similar `OrtAllocator`-wrapping approach would work for `BFCArena`, once its `SHARED_PROVIDER`-only dependencies are removed. The same `AllocOnStream` / `OrtDeviceAllocator` / option-parsing patterns apply. + +2. **Profiling and observability** — The in-tree CUDA EP exposes an EP profiler, while the plugin shim currently does not surface equivalent profiling hooks. Future work should wire up `GetProfiler()` for the plugin path, integrate CUDA/NVTX/CUPTI-based tracing where appropriate, and make plugin execution visible in the same profiling flows users already rely on for the bundled CUDA EP. + +3. **Stream/adapter parity for framework-style `Stream*` consumers** — A number of excluded or recently re-included kernels still assume access to a richer framework `Stream*` object rather than only a raw `cudaStream_t` view. Extending the adapter path here would unblock additional LLM, FFT, quantization, diffusion, and other CUDA kernels. + +4. **Contrib LLM migration pass** — The core CUDA LLM attention path is now adapter-safe, but `contrib_ops/cuda/llm/*` remains excluded as a separate follow-up. + +5. **Tunable ops** — Implement a plugin-side `ITuningContext` and remove the `ORT_USE_EP_API_ADAPTERS` guards in `matmul.cc`/`gemm.cc` so the plugin can recover runtime kernel selection and profiling-based tuning behavior. + +6. **TensorSeq and additional C API coverage** — Add enough sequence/tensor-sequence support to unblock `sequence_op.cc` (the last remaining TensorSeq-dependent file), and extend the ORT C API where needed for remaining framework-style attribute accessors such as string-array attributes used by RNN kernels. Note: `identity_op.cc` is now included in the plugin build — its TensorSeq code path is guarded by `#ifndef BUILD_CUDA_EP_AS_PLUGIN` and opset 14+ registrations use `AllFixedSizeTensorTypes()` (Tensor-only) instead of `AllFixedSizeTensorAndSequenceTensorTypes()`. + +7. **Remaining contrib exclusions** — The FFT (`fft_ops.cc`), crop (`crop.cc`), and dynamicslice (`dynamicslice.cc`) exclusions have been removed. These files now compile in the plugin build: FFT ops use `Stream(context)` (which works in both builds) and the `CUFFT_RETURN_IF_ERROR` macro was added to the adapter; crop and dynamicslice had no real framework blockers once tested. The plugin CMake now links `CUDA::cufft` for cuFFT symbol resolution. Remaining contrib exclusions are: `shrunken_gather.cc` (training), `transformers/*` (subgraph), `aten_ops/*` (ATen), `collective/*` (NCCL), and `llm/*` (contrib LLM pass). + +8. **CI integration and targeted benchmarking** — Add plugin build + test coverage to CI and include perf-oriented validation so allocator, profiling, and tunable-op regressions are caught early. + +9. **NHWC cleanup and hardening** — Complete the follow-up work described in [Section 5.3.1](#531-nhwc-layout-transformation-support): unify the allowlist, improve internal-domain diagnostics, and add stronger structural NHWC assertions. + +10. **CUDA Graph API for plugin EPs** — Add `IsGraphCaptureEnabled`, `IsGraphCaptured`, and `ReplayGraph` callbacks to the `OrtEp` C API (see [Section 5.4.4](#544-what-needs-to-change-in-ort-core-option-a)). This is required for efficient CUDA graph replay in the plugin EP. The capture/replay infrastructure will be reintroduced once the API is extended. + +11. **OpSchema-validated kernel registration (PR #27713)** — PR #27713 adds `OrtEpApi` functions that let plugin EPs query ONNX operator schemas from ORT's global registry (see [Section 3.5.1](#351-type-constraint-names-and-opschema-access)). Concrete follow-up work for the CUDA plugin EP: + + **A. Registration-time validation pass** + + Add a debug/diagnostic pass in `CreateCudaKernelRegistry()` that validates every registered kernel's type constraint names against the schema. This is the highest-value, lowest-risk change — it catches silent kernel-matching failures caused by constraint name drift without altering the registration flow. See [Section 11.6.1](#1161-validation-mode-recommended-first-step) for the implementation pattern. + + **B. NHWC internal-domain schema diagnostics** + + Extend the validation pass to cover `com.ms.internal.nhwc`-domain registrations. When kernel lookup fails for a rewritten NHWC node, the diagnostic can now report exactly which constraint name was expected vs. what the kernel registered, directly addressing the diagnostic requirement in [Section 5.3.1.3](#5313-nhwc-design-requirements). + + **C. Parity report enhancement** + + Update `cuda_plugin_parity_report.py` to use the schema API (via a small C++ test harness or Python ONNX bindings) to flag type-constraint mismatches between the plugin's registered kernels and the ONNX schema, in addition to the existing op-coverage comparison. + + **D. Schema-driven `KernelDefBuilder` helpers (longer term)** + + Create a `KernelDefBuilder` helper that auto-derives constraint names from the schema instead of requiring hard-coded strings. This reduces maintenance burden when new opset versions introduce constraint name changes, but is lower priority than the validation pass since all current constraint names are correct. + + **E. Potential code locations for changes** + + | File | Change | + |------|--------| + | `cuda_plugin_kernels.cu` / `CreateCudaKernelRegistry()` | Add schema validation loop after kernel collection | + | `cuda_kernel_adapter.h` | (Optional) Add schema-aware macro variant or post-registration hook | + | `include/onnxruntime/ep/adapter/kernel_def_builder.h` | (Optional) Add schema-lookup helper for constraint names | + | `cuda_ep.cc` / `GetCapabilityImpl()` | (Optional) Add schema-based diagnostic when `EpGraphSupportInfo_LookUpKernel` returns nullptr | + | `test_cuda_plugin_ep.py` | Add a validation stage that exercises schema-validated registration | + +12. **Resource accounting and annotation-based partitioning (PR #27595)** — ORT is acquiring two related features that affect how graph nodes are partitioned to EPs: + + **A. Resource accounting** + + `IResourceAccountant` lets an EP declare a resource budget (e.g., available VRAM) and have the partitioner stop assigning nodes once that budget is exhausted. The framework passes an `IResourceAccountant*` to `IExecutionProvider::GetCapability()`; the in-tree CUDA EP uses it to compute per-node estimated VRAM cost from initializer sizes. + + For plugin EPs, the `OrtEp::GetCapability` callback currently has no mechanism to receive or report resource usage — the `OrtEp` C API does not expose `IResourceAccountant`. Two design options: + + - **Option A (preferred — ORT core change, completed in PR #27595):** Add an `OrtEp` analogue of the current `IResourceAccountant` flow. PR #27595 introduced `OrtEpGraphSupportInfo_RequestResourceForNode` and `OrtEpGraphSupportInfo_StopAssigningNodesDueToResourceExhaustion` to the C API. This is the implementation path moving forward. + + - **Option B (plugin-side workaround):** Expose the VRAM threshold through a plugin-specific session option key. During `GetCapabilityImpl`, the plugin reads the threshold from the parsed config and performs its own initializer-size accounting using `OrtEp_GetNodeAttributes` / node-graph-view APIs already present in the `OrtEp` API surface. This avoids an ORT core change but duplicates budget-tracking logic. + + **B. Annotation-based layering** + + PR #27595 also introduces `layering_annotations` — node-level `"layer_ann"` metadata that routes nodes to specific EPs or CPU during partitioning. The expected model is that plugin EPs participate through the same `GetCapability` flow and therefore observe whatever node set ORT presents after applying layering rules. In practice that should mean no plugin-specific changes are needed to respect annotations that exclude nodes from the plugin. However, the plugin design should avoid depending on undocumented filtering details in the `OrtGraph*` contract. If the plugin EP itself needs to *read* layering annotations for internal decisions, or if the API needs to make filtered-vs-unfiltered graph semantics explicit, that would require new `OrtEp` API surface. + + Current known limitations to keep in future work: + + - The `cuda(...)` device selector currently matches only the built-in `CUDAExecutionProvider`. It does not match the plugin EP name `CudaPluginExecutionProvider`, so layer assignment settings written against `cuda(...)` do not work with the CUDA plugin EP today. + - The `gpu:(...)` selector is currently matched using `OrtHardwareDevice::device_id`. That field is not a stable CUDA ordinal and is not guaranteed to uniquely identify one physical GPU, so index-based layer assignment is unreliable for the CUDA plugin EP, especially on hosts with multiple similar NVIDIA GPUs. + + **Recommended action:** Combine with the recently added `OrtEpGraphSupportInfo_RequestResourceForNode` C API explicitly (completed in PR #27595 on the ORT core side) to correctly assign nodes within the budget in the plugin's `CudaEp::GetCapabilityImpl()` when layer assignments exist. diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index c7f7f23f70334..bd87c49d39ea5 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -43,6 +43,16 @@ class Tensor final { // Strive not to allocate Tensor with new/delete as it is a shallow class and using it by value is just fine. // Use InitOrtValue() methods to allocate for OrtValue. +#ifdef BUILD_CUDA_EP_AS_PLUGIN + /// Static factory kept for plugin EP kernels that still call Tensor::Create(). + /// The main tree deprecated these in favor of constructors, but dynamically-linked + /// plugin code relies on the static method. + static std::unique_ptr Create(MLDataType elt_type, const TensorShape& shape, + std::shared_ptr allocator) { + return std::make_unique(elt_type, shape, std::move(allocator)); + } +#endif + Tensor() = default; // to allow creating vector to support seq(tensor) /** diff --git a/include/onnxruntime/ep/adapter/allocator.h b/include/onnxruntime/ep/adapter/allocator.h index 4f107ae72c0e9..4e601bb22252b 100644 --- a/include/onnxruntime/ep/adapter/allocator.h +++ b/include/onnxruntime/ep/adapter/allocator.h @@ -7,6 +7,8 @@ #error "This header should not be included directly. Include ep/adapters.h instead." #endif +#include + #include "core/framework/allocator.h" namespace onnxruntime { @@ -39,7 +41,7 @@ class Allocator : public OrtAllocator { private: explicit Allocator(const OrtMemoryInfo* memory_info) - : OrtAllocator{}, memory_info_(memory_info) { + : OrtAllocator{}, memory_info_(memory_info), get_allocator_impl_(nullptr) { version = ORT_API_VERSION; Alloc = AllocImpl; Free = FreeImpl; diff --git a/include/onnxruntime/ep/adapter/kernel_def_builder.h b/include/onnxruntime/ep/adapter/kernel_def_builder.h index 279c2782ef8eb..276c793a4311c 100644 --- a/include/onnxruntime/ep/adapter/kernel_def_builder.h +++ b/include/onnxruntime/ep/adapter/kernel_def_builder.h @@ -130,6 +130,8 @@ struct KernelDefBuilder { return *this; } + // ExecQueueId is intentionally a no-op. The plugin EP manages stream + // assignment externally; the queue id hint is not needed. KernelDefBuilder& ExecQueueId(int /*queue_id*/) { return *this; } Ort::KernelDef Build() { return builder_.Build(); } diff --git a/include/onnxruntime/ep/adapter/node.h b/include/onnxruntime/ep/adapter/node.h index 8510f7bc5031b..91aff7d670b2f 100644 --- a/include/onnxruntime/ep/adapter/node.h +++ b/include/onnxruntime/ep/adapter/node.h @@ -26,11 +26,29 @@ struct Node { return kernel_info_.GetOperatorType(); } + /** Gets the Node's domain. */ + std::string Domain() const { + return kernel_info_.GetOperatorDomain(); + } + /** Gets the since version of the operator. */ int SinceVersion() const noexcept { return kernel_info_.GetOperatorSinceVersion(); } + /** Gets the number of outputs. */ + size_t OutputCount() const noexcept { + return kernel_info_.GetOutputCount(); + } + + /** Gets whether an output exists or is an omitted optional output. */ + bool OutputExists(size_t index) const { + // KernelInfo_GetOutputName returns an empty string for omitted optional + // outputs, which lets adapter consumers mirror NodeArg::Exists() without + // pulling in full NodeArg metadata. + return index < OutputCount() && !kernel_info_.GetOutputName(index).empty(); + } + private: const Ort::ConstKernelInfo kernel_info_; }; diff --git a/include/onnxruntime/ep/adapter/op_kernel.h b/include/onnxruntime/ep/adapter/op_kernel.h index 0a4b16321a8eb..273461b36e75f 100644 --- a/include/onnxruntime/ep/adapter/op_kernel.h +++ b/include/onnxruntime/ep/adapter/op_kernel.h @@ -8,6 +8,7 @@ #endif #include +#include #include #include "core/framework/allocator.h" @@ -35,7 +36,7 @@ struct OpKernel { explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_{info} {} virtual ~OpKernel() {} - Node Node() const { + adapter::Node Node() const { return op_kernel_info_.node(); } const OpKernelInfo& Info() const { @@ -93,6 +94,15 @@ struct OpKernelContext { input_tensors_[index] = CreateTensorFromApiValue(const_cast(static_cast(input))); return &input_tensors_[index]; } + /// Get a required (non-optional) input tensor. Throws if the input is null. + /// Use Input() for optional inputs that may legitimately be absent. + template >> + const T& RequiredInput(int index) const { + auto* input = Input(index); + ORT_ENFORCE(input != nullptr, "Required input ", index, " is null"); + return *input; + } Tensor* Output(int index, const TensorShape& shape) { if (index < 0 || static_cast(index) >= output_tensors_.size()) { return nullptr; @@ -109,6 +119,12 @@ struct OpKernelContext { output_tensors_[index] = CreateTensorFromApiValue(output); return &output_tensors_[index]; } + /// Get a required (non-optional) output tensor. Throws if the output is null. + Tensor& RequiredOutput(int index, const TensorShape& shape) { + auto* output = Output(index, shape); + ORT_ENFORCE(output != nullptr, "Required output ", index, " is null"); + return *output; + } Tensor* Output(int index, const std::vector& shape) { return Output(index, TensorShape{shape}); } @@ -116,10 +132,20 @@ struct OpKernelContext { return Output(index, TensorShape{shape}); } [[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const { - return static_cast(op_kernel_.Info().GetKernelInfo().GetEp())->GetTempSpaceCPUAllocator(output); + // Use GetOrtEp() directly from the cached KernelInfoCache rather than going through + // GetExecutionProvider()->GetOrtEp(). GetExecutionProvider() returns the native EP impl + // (e.g. WebGpuExecutionProvider), which doesn't override GetOrtEp() and returns nullptr. + // The cached ort_ep_ is resolved from the plugin wrapper's IExecutionProvider during + // KernelInfoCache construction, so it correctly holds the OrtEp instance. + const auto* ort_ep = op_kernel_.Info().GetOrtEp(); + ORT_ENFORCE(ort_ep != nullptr, "Kernel execution provider is not associated with an OrtEp instance."); + return static_cast(ort_ep)->GetTempSpaceCPUAllocator(output); } [[nodiscard]] Status GetTempSpaceAllocator(AllocatorPtr* output) const { - return static_cast(op_kernel_.Info().GetKernelInfo().GetEp())->GetTempSpaceAllocator(output); + // See comment in GetTempSpaceCPUAllocator for why we use GetOrtEp() directly. + const auto* ort_ep = op_kernel_.Info().GetOrtEp(); + ORT_ENFORCE(ort_ep != nullptr, "Kernel execution provider is not associated with an OrtEp instance."); + return static_cast(ort_ep)->GetTempSpaceAllocator(output); } int InputCount() const { return static_cast(input_tensors_.size()); @@ -131,7 +157,6 @@ struct OpKernelContext { // TODO(fs-eire): Implement GetUseDeterministicCompute(). return false; } - void* GetGPUComputeStream() const { return context_.GetGPUComputeStream(); } @@ -146,7 +171,7 @@ struct OpKernelContext { }; /// -/// A bridge class between `onnxruntime::ep::adapter::OpKernel` and `::OrtKernelImpl`. +/// A bridge class between `onnxruntime::ep::adapter::OpKernel` and `onnxruntime::OrtKernelImpl`. /// struct KernelImpl : OrtKernelImpl { explicit KernelImpl(std::unique_ptr impl) diff --git a/include/onnxruntime/ep/adapter/op_kernel_info.h b/include/onnxruntime/ep/adapter/op_kernel_info.h index 644cb30788ec6..f0b620c334d40 100644 --- a/include/onnxruntime/ep/adapter/op_kernel_info.h +++ b/include/onnxruntime/ep/adapter/op_kernel_info.h @@ -9,8 +9,10 @@ #include +#include "core/common/narrow.h" #include "core/common/status.h" #include "core/framework/config_options.h" +#include "core/framework/op_kernel_info.h" #include "core/framework/tensor_shape.h" #include "core/framework/tensor.h" @@ -41,18 +43,26 @@ struct OpKernelInfo { // to manage the lifetime of the cached data. struct KernelInfoCache { explicit KernelInfoCache(const OrtKernelInfo* kernel_info) : kernel_info_(kernel_info) { + const auto* core_kernel_info = reinterpret_cast(kernel_info); + execution_provider_ = core_kernel_info->GetExecutionProvider(); + ort_ep_ = execution_provider_ != nullptr ? execution_provider_->GetOrtEp() : nullptr; + ep_impl_ = ort_ep_ != nullptr ? (static_cast(ort_ep_))->EpImpl() : execution_provider_; + Ort::ConstKernelInfo info{kernel_info}; - const int input_count = info.GetInputCount(); + const size_t input_count = info.GetInputCount(); constant_input_tensors.resize(input_count); - for (int i = 0; i < input_count; ++i) { + for (size_t i = 0; i < input_count; ++i) { int is_constant = 0; - Ort::ConstValue const_input = info.GetTensorConstantInput(i, &is_constant); + Ort::ConstValue const_input = info.GetTensorConstantInput(gsl::narrow_cast(i), &is_constant); if (is_constant && const_input != nullptr && const_input.IsTensor()) { constant_input_tensors[i] = CreateTensorFromApiValue(const_cast(static_cast(const_input))); } } } const OrtKernelInfo* kernel_info_; + const ::onnxruntime::IExecutionProvider* execution_provider_{}; + const OrtEp* ort_ep_{}; + const ::onnxruntime::IExecutionProvider* ep_impl_{}; std::vector constant_input_tensors; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(KernelInfoCache); }; @@ -61,13 +71,16 @@ struct OpKernelInfo { } const DataTransferManager& GetDataTransferManager() const noexcept { - return (static_cast(info_.GetEp()))->GetDataTransferManager(); + return (static_cast(cache_->ort_ep_))->GetDataTransferManager(); } Node node() const noexcept { return Node{cache_->kernel_info_}; } const IExecutionProvider* GetExecutionProvider() const noexcept { - return (static_cast(info_.GetEp()))->EpImpl(); + return cache_->ep_impl_; + } + const OrtEp* GetOrtEp() const noexcept { + return cache_->ort_ep_; } KernelDef GetKernelDef() const noexcept { @@ -75,7 +88,7 @@ struct OpKernelInfo { } const Ort::ConstKernelInfo GetKernelInfo() const noexcept { - return info_; + return Ort::ConstKernelInfo{cache_->kernel_info_}; } ConfigOptions GetConfigOptions() const noexcept { @@ -85,7 +98,7 @@ struct OpKernelInfo { } int GetInputCount() const noexcept { - return info_.GetInputCount(); + return gsl::narrow_cast(info_.GetInputCount()); } const std::vector& GetConstantInputTensors() const noexcept { diff --git a/include/onnxruntime/ep/api.h b/include/onnxruntime/ep/api.h index 36d99e5d44d45..c22e52ed8aaa5 100644 --- a/include/onnxruntime/ep/api.h +++ b/include/onnxruntime/ep/api.h @@ -3,7 +3,9 @@ #pragma once +#include #include +#include #pragma push_macro("ORT_API_MANUAL_INIT") #undef ORT_API_MANUAL_INIT @@ -15,6 +17,8 @@ namespace onnxruntime { namespace ep { struct ApiPtrs { + ApiPtrs(const OrtApi& ort_, const OrtEpApi& ep_, const OrtModelEditorApi& model_editor_) + : ort(ort_), ep(ep_), model_editor(model_editor_) {} const OrtApi& ort; const OrtEpApi& ep; const OrtModelEditorApi& model_editor; @@ -28,23 +32,28 @@ inline std::optional g_api_ptrs; /// Get the global instance of ApiPtrs. /// inline const ApiPtrs& Api() { + if (!detail::g_api_ptrs.has_value()) { + throw std::logic_error("onnxruntime::ep::Api() called before ApiInit()."); + } return *detail::g_api_ptrs; } /// /// Initialize the EP API pointers and global OrtEnv if not already done. +/// Thread-safe via std::call_once. /// inline void ApiInit(const OrtApiBase* ort_api_base) { - // Manual init for the C++ API - const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); - const OrtEpApi* ep_api = ort_api->GetEpApi(); - const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); - Ort::InitApi(ort_api); - - // Initialize the global API instance - if (!detail::g_api_ptrs) { + static std::once_flag init_flag; + std::call_once(init_flag, [&]() { + // Manual init for the C++ API + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + const OrtEpApi* ep_api = ort_api->GetEpApi(); + const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); + Ort::InitApi(ort_api); + + // Initialize the global API instance detail::g_api_ptrs.emplace(*ort_api, *ep_api, *model_editor_api); - } + }); } } // namespace ep diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index cff5d5d320423..83b6237dcc2a6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -58,6 +58,8 @@ Attention::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionB template Status Attention::ComputeInternal(OpKernelContext* context) const { + auto ort_stream = GetOrtStream(context); + const Tensor* input = context->Input(0); const Tensor* weights = context->Input(1); const Tensor* bias = context->Input(2); @@ -139,14 +141,14 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; } - auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); - auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); - auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, GetComputeStream(context)); + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, GetComputeStream(context)); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, GetComputeStream(context)); #else constexpr bool use_flash_attention = false; - auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); - auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr - auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto softmax_lse_buffer = GetScratchBuffer(0, GetComputeStream(context)); + auto softmax_lse_accum_buffer = GetScratchBuffer(0, GetComputeStream(context)); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, GetComputeStream(context)); // nullptr #endif if (!use_flash_attention) { @@ -238,12 +240,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); int m = batch_size * sequence_length; int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size); int k = parameters.input_hidden_size; - IAllocatorUniquePtr gemm_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(m * n) * sizeof(T), false, context->GetComputeStream()); + IAllocatorUniquePtr gemm_buffer = GetScratchBuffer(static_cast(m * n) * sizeof(T), GetComputeStream(context)); CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); @@ -275,7 +275,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_memory_efficient_attention, use_cudnn_flash_attention, false); - IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); + IAllocatorUniquePtr work_space = GetScratchBuffer(workSpaceSize, GetComputeStream(context)); data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); if (nullptr != bias) { @@ -313,7 +313,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } cudnnHandle_t cudnn = GetCudnnHandle(context); - return QkvToContext(device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data); + return QkvToContext(device_prop, cublas, cudnn, ort_stream.get(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index ff7ac67852427..ba647bc3f6d33 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -974,7 +974,7 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream, if (use_persistent_softmax) { return onnxruntime::cuda::dispatch_warpwise_softmax_forward( - ort_stream, + stream, output, persistent_softmax_workspace, total_sequence_length, diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index 5e5f909415fff..e13f13fc8b245 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -175,6 +175,8 @@ DecoderAttention::DecoderAttention(const OpKernelInfo& info) : CudaKernel(inf template Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { + auto ort_stream = GetOrtStream(context); + const Tensor* query(context->Input(0)); const Tensor* key(context->Input(1)); const Tensor* q_weights(context->Input(2)); @@ -262,7 +264,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { // calculate q gemm_query_buffer_p = GetScratchBuffer(static_cast(batch_size) * sequence_length * hidden_size, - context->GetComputeStream()); + GetComputeStream(context)); m = sequence_length * batch_size; n = hidden_size; k = hidden_size; @@ -288,7 +290,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { if (!has_layer_state_ || !use_past_) { if (!static_kv_) { gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * sequence_length * hidden_size, - context->GetComputeStream()); + GetComputeStream(context)); m = sequence_length * batch_size; n = 2 * hidden_size; k = hidden_size; @@ -308,7 +310,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { // gemm_kv_buffer in col-base: (2*h2, T_S*B) } else { gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * key_sequence_length * hidden_size, - context->GetComputeStream()); + GetComputeStream(context)); m = key_sequence_length * batch_size; n = 2 * hidden_size; k = hidden_size; @@ -334,7 +336,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { int cache_sequence_length = static_cast(cache_shape[2]); if (!static_kv_) { gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * sequence_length * hidden_size, - context->GetComputeStream()); + GetComputeStream(context)); m = sequence_length * batch_size; kv_sequence_length = cache_sequence_length + sequence_length; // broadcast bias for key and value: (2*h2, T_S*B) @@ -357,11 +359,11 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { size_t bytes = element_size * batch_size * (static_cast(sequence_length) + static_cast(2) * kv_sequence_length) * hidden_size; - auto qkv_buffer_p = GetScratchBuffer(bytes, context->GetComputeStream()); + auto qkv_buffer_p = GetScratchBuffer(bytes, GetComputeStream(context)); bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * (static_cast(2) * head_size + static_cast(kv_sequence_length)); - auto workspace_p = GetScratchBuffer(bytes, context->GetComputeStream()); + auto workspace_p = GetScratchBuffer(bytes, GetComputeStream(context)); Tensor* output(context->Output(0, query_shape)); TensorShape new_cache_shape({batch_size, num_heads_, kv_sequence_length, head_size}); @@ -371,7 +373,7 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { return LaunchDecoderAttentionKernel( device_prop, UseTF32(), - context->GetComputeStream(), + ort_stream.get(), cublas, element_size, batch_size, diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc index b4643da58eba5..416bf24dd818c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc @@ -134,7 +134,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont int m = batch_size * sequence_length; int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size); int k = parameters.input_hidden_size; - gemm_buffer = GetScratchBuffer(static_cast(m) * n, context->GetComputeStream()); + gemm_buffer = GetScratchBuffer(static_cast(m) * n, GetComputeStream(context)); CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index e7ed96d7f5ee2..0a5e2ef55197b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -30,6 +30,37 @@ REGISTER_KERNEL_TYPED(double) using namespace ONNX_NAMESPACE; +#ifdef BUILD_CUDA_EP_AS_PLUGIN +// PLUGIN BUILD ADAPTATION: bias_gelu_helper::CheckInputs lives in the CPU +// provider and cannot be linked into the plugin. Reimplement the same input +// validation (rank checks, bias shape matching) inline. +// Keep in sync with contrib_ops/cpu/bert/bias_gelu_helper.h. +static Status CheckInputsForPlugin(const OpKernelContext* context) { + const Tensor* input = context->Input(0); + const Tensor* bias = context->Input(1); + + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() < 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 0 is expected to have 1 or more dimensions, got ", input_dims.size()); + } + + if (nullptr != bias) { + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 1 is expected to have 1 dimensions, got ", bias_dims.size()); + } + if (bias_dims[0] != input_dims[input_dims.size() - 1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 1 dimension 0 should have same length as the last dimension of input 0"); + } + } + + return Status::OK(); +} +#endif + template FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { const TransformerOptions* options = TransformerOptions::GetInstance(); @@ -38,7 +69,11 @@ FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel template Status FastGelu::ComputeInternal(OpKernelContext* context) const { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + ORT_RETURN_IF_ERROR(CheckInputsForPlugin(context)); +#else ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context)); +#endif const Tensor* input = context->Input(0); const Tensor* bias = context->Input(1); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 6e6d463889007..3b6b5f9079ebe 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -142,6 +142,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) // 11. head_sink (Tensor) - Attention sink for GPT-OSS template Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { + auto ort_stream = GetOrtStream(context); + const Tensor* query = context->Input(0); const Tensor* key = context->Input(1); const Tensor* value = context->Input(2); @@ -259,8 +261,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, dense_head_size}; TensorShape present_shape(present_dims); - Tensor* present_key_tensor = context->Output(1, present_shape); - Tensor* present_value_tensor = context->Output(2, present_shape); + Tensor* present_key_output = context->Output(1, present_shape); // present_key + Tensor* present_value_output = context->Output(2, present_shape); // present_value IAllocatorUniquePtr k_buffer; IAllocatorUniquePtr v_buffer; @@ -288,8 +290,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); - data.present_key = reinterpret_cast(present_key_tensor->MutableData()); - data.present_value = reinterpret_cast(present_value_tensor->MutableData()); + data.present_key = reinterpret_cast(present_key_output->MutableData()); + data.present_value = reinterpret_cast(present_value_output->MutableData()); // Compute past_present_share_buffer early since it's needed for flash attention path selection. // This compares the final pointer values after quantization handling. @@ -370,7 +372,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons xqa_total_bytes += q_bytes + k_bytes; } - xqa_scratch_buffer = this->GetScratchBuffer(xqa_total_bytes, context->GetComputeStream()); + xqa_scratch_buffer = this->GetScratchBuffer(xqa_total_bytes, GetComputeStream(context)); data.xqa_buffer = xqa_scratch_buffer.get(); data.xqa_buffer_bytes = xqa_internal_bytes; @@ -413,11 +415,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons out_accum_bytes = onnxruntime::flash::get_out_accum_size(num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, round_multiple(parameters.head_size, 32)); } - softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); - softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); - out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, GetComputeStream(context)); + softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, GetComputeStream(context)); + out_accum_buffer = GetScratchBuffer(out_accum_bytes, GetComputeStream(context)); - auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + auto cuda_stream = Stream(context); if (softmax_lse_accum_bytes > 0) { // Initialize to 0 is fine because Flash kernel will write -inf to it if needed. // However, the standard Flash kernel often doesn't zero it globally. @@ -442,8 +444,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons } else { // Compute sequence length buffers (past_seq_lens and total_seq_lens). // Allocate buffer for both: first half is past_seq_lens, second half is total_seq_lens. - seq_lens_buffer = GetScratchBuffer(3 * parameters.batch_size, context->GetComputeStream()); - auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + seq_lens_buffer = GetScratchBuffer(3 * parameters.batch_size, GetComputeStream(context)); + auto cuda_stream = Stream(context); data.past_seq_lens = seq_lens_buffer.get(); data.total_seq_lens = seq_lens_buffer.get() + parameters.batch_size; data.padded_seq_lens = data.total_seq_lens + parameters.batch_size; @@ -480,9 +482,9 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons ? (sizeof(float) * parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size) : 0; - k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); - v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); - fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); + k_buffer = GetScratchBuffer(kv_buffer_bytes, GetComputeStream(context)); + v_buffer = GetScratchBuffer(kv_buffer_bytes, GetComputeStream(context)); + fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, GetComputeStream(context)); data.k = reinterpret_cast(k_buffer.get()); data.v = reinterpret_cast(v_buffer.get()); @@ -501,7 +503,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons data.use_memory_efficient_attention); if (buffer_req.qkv_buffer_bytes > 0) { - unpacked_qkv_buffer = GetScratchBuffer(buffer_req.qkv_buffer_bytes, context->GetComputeStream()); + unpacked_qkv_buffer = GetScratchBuffer(buffer_req.qkv_buffer_bytes, GetComputeStream(context)); data.qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); } @@ -556,7 +558,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons cublasHandle_t cublas = GetCublasHandle(context); ORT_RETURN_IF_ERROR((QkvToContext( - device_prop, cublas, context->GetComputeStream(), parameters, data))); + device_prop, cublas, ort_stream.get(), parameters, data))); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc index 9c5d0e9834f6f..3501c3baff0b6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc @@ -78,11 +78,11 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { // TODO(tianleiwu): only calculate global index once per model instead of once per LongformerAttention node. // Build Global Index - auto global_index_buffer = GetScratchBuffer(static_cast(batch_size) * sequence_length, context->GetComputeStream()); - auto batch_global_num_buffer = GetScratchBuffer(batch_size, context->GetComputeStream()); + auto global_index_buffer = GetScratchBuffer(static_cast(batch_size) * sequence_length, GetComputeStream(context)); + auto batch_global_num_buffer = GetScratchBuffer(batch_size, GetComputeStream(context)); size_t global_scratch_bytes = GetGlobalScratchSize(sequence_length); - auto global_scratch_buffer = GetScratchBuffer(global_scratch_bytes, context->GetComputeStream()); + auto global_scratch_buffer = GetScratchBuffer(global_scratch_bytes, GetComputeStream(context)); auto& device_prop = GetDeviceProp(); ORT_RETURN_IF_ERROR(BuildGlobalIndex( @@ -116,7 +116,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { size_t qkv_size = static_cast(batch_size) * sequence_length * 3 * hidden_size * element_size; // Buffer for GEMM outputs of q, k, v, global_q, global_k and global_v // TODO(tianleiwu): compact global_q only need batch_size * window * hidden_size * element_size buffer size. - auto gemm_buffer = GetScratchBuffer(qkv_size + qkv_size, context->GetComputeStream()); + auto gemm_buffer = GetScratchBuffer(qkv_size + qkv_size, GetComputeStream(context)); bool use_merged_qkv_weights = (weights->Shape().NumDimensions() == 2); @@ -257,7 +257,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { max_num_global, window_, disable_compact_memory); - auto workspace_buffer = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + auto workspace_buffer = GetScratchBuffer(workSpaceSize, GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchLongformerAttentionKernel( device_prop, cublas, @@ -285,7 +285,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { use_half4_)); // Defer release of pinned memory since cudaStreamSynchronize is not used here and kernel need access the buffer. - this->AddDeferredReleaseCPUPtr(pinned_buffer.release(), context->GetComputeStream()); + this->AddDeferredReleaseCPUPtr(pinned_buffer.release(), GetComputeStream(context)); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 452a00deeb21c..a2af4831a3a00 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -88,6 +88,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) template Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { + auto ort_stream = GetOrtStream(context); + const Tensor* query = context->Input(0); const Tensor* key = context->Input(1); const Tensor* value = context->Input(2); @@ -290,7 +292,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) cons kernel_type = AttentionKernelType::AttentionKernel_LeanAttention; } - auto lean_sync_flag_buffer = GetScratchBuffer(sync_flag_bytes, context->GetComputeStream()); + auto lean_sync_flag_buffer = GetScratchBuffer(sync_flag_bytes, GetComputeStream(context)); data.lean_sync_flag = reinterpret_cast(lean_sync_flag_buffer.get()); #else constexpr bool use_lean_attention = false; @@ -336,9 +338,9 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) cons #endif #if USE_LEAN_ATTENTION || USE_FLASH_ATTENTION - auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); - auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); - auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, GetComputeStream(context)); + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, GetComputeStream(context)); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, GetComputeStream(context)); if (use_flash_attention || use_lean_attention) { data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); @@ -485,7 +487,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) cons data.use_memory_efficient_attention = use_memory_efficient_attention; data.use_decoder_masked_multihead_attention = use_decoder_masked_multihead_attention; data.kernel_type = kernel_type; - data.allocator = Info().GetAllocator(OrtMemType::OrtMemTypeDefault); + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&data.allocator)); // Cache of cumulated sequence length that could help when sequence length does not change (for example, image model). // The cache will be initialized only once, and become readonly after that. @@ -515,7 +517,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) cons use_memory_efficient_attention, use_cudnn_sdpa, no_qkv_workspace); - auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + auto work_space = GetScratchBuffer(workspace_bytes, GetComputeStream(context)); data.has_qkv_workspace = !no_qkv_workspace; data.workspace = reinterpret_cast(work_space.get()); @@ -528,7 +530,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) cons std::vector seqlens_k(parameters.batch_size, parameters.total_sequence_length - 1); size_t seqlens_k_bytes = 0; seqlens_k_bytes = sizeof(int) * parameters.batch_size; - auto seqlens_k_buffer = GetScratchBuffer(seqlens_k_bytes, context->GetComputeStream()); + auto seqlens_k_buffer = GetScratchBuffer(seqlens_k_bytes, GetComputeStream(context)); if (seqlens_k_buffer != nullptr) { data.seqlens_k_total = reinterpret_cast(seqlens_k_buffer.get()); CUDA_RETURN_IF_ERROR(cudaMemcpy(data.seqlens_k_total, seqlens_k.data(), seqlens_k_bytes, cudaMemcpyHostToDevice)); @@ -557,7 +559,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) cons cudnnHandle_t cudnn = GetCudnnHandle(context); DUMP_STRING("Run QkvToContext from MHA CUDA"); return QkvToContext( - device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data); + device_prop, cublas, cudnn, ort_stream.get(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index 10bd1170d8f07..df4bc7b170106 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -286,7 +286,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { int m = parameters.token_count; int n = parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size; int k = parameters.input_hidden_size; - gemm_buffer = this->template GetScratchBuffer(static_cast(m) * n, context->GetComputeStream()); + gemm_buffer = this->template GetScratchBuffer(static_cast(m) * n, this->GetComputeStream(context)); cublasHandle_t cublas = this->GetCublasHandle(context); @@ -310,7 +310,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { false, use_memory_efficient_attention, no_qkv_workspace); - auto work_space = this->template GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + auto work_space = this->template GetScratchBuffer(workSpaceSize, this->GetComputeStream(context)); typedef typename ToCudaType::MappedType CudaT; PackedAttentionData data; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 928a8b17229e6..20a45d18367bc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -267,7 +267,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co use_flash_attention, use_memory_efficient_attention, no_qkv_workspace); - auto work_space = this->template GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + auto work_space = this->template GetScratchBuffer(workSpaceSize, this->GetComputeStream(context)); PackedMultiHeadAttentionData data; data.query = reinterpret_cast(query->Data()); diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc index b9f30f71e9a66..5df2c8b438771 100644 --- a/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc @@ -54,6 +54,8 @@ PagedAttention::PagedAttention(const OpKernelInfo& info) template Status PagedAttention::ComputeInternal(OpKernelContext* context) const { + auto ort_stream = GetOrtStream(context); + const Tensor* query = context->Input(0); const Tensor* key = context->Input(1); const Tensor* value = context->Input(2); @@ -151,10 +153,10 @@ Status PagedAttention::ComputeInternal(OpKernelContext* context) const { softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.token_count, parameters.num_heads); } - auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, GetComputeStream(context)); #else constexpr bool use_flash_attention = false; - auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto softmax_lse_buffer = GetScratchBuffer(0, GetComputeStream(context)); // nullptr #endif if (!use_flash_attention) { @@ -163,7 +165,7 @@ Status PagedAttention::ComputeInternal(OpKernelContext* context) const { } size_t cumulative_seqlens_kv_bytes = sizeof(int) * (parameters.batch_size + 1); - auto cumulative_seqlens_kv_buffer = GetScratchBuffer(cumulative_seqlens_kv_bytes, context->GetComputeStream()); + auto cumulative_seqlens_kv_buffer = GetScratchBuffer(cumulative_seqlens_kv_bytes, GetComputeStream(context)); size_t workspace_buffer_bytes = 0; if (do_rotary_) { @@ -171,7 +173,7 @@ Status PagedAttention::ComputeInternal(OpKernelContext* context) const { } else if (parameters.is_packed_qkv) { workspace_buffer_bytes = sizeof(T) * parameters.token_count * parameters.hidden_size; } - auto workspace_buffer = GetScratchBuffer(workspace_buffer_bytes, context->GetComputeStream()); + auto workspace_buffer = GetScratchBuffer(workspace_buffer_bytes, GetComputeStream(context)); // Print debug info if (kernel_options_->AllowDebugInfo()) { @@ -210,7 +212,7 @@ Status PagedAttention::ComputeInternal(OpKernelContext* context) const { cublasHandle_t cublas = GetCublasHandle(context); return QkvToContext( - device_prop, cublas, context->GetComputeStream(), parameters, data); + device_prop, cublas, ort_stream.get(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc index 05f55d9106d0e..e62020a09216d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -163,7 +163,7 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c const size_t elements_after_gemm = (size_t)BNS * (size_t)D; bool reuse_output = (seq_len >= D); size_t workspace_size = sizeof(T) * (elements_in_query + (reuse_output ? (size_t)0 : elements_after_gemm)); - auto workspace = GetScratchBuffer(workspace_size, context->GetComputeStream()); + auto workspace = this->GetScratchBuffer(workspace_size, this->GetComputeStream(context)); cudaStream_t stream = Stream(context); if (!is_padding_removed) { diff --git a/onnxruntime/contrib_ops/cuda/bert/remove_padding.cc b/onnxruntime/contrib_ops/cuda/bert/remove_padding.cc index ec9617982c776..eba4c48301cf3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/remove_padding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/remove_padding.cc @@ -53,7 +53,7 @@ Status RemovePadding::ComputeInternal(OpKernelContext* context) const { int64_t sequence_length = dims[1]; int64_t hidden_size = dims[2]; - auto token_count_buffer = GetScratchBuffer(2, context->GetComputeStream()); + auto token_count_buffer = GetScratchBuffer(2, GetComputeStream(context)); TensorShapeVector token_offset_shape(2); token_offset_shape[0] = batch_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 92ae7e81fb5bd..aefd86a6ebd10 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -43,9 +43,14 @@ SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); +#ifdef BUILD_CUDA_EP_AS_PLUGIN + // Plugin adapter cannot static_cast to CUDAExecutionProvider directly. + // Use the adapter shim that reads the config from the per-EP runtime map. + strict_ = onnxruntime::cuda::GetCudaKernelAdapterSkipLayerNormStrictMode(op_kernel_info.GetExecutionProvider()); +#else const CUDAExecutionProvider* cuda_ep = static_cast(op_kernel_info.GetExecutionProvider()); - strict_ = cuda_ep->IsSkipLayerNormInStrictMode(); +#endif } template diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh index b1a6badc6b3f1..d132fba85988c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh @@ -17,26 +17,19 @@ namespace NAMESPACE_NAME { // XQA kernels require SM80+ (Ampere or newer). We need a guard that works correctly // during both host and device compilation passes: // - Device pass: __CUDA_ARCH__ is defined, check it directly. -// - Host pass: __CUDA_ARCH__ is NOT defined. Use __CUDA_ARCH_LIST__ (available since -// CUDA 11.5), which nvcc defines as a comma-separated ascending list of target -// architectures (e.g. 750,800). In #if, the comma operator returns the rightmost -// (highest) value, so __CUDA_ARCH_LIST__ >= 800 checks the max target arch. -// - Non-nvcc (e.g. IDE parser): cuda_hint.cuh defines __CUDA_ARCH__ 900, which -// takes the first branch. -// Using !defined(__CUDA_ARCH__) here would be WRONG: it always evaluates true during -// the host pass, causing the kernel to be declared even when no SM80+ device code -// exists. CUDA 13+ then fails to generate a host stub, producing C2129 / LNK2001. +// - Host pass: rely on HAS_SM80_OR_LATER from cmake/external/cuda_configuration.cmake. +// If any SM80+ arch is enabled, the host stub must be emitted. +// - Non-nvcc parsers usually won't see the CMake-provided define, so keep editor parsing +// intact by taking the fallback branch when __CUDACC__ is not defined. +// Using only !defined(__CUDA_ARCH__) here would be WRONG: it always evaluates true during +// the host pass, causing the kernel to be declared even when no SM80+ device code exists. +// CUDA 13+ then fails to generate a host stub, producing C2129 / LNK2001. #undef XQA_HAS_SM80_TARGET #ifdef __CUDA_ARCH__ #if __CUDA_ARCH__ >= 800 #define XQA_HAS_SM80_TARGET 1 #endif -#elif defined(__CUDA_ARCH_LIST__) -#if __CUDA_ARCH_LIST__ >= 800 -#define XQA_HAS_SM80_TARGET 1 -#endif -#else -// Non-nvcc fallback: assume supported (IDE parsers, etc.) +#elif defined(HAS_SM80_OR_LATER) || !defined(__CUDACC__) #define XQA_HAS_SM80_TARGET 1 #endif diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc index 967f30a304ac2..00e50a80c327b 100644 --- a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc @@ -83,9 +83,11 @@ Status DistributedReduceBase::ComputeInternal(OpKernelContext* context) const const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute(); return onnxruntime::cuda::ReduceComputeCore( /* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + /* kernel */ this, *input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span, /* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false, - enable_fast_but_non_deterministic_reduction, context->GetComputeStream()); + enable_fast_but_non_deterministic_reduction, + Stream(context), GetComputeStream(context), GetCudnnHandle(context)); } return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index dea5391c7629b..8f729f913c036 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -65,6 +65,55 @@ struct DispatchGroupNorm { broadcast_skip, channels_per_block); } + +#ifdef BUILD_CUDA_EP_AS_PLUGIN + // Plugin overload: accepts PluginTuningContextStub* (unused) and raw void* + // stream handle instead of IKernelExplorer*/Stream* which are not available + // in the plugin build. Uses OrtStreamAdapter to bridge to the _impl kernel. + Status operator()(CudaKernel::PluginTuningContextStub* tuning_ctx, + void* ort_stream, + Tensor* output, + Tensor* add_out, + const Tensor* input, + const Tensor* skip, + const Tensor* bias, + const Tensor* gamma, + const Tensor* beta, + void* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_swish_activation, + bool broadcast_skip, + int channels_per_block) { + ORT_UNUSED_PARAMETER(tuning_ctx); + typedef typename ToCudaType::MappedType CudaT; + onnxruntime::OrtStreamAdapter ort_stream_adapter(ort_stream); + return LaunchGroupNormKernel( + nullptr, + ort_stream_adapter.get(), + reinterpret_cast(output->MutableData()), + add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), + reinterpret_cast(input->Data()), + skip == nullptr ? nullptr : reinterpret_cast(skip->Data()), + bias == nullptr ? nullptr : reinterpret_cast(bias->Data()), + gamma->Data(), + beta->Data(), + workspace, + epsilon, + batch_size, + num_channels, + height, + width, + num_groups, + use_swish_activation, + broadcast_skip, + channels_per_block); + } +#endif }; } // namespace @@ -208,11 +257,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { } auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups_), - context->GetComputeStream()); + GetComputeStream(context)); utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); return dispatcher.InvokeRet(GetTuningContext(), - context->GetComputeStream(), output, add_out, input, skip, bias, + GetComputeStream(context), output, add_out, input, skip, bias, gamma, beta, workspace.get(), epsilon_, batch_size, diff --git a/onnxruntime/contrib_ops/cuda/fused_conv.cc b/onnxruntime/contrib_ops/cuda/fused_conv.cc index 0554cc34933f1..34732798a656a 100644 --- a/onnxruntime/contrib_ops/cuda/fused_conv.cc +++ b/onnxruntime/contrib_ops/cuda/fused_conv.cc @@ -181,7 +181,7 @@ class FusedConv : public onnxruntime::cuda::CudaKernel { // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, - context->GetComputeStream()); + GetComputeStream(context)); s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); } else { // No post slicing needed. Fill the output tensor's buffer directly. @@ -338,7 +338,7 @@ class FusedConv : public onnxruntime::cuda::CudaKernel { } if (s_.post_slicing_required) { s_.memory_for_cudnn_conv_results = GetScratchBuffer( - TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); + TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, GetComputeStream(context)); s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); } else { s_.y_data = reinterpret_cast(s_.Y->MutableData()); @@ -358,7 +358,7 @@ class FusedConv : public onnxruntime::cuda::CudaKernel { bool has_b = nullptr != s_.b_data; const auto alpha = onnxruntime::cuda::Consts::One; const auto beta = onnxruntime::cuda::Consts::Zero; - IAllocatorUniquePtr workspace = GetWorkSpace(context->GetComputeStream()); + IAllocatorUniquePtr workspace = GetWorkSpace(GetComputeStream(context)); auto cudnn_status = cudnnConvolutionBiasActivationForward(cudnnHandle, &alpha, s_.x_tensor, @@ -422,7 +422,7 @@ class FusedConv : public onnxruntime::cuda::CudaKernel { return Status::OK(); } - inline IAllocatorUniquePtr GetWorkSpace(onnxruntime::Stream* stream) const { + inline IAllocatorUniquePtr GetWorkSpace(void* stream) const { return GetScratchBuffer(s_.workspace_bytes, stream); } @@ -455,4 +455,4 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( } // namespace cuda } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/inverse.cc b/onnxruntime/contrib_ops/cuda/inverse.cc index 9075dda26f86b..68f06b51c0067 100644 --- a/onnxruntime/contrib_ops/cuda/inverse.cc +++ b/onnxruntime/contrib_ops/cuda/inverse.cc @@ -65,7 +65,8 @@ Status CheckForSingularity(cudaStream_t stream, const IAllocatorUniquePtr& template struct Inverse::ComputeImpl { - Status operator()(onnxruntime::Stream* ort_stream, Inverse::CublasHandle cublas_h, const Inverse* inst, const Tensor& input, Tensor& output, + Status operator()(void* ort_stream, cudaStream_t stream, Inverse::CublasHandle cublas_h, const Inverse* inst, + const Tensor& input, Tensor& output, const IAllocatorUniquePtr& info, const IAllocatorUniquePtr& pivots, size_t num_batches, size_t rows) const { using namespace onnxruntime::cuda; @@ -75,7 +76,6 @@ struct Inverse::ComputeImpl { auto info_cpu = std::make_unique(num_batches); const auto dim = static_cast(rows); const auto n_batches = static_cast(num_batches); - cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; // Make a copy of the input which will serve as a workspace as well. if constexpr (std::is_same::value || std::is_same::value) { @@ -150,13 +150,13 @@ Status Inverse::ComputeInternal(OpKernelContext* ctx) const { num_batches = static_cast(input_shape.SizeToDimension(num_dim - 2)); } - IAllocatorUniquePtr info = GetScratchBuffer(num_batches, ctx->GetComputeStream()); + IAllocatorUniquePtr info = GetScratchBuffer(num_batches, GetComputeStream(ctx)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(info.get(), 0, num_batches * sizeof(int), Stream(ctx))); - IAllocatorUniquePtr pivots = GetScratchBuffer(rows * num_batches, ctx->GetComputeStream()); + IAllocatorUniquePtr pivots = GetScratchBuffer(rows * num_batches, GetComputeStream(ctx)); utils::MLTypeCallDispatcher t_disp(input->GetElementType()); return t_disp.InvokeRet( - ctx->GetComputeStream(), GetCublasHandle(ctx), this, *input, *output, info, pivots, num_batches, rows); + GetComputeStream(ctx), Stream(ctx), GetCublasHandle(ctx), this, *input, *output, info, pivots, num_batches, rows); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/math/bias_dropout.cc b/onnxruntime/contrib_ops/cuda/math/bias_dropout.cc index 79cc656ba0491..163d85b458ba1 100644 --- a/onnxruntime/contrib_ops/cuda/math/bias_dropout.cc +++ b/onnxruntime/contrib_ops/cuda/math/bias_dropout.cc @@ -124,7 +124,7 @@ Status BiasDropout::ComputeInternal(OpKernelContext* context) const } IAllocatorUniquePtr temp_mask_buffer{}; // buffer to use if mask is not provided - auto* ort_stream = context->GetComputeStream(); + auto* ort_stream = GetComputeStream(context); void* const mask_data = [this, mask_element_count, mask, &temp_mask_buffer, ort_stream]() { if (mask) return mask->MutableDataRaw(); temp_mask_buffer = diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 707eb24a386a9..ffd1b219da03c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -49,8 +49,6 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { 0)); // no block-wise quantization for regular MoE using CudaT = typename OrtToCudaType::type; - auto stream = context->GetComputeStream(); - auto& device_prop = GetDeviceProp(); const int sm = device_prop.major * 10 + device_prop.minor; @@ -68,18 +66,13 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int); - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - - // TODO: allocate one buffer and reuse it. - IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, stream); - IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); - IAllocatorUniquePtr expert_scales = - IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, stream); + IAllocatorUniquePtr work_space = this->template GetScratchBuffer(ws_size, this->GetComputeStream(context)); + IAllocatorUniquePtr fc2_output = this->template GetScratchBuffer(fc2_output_size, this->GetComputeStream(context)); + IAllocatorUniquePtr expert_scales = this->template GetScratchBuffer(expert_scales_size, this->GetComputeStream(context)); IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = - IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream); + this->template GetScratchBuffer(expanded_source_row_to_expanded_dest_row_size, this->GetComputeStream(context)); IAllocatorUniquePtr expert_for_source_row = - IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); + this->template GetScratchBuffer(expert_for_source_row_size, this->GetComputeStream(context)); const CudaT* fc_scales_ptr = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 33cd906508bcf..3dcc03e9597e3 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -95,6 +95,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { // Output 0 - output : (batch_size, sequence_length, hidden_size) // Output 1 - present : (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size) + auto ort_stream = GetOrtStream(context); + const Tensor* input = context->Input(0); const Tensor* weights = context->Input(1); const Tensor* bias = context->Input(2); @@ -138,8 +140,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { int n = 3 * hidden_size; int k = parameters.input_hidden_size; size_t num_elements = SafeInt(m) * n; - auto gemm_buffer = GetScratchBuffer(num_elements * element_size, context->GetComputeStream()); - auto gemm_buffer_quantized = GetScratchBuffer(num_elements, context->GetComputeStream()); + auto gemm_buffer = GetScratchBuffer(num_elements * element_size, GetComputeStream(context)); + auto gemm_buffer_quantized = GetScratchBuffer(num_elements, GetComputeStream(context)); typedef typename ToCudaType::MappedType CudaT; @@ -149,7 +151,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { weights->Data(), n, gemm_buffer_quantized.get(), n, this, - context->GetComputeStream())); + GetComputeStream(context), Stream(context), + GetCublasHandle(context))); CudaT dequant_scale; CudaT input_scale = *(reinterpret_cast(input_scale_tensor->Data())); @@ -197,7 +200,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { use_cudnn_flash_attention, true); - auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + auto work_space = GetScratchBuffer(workSpaceSize, GetComputeStream(context)); typedef typename ToCudaType::MappedType CudaT; AttentionData data; @@ -220,7 +223,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { } cudnnHandle_t cudnn = GetCudnnHandle(context); - return QkvToContext(GetDeviceProp(), cublas, cudnn, context->GetComputeStream(), parameters, data); + return QkvToContext(GetDeviceProp(), cublas, cudnn, ort_stream.get(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc index 0534ed6dc7fc0..d5d7153d0c8b9 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -56,12 +56,12 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { // TODO: find a better way to create the quant_map without using a buffer // don't want to use malloc directly so asking from the caller // can create a __device__ static array for float but doesn't work for half - IAllocatorUniquePtr quant_map_buffer = GetScratchBuffer(16, ctx->GetComputeStream()); + IAllocatorUniquePtr quant_map_buffer = this->template GetScratchBuffer(16, this->GetComputeStream(ctx)); auto* quant_map_buffer_data = quant_map_buffer.get(); ORT_RETURN_IF_ERROR(SetBnbQuantMap( SafeInt(quant_type_), reinterpret_cast(quant_map_buffer_data), - static_cast(ctx->GetComputeStream()->GetHandle()))); + this->Stream(ctx))); constexpr bool transa = false; const bool transb = transB_; @@ -85,10 +85,10 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { SafeInt(helper.N()), SafeInt(helper.K()), SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle())); + this->Stream(ctx)); if (!is_4bit_done) { - IAllocatorUniquePtr b_dequant_ptr = GetScratchBuffer(N_ * K_, ctx->GetComputeStream()); + IAllocatorUniquePtr b_dequant_ptr = this->template GetScratchBuffer(N_ * K_, this->GetComputeStream(ctx)); auto* b_dequant_data = b_dequant_ptr.get(); ORT_RETURN_IF_ERROR(DequantizeBnb4( reinterpret_cast(quant_map_buffer_data), @@ -97,7 +97,7 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { reinterpret_cast(absmax_data), SafeInt(block_size_), SafeInt(N_ * K_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + this->Stream(ctx))); const CudaT alpha = ToCudaType::FromFloat(1.f); const CudaT zero = ToCudaType::FromFloat(0.f); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 3b667bf68634c..699ead654c83f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -304,7 +304,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { if (Y->Shape().Size() == 0) return Status::OK(); - cudaStream_t stream = static_cast(ctx->GetComputeStream()->GetHandle()); + cudaStream_t stream = this->Stream(ctx); typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; @@ -352,7 +352,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { onnxruntime::llm::kernels::fpA_intB_gemv::kernel_launcher(sm_, params, stream); } else { const size_t workspace_size = weightOnlyGemmRunner_->getWorkspaceSize(m, n, k); - auto workspace_buffer = GetScratchBuffer(workspace_size, ctx->GetComputeStream()); + auto workspace_buffer = this->template GetScratchBuffer(workspace_size, this->GetComputeStream(ctx)); weightOnlyGemmRunner_->gemm( a_data, @@ -394,7 +394,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { } int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; - IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); + IAllocatorUniquePtr b_data_ptr = this->template GetScratchBuffer(N_ * K_padded, this->GetComputeStream(ctx)); auto* b_data = b_data_ptr.get(); if (nbits_ == 8) { diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h index a7f0a9516584c..3345856fad98b 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -44,21 +44,33 @@ class MatMulNBits final : public CudaKernel { ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); - constexpr size_t kInputIndexScale = 2; - constexpr size_t kInputIndexZeroPoints = 3; - constexpr size_t kInputIndexGroupIndex = 4; - constexpr size_t kInputIndexBias = 5; - + constexpr int kInputIndexScale = 2; + constexpr int kInputIndexZeroPoints = 3; + constexpr int kInputIndexGroupIndex = 4; + constexpr int kInputIndexBias = 5; + +#ifdef BUILD_CUDA_EP_AS_PLUGIN + // PLUGIN BUILD ADAPTATION: The adapter Node does not expose InputDefs(), + // so we cannot check whether optional inputs (zero_points, g_idx, bias) + // truly exist at construction time. Instead, we check input count here + // and verify actual tensor presence in ComputeInternal. + ORT_UNUSED_PARAMETER(kInputIndexScale); // only used in non-plugin path for type checking + has_zero_points_ = info.GetInputCount() > kInputIndexZeroPoints; + has_g_idx_ = info.GetInputCount() > kInputIndexGroupIndex; + has_bias_ = info.GetInputCount() > kInputIndexBias; + // is_zero_points_scale_same_type_ defaults to false; checked at runtime in plugin path. +#else has_zero_points_ = info.GetInputCount() > kInputIndexZeroPoints && info.node().InputDefs()[kInputIndexZeroPoints]->Exists(); has_g_idx_ = info.GetInputCount() > kInputIndexGroupIndex && info.node().InputDefs()[kInputIndexGroupIndex]->Exists(); has_bias_ = info.GetInputCount() > kInputIndexBias && info.node().InputDefs()[kInputIndexBias]->Exists(); - sm_ = this->GetDeviceProp().major * 10 + this->GetDeviceProp().minor; if (has_zero_points_) { int32_t zero_point_type = info.node().InputDefs()[kInputIndexZeroPoints]->TypeAsProto()->tensor_type().elem_type(); int32_t scale_type = info.node().InputDefs()[kInputIndexScale]->TypeAsProto()->tensor_type().elem_type(); is_zero_points_scale_same_type_ = (zero_point_type == scale_type); } +#endif + sm_ = this->GetDeviceProp().major * 10 + this->GetDeviceProp().minor; #if USE_FPA_INTB_GEMM if constexpr (std::is_same::value || std::is_same::value) { diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index 31a793bb86f17..4b261346887f6 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -57,13 +57,8 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, const Tensor* fc2_scales, const Tensor* fc3_scales_optional, const cudaDeviceProp& device_prop) const { - auto stream = context->GetComputeStream(); - const int sm = device_prop.major * 10 + device_prop.minor; - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - using CudaT = typename OrtToCudaType::type; ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, @@ -81,14 +76,13 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int); - IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, stream); - IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); - IAllocatorUniquePtr expert_scales = - IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, stream); + IAllocatorUniquePtr work_space = this->template GetScratchBuffer(ws_size, this->GetComputeStream(context)); + IAllocatorUniquePtr fc2_output = this->template GetScratchBuffer(fc2_output_size, this->GetComputeStream(context)); + IAllocatorUniquePtr expert_scales = this->template GetScratchBuffer(expert_scales_size, this->GetComputeStream(context)); IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = - IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream); + this->template GetScratchBuffer(expanded_source_row_to_expanded_dest_row_size, this->GetComputeStream(context)); IAllocatorUniquePtr expert_for_source_row = - IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); + this->template GetScratchBuffer(expert_for_source_row_size, this->GetComputeStream(context)); moe_runner.run_moe_fc( reinterpret_cast(input->template Data()), diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 3e93a527877c5..5ba8833d0d8b7 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -228,7 +228,7 @@ Status QOrderedAttention::ComputeInternal(OpKernelContext* context) const { output_shape[2] = static_cast(hidden_size); Tensor* output = context->Output(0, output_shape); - cublasLtHandle_t cublasLt = CublasLtHandle(); + cublasLtHandle_t cublasLt = this->GetCublasLtHandle(context); // Use GEMM for fully connection. int m = batch_size * sequence_length; int n = 3 * hidden_size; @@ -236,7 +236,7 @@ Status QOrderedAttention::ComputeInternal(OpKernelContext* context) const { int64_t size_of_attention_scores = ((int64_t)batch_size) * num_heads_ * sequence_length * sequence_length; // transposed qkv_layer, union(stacked, attention probs + attention scores) - auto gemm_buffer_quantized = GetScratchBuffer((int64_t)m * n + std::max((int64_t)m * n, 2 * size_of_attention_scores), context->GetComputeStream()); + auto gemm_buffer_quantized = GetScratchBuffer((int64_t)m * n + std::max((int64_t)m * n, 2 * size_of_attention_scores), GetComputeStream(context)); int8_t* stacked_qkv_layers = gemm_buffer_quantized.get() + ((int64_t)m * n); int8_t* tranposed_qkv_layers = gemm_buffer_quantized.get(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_longformer_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_longformer_attention.cc index 4e0140f34e869..65b09a98f139c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_longformer_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_longformer_attention.cc @@ -100,7 +100,7 @@ QOrderedLongformerAttention::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, shape); cublasHandle_t cublas = GetCublasHandle(context); - cublasLtHandle_t cublasLt = CublasLtHandle(); + cublasLtHandle_t cublasLt = this->GetCublasLtHandle(context); cudaStream_t stream = Stream(context); CUBLAS_RETURN_IF_ERROR(cublasSetStream(cublas, stream)); @@ -109,11 +109,11 @@ QOrderedLongformerAttention::ComputeInternal(OpKernelContext* context) const { // TODO: only calculate once per model. // Build Global Index - auto global_index_buffer = GetScratchBuffer(static_cast(batch_size) * static_cast(sequence_length), context->GetComputeStream()); - auto batch_global_num_buffer = GetScratchBuffer(batch_size, context->GetComputeStream()); + auto global_index_buffer = GetScratchBuffer(static_cast(batch_size) * static_cast(sequence_length), GetComputeStream(context)); + auto batch_global_num_buffer = GetScratchBuffer(batch_size, GetComputeStream(context)); size_t global_scratch_bytes = GetGlobalScratchSize(sequence_length); - auto global_scratch_buffer = GetScratchBuffer(global_scratch_bytes, context->GetComputeStream()); + auto global_scratch_buffer = GetScratchBuffer(global_scratch_bytes, GetComputeStream(context)); auto& device_prop = GetDeviceProp(); ORT_RETURN_IF_ERROR(BuildGlobalIndex(device_prop, @@ -152,7 +152,7 @@ QOrderedLongformerAttention::ComputeInternal(OpKernelContext* context) const { // Buffer for GEMM outputs of q, k, v, global_q, global_k and global_v // TODO(tianleiwu): compact global_q only need batch_size * window * hidden_size * element_size buffer size. size_t qkv_3 = qkv_size + qkv_size + 2 * qkv_count * sizeof(int8_t); - auto gemm_buffer = GetScratchBuffer(qkv_3, context->GetComputeStream()); + auto gemm_buffer = GetScratchBuffer(qkv_3, GetComputeStream(context)); const float* scale_input = context->Input(1)->Data(); const float* scale_weight = context->Input(3)->Data(); @@ -220,7 +220,7 @@ QOrderedLongformerAttention::ComputeInternal(OpKernelContext* context) const { window_, disable_compact_memory); - auto workspace_buffer = GetScratchBuffer(workSpaceSize + output_elements * element_size, context->GetComputeStream()); + auto workspace_buffer = GetScratchBuffer(workSpaceSize + output_elements * element_size, GetComputeStream(context)); MLFloat16* out_fp16 = (MLFloat16*)(((int8_t*)workspace_buffer.get()) + workSpaceSize); ORT_RETURN_IF_ERROR(LaunchLongformerAttentionKernel(device_prop, cublas, @@ -252,7 +252,7 @@ QOrderedLongformerAttention::ComputeInternal(OpKernelContext* context) const { *scale_output, batch_size, sequence_length, hidden_size)); // Defer release of pinned memory since cudaStreamSynchronize is not used here and kernel need access the buffer. - this->AddDeferredReleaseCPUPtr(pinned_buffer.release(), context->GetComputeStream()); + this->AddDeferredReleaseCPUPtr(pinned_buffer.release(), GetComputeStream(context)); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc index a64f628f245e6..520c205179dc7 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_matmul.cc @@ -146,7 +146,7 @@ Status QOrderedMatMul::ComputeInternal(OpKernelContext* context) const { } Tensor* tensor_Y = context->Output(0, shapeY); - cublasLtHandle_t cublasLt = CublasLtHandle(); + cublasLtHandle_t cublasLt = this->GetCublasLtHandle(context); cudaStream_t stream = Stream(context); auto& device_prop = GetDeviceProp(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq.cc index 347697e588b8b..d7ac36d16e8e2 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq.cc @@ -138,7 +138,7 @@ Status QuantizeWithOrder::ComputeInternal(OpKernelContext* context) const { input_tensor, (cublasLtOrder_t)order_input_, (cublasLtOrder_t)order_output_, rows, cols, batch, n)); const float* scale = context->Input(1)->Data(); Tensor* output_tensor = context->Output(0, input_tensor.Shape()); - cublasLtHandle_t cublasLt = CublasLtHandle(); + cublasLtHandle_t cublasLt = this->GetCublasLtHandle(context); cudaStream_t stream = Stream(context); const auto& device_prop = GetDeviceProp(); @@ -154,7 +154,7 @@ Status QuantizeWithOrder::ComputeInternal(OpKernelContext* context) const { *scale, gsl::narrow(batch), gsl::narrow(rows), gsl::narrow(cols))); } } else { - auto q8_buffer = GetScratchBuffer(order_input_ == order_output_ ? 0LL : n, context->GetComputeStream()); + auto q8_buffer = GetScratchBuffer(order_input_ == order_output_ ? 0LL : n, GetComputeStream(context)); int8_t* dst = (order_input_ == order_output_ ? output_tensor->MutableData() : q8_buffer.get()); if (input_tensor.IsDataType()) { ORT_RETURN_IF_ERROR(QOrderQuantize_Strict(stream, device_prop, (const __half*)input_tensor.Data(), dst, *scale, n)); diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc index 865a1dc29ce47..656fde2f46ab8 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc @@ -60,6 +60,8 @@ SparseAttention::SparseAttention(const OpKernelInfo& info) template Status SparseAttention::ComputeInternal(OpKernelContext* context) const { + auto ort_stream = GetOrtStream(context); + auto& device_prop = GetDeviceProp(); if constexpr (std::is_same::value) { if (device_prop.major < 8) { @@ -219,8 +221,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length * parameters.head_size; rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length; } - onnxruntime::Stream* stream = context->GetComputeStream(); - auto rotary_buffer = GetScratchBuffer(rotary_buffer_bytes, stream); + auto rotary_buffer = GetScratchBuffer(rotary_buffer_bytes, GetComputeStream(context)); data.rotary_buffer = reinterpret_cast(rotary_buffer.get()); size_t transposed_q_bytes = 0; @@ -228,7 +229,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { transposed_q_bytes = parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(T); } - auto transposed_q_buffer = GetScratchBuffer(transposed_q_bytes, stream); + auto transposed_q_buffer = GetScratchBuffer(transposed_q_bytes, GetComputeStream(context)); if (transposed_q_buffer) { data.transposed_q_buffer = reinterpret_cast(transposed_q_buffer.get()); } @@ -239,7 +240,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { (parameters.num_heads + 2 * parameters.kv_num_heads) * parameters.head_size * sizeof(T)); } - auto unpacked_qkv_buffer = GetScratchBuffer(unpacked_qkv_bytes, stream); + auto unpacked_qkv_buffer = GetScratchBuffer(unpacked_qkv_bytes, GetComputeStream(context)); if (unpacked_qkv_buffer) { data.unpacked_qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); } @@ -303,7 +304,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { } } - v2_kernel_buffer = GetScratchBuffer(v2_kernel_buffer_size, stream); + v2_kernel_buffer = GetScratchBuffer(v2_kernel_buffer_size, GetComputeStream(context)); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(v2_kernel_buffer.get(), v2_kernel_inputs_pinned, sizeof(int32_t) * v2_kernel_buffer_size, cudaMemcpyHostToDevice, cuda_stream)); @@ -317,7 +318,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { data.active_q_blocks = active_q_blocks; } - return QkvToContext(device_prop, stream, parameters, data); + return QkvToContext(device_prop, ort_stream.get(), parameters, data); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc index 381316f605fc9..2c9f5c1cd9014 100644 --- a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc @@ -35,7 +35,7 @@ Status DynamicTimeWarping::ComputeInternal(OpKernelContext* ctx) const { size_t max_index_len = 0; size_t buffer_size_in_bytes = GetDynamicTimeWarpingBufferSize(1, rows, cols, max_index_len); - IAllocatorUniquePtr buffer = GetScratchBuffer(buffer_size_in_bytes, ctx->GetComputeStream()); + IAllocatorUniquePtr buffer = GetScratchBuffer(buffer_size_in_bytes, GetComputeStream(ctx)); size_t result_len = 0; ORT_RETURN_IF_ERROR(LaunchDynamicTimeWarping( diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e2e60066ec36d..36d0287685838 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -139,10 +139,12 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest, output_indices = std::move(*Tensor::Create(DataTypeImpl::GetType(), output_shape, std::move(allocator))); Status result; + auto cuda_stream = stream ? static_cast(stream->GetHandle()) : nullptr; if (input->IsDataType()) { result = TopKImpl(nullptr, // We limit number of beams in BeamSearchParameters, so K <= 256 and use NULL here false /*use_deterministic_compute*/, - stream, + cuda_stream, + nullptr, // alloc_stream not needed when kernel is nullptr input->Data(), static_cast(output_values.MutableDataRaw()), static_cast(output_indices.MutableDataRaw()), @@ -157,7 +159,8 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest, } else if (input->IsDataType()) { result = TopKImpl(nullptr, false /*use_deterministic_compute*/, - stream, + cuda_stream, + nullptr, // alloc_stream not needed when kernel is nullptr input->Data(), static_cast(output_values.MutableDataRaw()), static_cast(output_indices.MutableDataRaw()), @@ -411,7 +414,7 @@ Status ProcessLogits(const OrtValue& logits, // const CudaT* X_data = is_reuse_logits_buffer ? logits_data : reinterpret_cast(next_token_logits.data()); ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward( - ort_stream, Y_data, X_data, vocab_size, + ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr, Y_data, X_data, vocab_size, is_reuse_logits_buffer ? padded_vocab_size : vocab_size, vocab_size, batch_size * num_beams))); diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h index 7be3a4851aaed..d4e7e02f51fdd 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h @@ -93,7 +93,7 @@ Status Sample(AllocatorPtr& allocator, #endif gsl::span& d_sorted_softmaxed_score = sampling_state->d_sorted_softmaxed_score; - ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward(stream, + ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward(cuda_stream, d_sorted_softmaxed_score.data(), reinterpret_cast(d_sorted_score.data()), parameters->vocab_size, @@ -127,7 +127,7 @@ Status Sample(AllocatorPtr& allocator, #endif gsl::span& d_softmaxed_score = sampling_state->d_softmaxed_score; - ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward(stream, + ORT_RETURN_IF_ERROR((dispatch_blockwise_softmax_forward(cuda_stream, d_softmaxed_score.data(), reinterpret_cast(next_token_scores.data()), parameters->vocab_size, diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h index f08f134d0c080..f979b0e6230b5 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h @@ -3,6 +3,11 @@ #pragma once +// SHARED_PROVIDER is defined in the in-tree CUDA EP shared library build +// (onnxruntime_providers_cuda). It gates out framework headers that are +// re-exported via the DLL-boundary proxy. The plugin EP build uses a +// different flag (BUILD_CUDA_EP_AS_PLUGIN) and the force-include adapter +// headers instead. Both builds need these headers excluded. #ifndef SHARED_PROVIDER #include "core/common/common.h" #include "core/common/type_list.h" @@ -66,32 +71,41 @@ using ConstantOfShapeDefaultOutputTypesOpset23 = uint8_t, uint16_t, uint32_t, uint64_t, bool>; -template -class ConstantOfShapeBase { +#define ORT_CONSTANT_OF_SHAPE_VALUE_TYPES(M) \ + M(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) \ + M(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) \ + M(MLFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) \ + M(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) \ + M(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) \ + M(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16) \ + M(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) \ + M(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) \ + M(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) \ + M(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16) \ + M(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32) \ + M(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64) \ + M(BFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) + +template +struct ConstantOfShapeOrtType; + +#define DEFINE_CONSTANT_OF_SHAPE_ORT_TYPE(c_type, ort_type) \ + template <> \ + struct ConstantOfShapeOrtType { \ + static constexpr ONNXTensorElementDataType value = ort_type; \ + }; + +ORT_CONSTANT_OF_SHAPE_VALUE_TYPES(DEFINE_CONSTANT_OF_SHAPE_ORT_TYPE) + +#undef DEFINE_CONSTANT_OF_SHAPE_ORT_TYPE + +class ConstantOfShapeCore { protected: - ConstantOfShapeBase(const OpKernelInfo& info) { -#ifndef SHARED_PROVIDER - ONNX_NAMESPACE::TensorProto t_proto; - auto* t_proto_p = &t_proto; -#else - auto t_proto = ONNX_NAMESPACE::TensorProto::Create(); - auto* t_proto_p = t_proto.get(); -#endif - if (info.GetAttr("value", t_proto_p).IsOK()) { - for (auto dim : t_proto_p->dims()) { - ORT_ENFORCE(dim == 1, "The value attribute of ConstantOfShape must be a single-element tensor"); - } - SetValueFromTensorProto(*t_proto_p); - } else { - float f_value = 0.f; - SetValue(sizeof(float), reinterpret_cast(&f_value)); - } - } - void* GetValuePtr() const { return p_value_; } - static Status PrepareCompute(OpKernelContext* ctx, Tensor** output_tensor) { - const auto shape_tensor = ctx->Input(0); + template + static Status PrepareCompute(ContextType* ctx, Tensor** output_tensor) { + const auto shape_tensor = ctx->template Input(0); const auto& input_shape = shape_tensor->Shape(); // If empty the output is a scalar with empty shape @@ -99,7 +113,7 @@ class ConstantOfShapeBase { // one value ORT_RETURN_IF_NOT(input_shape.NumDimensions() > 0, "Must have a valid input shape."); - const auto span = shape_tensor->DataAsSpan(); + const auto span = shape_tensor->template DataAsSpan(); TensorShape output_shape(span); (*output_tensor) = ctx->Output(0, output_shape); @@ -107,31 +121,48 @@ class ConstantOfShapeBase { return Status::OK(); } - private: - union SizeBasedValue { - int8_t int8_; - int16_t int16_; - int32_t int32_; - int64_t int64_; - } s_value_; - void* p_value_; + void SetDefaultValue() { + float f_value = 0.f; + SetValue(sizeof(float), &f_value); + } + + template + void SetValueFromOrtTensor(ONNXTensorElementDataType tensor_type, const void* data) { + bool handled = false; + switch (tensor_type) { +#define CASE_SET_ORT_VALUE(c_type, ort_type) \ + case ConstantOfShapeOrtType::value: { \ + if (utils::HasType()) { \ + SetValue(sizeof(c_type), data); \ + handled = true; \ + } \ + break; \ + } + ORT_CONSTANT_OF_SHAPE_VALUE_TYPES(CASE_SET_ORT_VALUE) +#undef CASE_SET_ORT_VALUE + default: + ORT_THROW("Unsupported value attribute datatype: ", static_cast(tensor_type)); + } - void SetValue(size_t size, void* value) { + ORT_ENFORCE(handled, "Unsupported value attribute datatype in this build: ", static_cast(tensor_type)); + } + + void SetValue(size_t size, const void* value) { switch (size) { case sizeof(int8_t): - s_value_.int8_ = *(reinterpret_cast(value)); + s_value_.int8_ = *(reinterpret_cast(value)); p_value_ = reinterpret_cast(&(s_value_.int8_)); break; case sizeof(int16_t): - s_value_.int16_ = *(reinterpret_cast(value)); + s_value_.int16_ = *(reinterpret_cast(value)); p_value_ = reinterpret_cast(&(s_value_.int16_)); break; case sizeof(int32_t): - s_value_.int32_ = *(reinterpret_cast(value)); + s_value_.int32_ = *(reinterpret_cast(value)); p_value_ = reinterpret_cast(&(s_value_.int32_)); break; case sizeof(int64_t): - s_value_.int64_ = *(reinterpret_cast(value)); + s_value_.int64_ = *(reinterpret_cast(value)); p_value_ = reinterpret_cast(&(s_value_.int64_)); break; default: @@ -139,10 +170,43 @@ class ConstantOfShapeBase { } } + private: + union SizeBasedValue { + int8_t int8_; + int16_t int16_; + int32_t int32_; + int64_t int64_; + }; + mutable SizeBasedValue s_value_{}; + mutable void* p_value_ = nullptr; +}; + +template +class ConstantOfShapeBase : public ConstantOfShapeCore { + protected: + ConstantOfShapeBase(const OpKernelInfo& info) { +#ifndef SHARED_PROVIDER + ONNX_NAMESPACE::TensorProto t_proto; + auto* t_proto_p = &t_proto; +#else + auto t_proto = ONNX_NAMESPACE::TensorProto::Create(); + auto* t_proto_p = t_proto.get(); +#endif + if (info.GetAttr("value", t_proto_p).IsOK()) { + for (auto dim : t_proto_p->dims()) { + ORT_ENFORCE(dim == 1, "The value attribute of ConstantOfShape must be a single-element tensor"); + } + SetValueFromTensorProto(*t_proto_p); + } else { + SetDefaultValue(); + } + } + void SetValueFromTensorProto(const ONNX_NAMESPACE::TensorProto&); }; -#define CASE_FETCH_VALUE_DATA(c_type) \ +// ort_type parameter unused here but required for ORT_CONSTANT_OF_SHAPE_VALUE_TYPES X-macro conformance. +#define CASE_FETCH_VALUE_DATA(c_type, ort_type) \ case utils::ToTensorProtoElementType(): { \ if (utils::HasType()) { \ c_type val; \ @@ -164,19 +228,7 @@ void ConstantOfShapeBase::SetValueFromTensorProto(const O const size_t raw_data_len = utils::HasRawData(t_proto) ? t_proto.raw_data().size() : 0; bool handled = false; switch (tensor_type) { - CASE_FETCH_VALUE_DATA(bool) - CASE_FETCH_VALUE_DATA(float) - CASE_FETCH_VALUE_DATA(MLFloat16) - CASE_FETCH_VALUE_DATA(double) - CASE_FETCH_VALUE_DATA(int8_t) - CASE_FETCH_VALUE_DATA(int16_t) - CASE_FETCH_VALUE_DATA(int32_t) - CASE_FETCH_VALUE_DATA(int64_t) - CASE_FETCH_VALUE_DATA(uint8_t) - CASE_FETCH_VALUE_DATA(uint16_t) - CASE_FETCH_VALUE_DATA(uint32_t) - CASE_FETCH_VALUE_DATA(uint64_t) - CASE_FETCH_VALUE_DATA(BFloat16) + ORT_CONSTANT_OF_SHAPE_VALUE_TYPES(CASE_FETCH_VALUE_DATA) default: ORT_THROW("Unsupported value attribute datatype: ", tensor_type); } @@ -185,5 +237,6 @@ void ConstantOfShapeBase::SetValueFromTensorProto(const O } #undef CASE_FETCH_VALUE_DATA +#undef ORT_CONSTANT_OF_SHAPE_VALUE_TYPES } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h index 8bc891bb4f377..a5c0606a55e7d 100644 --- a/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/deform_conv_attributes.h @@ -15,15 +15,16 @@ namespace onnxruntime { // See https://onnx.ai/onnx/operators/onnx__DeformConv.html // Used by both CPU and CUDA implementations (CUDA includes from here). struct DeformConvAttributes { - explicit DeformConvAttributes(const OpKernelInfo& info) { + template + explicit DeformConvAttributes(const KernelInfoType& info) { // Optional attributes. // If not present, they will be empty/default, and handled in Compute/ComputeInternal. (void)info.GetAttrs("kernel_shape", kernel_shape); (void)info.GetAttrs("strides", strides); (void)info.GetAttrs("pads", pads); (void)info.GetAttrs("dilations", dilations); - group = info.GetAttrOrDefault("group", 1); - offset_group = info.GetAttrOrDefault("offset_group", 1); + group = info.template GetAttrOrDefault("group", 1); + offset_group = info.template GetAttrOrDefault("offset_group", 1); } TensorShapeVector kernel_shape; diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index ff5498c0b4644..ded4813276b1d 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -16,6 +16,7 @@ #include "core/common/status.h" #include #include +#include #ifndef SHARED_PROVIDER #include "core/framework/op_kernel.h" #endif @@ -166,12 +167,28 @@ inline void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span +struct has_input_defs : std::false_type {}; +template +struct has_input_defs().InputDefs())>> : std::true_type {}; +} // namespace upsamplebase_detail + class UpsampleBase { public: // Make this available in other EP via provider bridge // it works iff output_shape is specified +#ifdef SHARED_PROVIDER void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, InlinedVector& scales) const; +#else + void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, + InlinedVector& scales) const { + upsamplebase_helper::AdjustOutputSizeAsPolicy(output_dims, input_dims, scales, keep_aspect_ratio_policy_, axes_); + } +#endif protected: template @@ -266,11 +283,9 @@ class UpsampleBase { const Tensor* scale; bool get_scale = info.TryGetConstantInput(scales_input_idx_, &scale); int64_t rank = -1; - if constexpr (std::is_same_v) { + if constexpr (upsamplebase_detail::has_input_defs::value) { auto x_shape = node.InputDefs()[0]->Shape(); - if (x_shape != nullptr) { - rank = x_shape->dim_size(); - } + rank = x_shape ? x_shape->dim_size() : -1; } else { auto type_info = info.GetKernelInfo().GetInputTypeInfo(0); auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); @@ -652,13 +667,6 @@ class UpsampleBase { } }; // UpsampleBase -#ifndef SHARED_PROVIDER -inline void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, - InlinedVector& scales) const { - upsamplebase_helper::AdjustOutputSizeAsPolicy(output_dims, input_dims, scales, keep_aspect_ratio_policy_, axes_); -} -#endif - } // namespace onnxruntime #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 42989d09cfa85..f90eb2813afc4 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -3,6 +3,8 @@ #pragma once +#ifndef BUILD_CUDA_EP_AS_PLUGIN + // The following three lines were copied from ABSL // cutlass needs them, because cutlass uses "and"/"or" keywords #ifdef __cplusplus @@ -52,6 +54,7 @@ namespace cuda { #define CUDNN2_RETURN_IF_ERROR(expr, m) ORT_RETURN_IF_ERROR(CUDNN_CALL2(expr, m)) #define CUFFT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUFFT_CALL(expr)) #endif + // Type mapping for MLFloat16 to half template class ToCudaType { @@ -238,3 +241,7 @@ class HalfGemmOptions { } // namespace onnxruntime #include "core/providers/cuda/cuda_common_type_helpers.h" +#else +// Define shims and basic types needed by kernels in plugin build when cuda_common.h is included +#include "core/providers/cuda/plugin/cuda_kernel_adapter.h" +#endif diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index 70ba7657e6747..13bf5b37490e0 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -3,6 +3,8 @@ #pragma once +#ifndef BUILD_CUDA_EP_AS_PLUGIN + #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cuda/cuda_fwd.h" @@ -12,6 +14,23 @@ namespace onnxruntime { namespace cuda { +class OrtStreamAdapter { + public: + explicit OrtStreamAdapter(onnxruntime::Stream* stream) : stream_(stream) {} + explicit OrtStreamAdapter(void* stream) : stream_(static_cast(stream)) {} + + onnxruntime::Stream* get() const { return stream_; } + operator onnxruntime::Stream*() const { return stream_; } + + private: + onnxruntime::Stream* stream_; +}; + +#ifndef CUDA_STREAM_FROM_CTX +// Helper for kernels that need a cudaStream_t from OpKernelContext in both framework and plugin builds. +#define CUDA_STREAM_FROM_CTX(ctx) static_cast(GetComputeStream(ctx)) +#endif + // ----------------------------------------------------------------------- // Base class for CUDA kernels // ----------------------------------------------------------------------- @@ -49,6 +68,19 @@ class CudaKernel : public OpKernel { stream); } + // void* overload for dual-build compatibility with the plugin EP. + // In the framework build, the void* is always a static_cast(onnxruntime::Stream*). + template + inline IAllocatorUniquePtr GetScratchBuffer(size_t count_or_bytes, void* stream) const { + return GetScratchBuffer(count_or_bytes, static_cast(stream)); + } + + // Resolve nullptr ambiguity between Stream* and void* overloads. + template + inline IAllocatorUniquePtr GetScratchBuffer(size_t count_or_bytes, std::nullptr_t) const { + return GetScratchBuffer(count_or_bytes, static_cast(nullptr)); + } + // Different from GetScratchBuffer which use IAllocator::Alloc() to allocate memory, // this GetTransientScratchBuffer will call IAllocator::Reserve() to allocate memory. // IAllocator::Reserve() optionally implement some allocation logic that by-passes any arena-based @@ -65,6 +97,11 @@ class CudaKernel : public OpKernel { cuda_ep_stream->EnqueDeferredCPUBuffer(p); } + // void* overload for dual-build compatibility with the plugin EP. + inline void AddDeferredReleaseCPUPtr(void* p, void* stream) const { + AddDeferredReleaseCPUPtr(p, static_cast(stream)); + } + template inline IAllocatorUniquePtr AllocateBufferOnCPUPinned(size_t count_or_bytes) const { if (count_or_bytes == 0) return nullptr; @@ -72,12 +109,20 @@ class CudaKernel : public OpKernel { } const cudaDeviceProp& GetDeviceProp() const { return provider_->GetDeviceProp(); } + int GetCudnnConvAlgo() const { return provider_->GetCudnnConvAlgo(); } + bool GetCudnnConvUseMaxWorkspace() const { return provider_->GetCudnnConvUseMaxWorkspace(); } + bool GetCudnnConv1dPadToNc1d() const { return provider_->GetCudnnConv1dPadToNc1d(); } + bool IsFuseConvBias() const { return provider_->IsFuseConvBias(); } // Compatibility helper used by kernels that need the underlying ORT stream object. inline onnxruntime::Stream* GetComputeStream(OpKernelContext* ctx) const { return ctx ? ctx->GetComputeStream() : nullptr; } + inline OrtStreamAdapter GetOrtStream(OpKernelContext* ctx) const { + return OrtStreamAdapter(GetComputeStream(ctx)); + } + inline cudaStream_t Stream(OpKernelContext* ctx) const { auto* stream = ctx->GetComputeStream(); return stream ? static_cast(stream->GetHandle()) : nullptr; @@ -119,6 +164,10 @@ class CudaKernel : public OpKernel { return stream ? GetCublasHandle(stream) : DefaultCublasHandle(); } + inline cublasLtHandle_t GetCublasLtHandle(OpKernelContext* /*ctx*/) const { + return provider_->PerThreadCublasLtHandle(); + } + bool UseTF32() const { return provider_->UseTF32(); } @@ -174,6 +223,11 @@ class CudaKernel : public OpKernel { return Status::OK(); } + // void* overload for dual-build compatibility with the plugin EP. + Status CopyToGpu(void* stream) { + return CopyToGpu(static_cast(stream)); + } + T* CpuPtr() const { return cpu_pinned_copy_.get(); } @@ -234,3 +288,7 @@ class CudaKernel : public OpKernel { } // namespace cuda } // namespace onnxruntime + +#else +#include "core/providers/cuda/plugin/cuda_kernel_adapter.h" +#endif diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc index 5839d10b4345f..e36f745a8fdfa 100755 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -176,7 +176,7 @@ class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(16, 19, float, GridSample); class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(20, 21, float, GridSample); class CUDA_NHWC_OP_TYPED_CLASS_NAME(22, float, GridSample); -onnxruntime::common::Status RegisterCudaNhwcContribKernels(onnxruntime::KernelRegistry& kernel_registry) { +onnxruntime::common::Status RegisterCudaNhwcContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn nhwc_function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.h b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.h index 5a4d3493fdae6..f554d48f4d1f5 100644 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.h +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.h @@ -8,14 +8,14 @@ namespace onnxruntime::cuda { -onnxruntime::common::Status RegisterCudaNhwcKernels(onnxruntime::KernelRegistry& kernel_registry); +onnxruntime::common::Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry); } // namespace onnxruntime::cuda #ifndef DISABLE_CONTRIB_OPS namespace onnxruntime::contrib::cuda { -onnxruntime::common::Status RegisterCudaNhwcContribKernels(onnxruntime::KernelRegistry& kernel_registry); +onnxruntime::common::Status RegisterCudaNhwcContribKernels(KernelRegistry& kernel_registry); } // namespace onnxruntime::contrib::cuda #endif diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index b267ef6bed64f..6725d1120be23 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -149,9 +149,11 @@ struct Consts { inline double ClampCudnnBatchNormEpsilon(double epsilon) { if (epsilon < CUDNN_BN_MIN_EPSILON) { +#ifndef BUILD_CUDA_EP_AS_PLUGIN if (CUDNN_BN_MIN_EPSILON - epsilon > FLT_EPSILON) LOGS_DEFAULT(WARNING) << "Provided epsilon is smaller than CUDNN_BN_MIN_EPSILON. " << "Setting it to CUDNN_BN_MIN_EPSILON"; +#endif return CUDNN_BN_MIN_EPSILON; } return epsilon; diff --git a/onnxruntime/core/providers/cuda/generator/constant_of_shape.h b/onnxruntime/core/providers/cuda/generator/constant_of_shape.h index 99c5da0615ede..7f0268923cb3e 100644 --- a/onnxruntime/core/providers/cuda/generator/constant_of_shape.h +++ b/onnxruntime/core/providers/cuda/generator/constant_of_shape.h @@ -10,6 +10,46 @@ namespace onnxruntime { namespace cuda { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + +// Plugin build: keep the attribute fetch self-contained while reusing the shared +// ConstantOfShapeCore helpers for default handling and supported type mapping. +// ConstantOfShapeBase still depends on TensorProto/UnpackTensor utilities that the +// plugin build avoids, so the plugin path reads the attribute via the ORT C API instead. +class ConstantOfShape final : public ConstantOfShapeCore, public CudaKernel { + public: + explicit ConstantOfShape(const OpKernelInfo& info) : CudaKernel(info) { + InitValue(info); + } + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConstantOfShape); + + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + void InitValue(const OpKernelInfo& info) { + Ort::AllocatorWithDefaultOptions allocator; + auto ort_info = info.GetKernelInfo(); + try { + Ort::Value value_tensor = ort_info.GetTensorAttribute("value", allocator); + auto type_and_shape = value_tensor.GetTensorTypeAndShapeInfo(); + size_t elem_count = type_and_shape.GetElementCount(); + ORT_ENFORCE(elem_count == 1 || elem_count == 0, + "The value attribute of ConstantOfShape must be a single-element tensor"); + if (elem_count == 1) { + SetValueFromOrtTensor( + type_and_shape.GetElementType(), value_tensor.GetTensorRawData()); + } else { + SetDefaultValue(); + } + } catch (const Ort::Exception&) { + SetDefaultValue(); + } + } +}; + +#else // !BUILD_CUDA_EP_AS_PLUGIN + class ConstantOfShape final : public ConstantOfShapeBase<>, public CudaKernel { public: explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), CudaKernel(info) {} @@ -19,5 +59,7 @@ class ConstantOfShape final : public ConstantOfShapeBase<>, public CudaKernel { Status ComputeInternal(OpKernelContext* ctx) const override; }; +#endif // BUILD_CUDA_EP_AS_PLUGIN + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/integer_gemm.cc b/onnxruntime/core/providers/cuda/integer_gemm.cc index ef22f4793befc..a3e8ba0cc41af 100644 --- a/onnxruntime/core/providers/cuda/integer_gemm.cc +++ b/onnxruntime/core/providers/cuda/integer_gemm.cc @@ -17,12 +17,13 @@ constexpr int roundoff(int v, int d) { Status GemmInt8(int m, int n, int k, int32_t alpha, int32_t beta, const int8_t* a, int lda, const int8_t* b, int ldb, int32_t* c, int ldc, - const CudaKernel* cuda_kernel, onnxruntime::Stream* ort_stream) { + const CudaKernel* cuda_kernel, void* alloc_stream, cudaStream_t cuda_stream, + cublasHandle_t cublas_handle) { ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null"); ORT_ENFORCE(cuda_kernel != nullptr, "kernel is null"); - ORT_ENFORCE(ort_stream != nullptr, "Cuda kernel must have the stream instance"); + ORT_ENFORCE(cuda_stream != nullptr, "Cuda kernel must have the cuda stream"); - cudaStream_t stream = static_cast(ort_stream->GetHandle()); + cudaStream_t stream = cuda_stream; // pad A and B to make their leading dimension be multiples of 32 // because cublasGemmEx requires: @@ -34,7 +35,7 @@ Status GemmInt8(int m, int n, int k, IAllocatorUniquePtr a_padded; if ((mask & lda_aligned) != 0) { lda_aligned = roundoff(lda, 32); - a_padded = cuda_kernel->GetScratchBuffer(SafeInt(m) * lda_aligned, ort_stream); + a_padded = cuda_kernel->GetScratchBuffer(SafeInt(m) * lda_aligned, alloc_stream); cudaMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, cudaMemcpyDeviceToDevice, stream); } @@ -42,14 +43,12 @@ Status GemmInt8(int m, int n, int k, IAllocatorUniquePtr b_padded; if ((mask & ldb_aligned) != 0) { ldb_aligned = roundoff(ldb, 32); - b_padded = cuda_kernel->GetScratchBuffer(SafeInt(k) * ldb_aligned, ort_stream); + b_padded = cuda_kernel->GetScratchBuffer(SafeInt(k) * ldb_aligned, alloc_stream); cudaMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, cudaMemcpyDeviceToDevice, stream); } - auto cublas = cuda_kernel->GetCublasHandleOrDefault(ort_stream); - CUBLAS_RETURN_IF_ERROR(cublasGemmEx( - cublas, + cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 0ed6210fb4d29..27b81691c70af 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -19,6 +19,19 @@ using namespace onnxruntime::cuda; namespace onnxruntime { namespace cuda { +namespace llm_attention_detail { + +template +bool HasOutput(const NodeType& node, size_t output_index) { + if constexpr (requires(const NodeType& candidate) { candidate.OutputCount(); candidate.OutputExists(output_index); }) { + return node.OutputCount() > output_index && node.OutputExists(output_index); + } else { + return node.OutputDefs().size() > output_index && node.OutputDefs()[output_index]->Exists(); + } +} + +} // namespace llm_attention_detail + #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ Attention, \ @@ -60,7 +73,8 @@ Attention::Attention(const OpKernelInfo& info) : CudaKernel(info) { kv_num_heads_ = static_cast(info.GetAttrOrDefault("kv_num_heads", 0)); q_num_heads_ = static_cast(info.GetAttrOrDefault("q_num_heads", 0)); int mode = static_cast(info.GetAttrOrDefault("qk_matmul_output_mode", 0)); - qk_matmul_output_mode_ = info.node().OutputDefs().size() >= 4 && info.node().OutputDefs()[3]->Exists() + const auto& node = info.node(); + qk_matmul_output_mode_ = llm_attention_detail::HasOutput(node, 3) ? static_cast(mode) : attention_helper::QKMatMulOutputMode::kNone; ORT_ENFORCE(qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kNone || @@ -138,7 +152,7 @@ Status Attention::ConvertAttnMaskToBias( using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; int64_t num_elements = attn_mask->Shape().Size(); converted_mask_buffer = GetScratchBuffer( - num_elements * sizeof(NativeCudaT), context->GetComputeStream()); + num_elements * sizeof(NativeCudaT), GetComputeStream(context)); float mask_filter_value = static_cast(std::numeric_limits::lowest()); ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( attn_mask->Data(), @@ -219,7 +233,7 @@ Status Attention::RunFlashAttention( const attention_helper::AttentionParameters& parameters) const { #if USE_FLASH_ATTENTION auto& device_prop = GetDeviceProp(); - auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + auto cuda_stream = Stream(context); const bool is_bf16 = std::is_same::value; const bool is_bsnh = parameters.transpose_output; // 3D inputs → BSNH @@ -233,9 +247,9 @@ Status Attention::RunFlashAttention( parameters.total_sequence_length, parameters.q_num_heads, parameters.head_size, device_prop.multiProcessorCount); - auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); - auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); - auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, GetComputeStream(context)); + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, GetComputeStream(context)); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, GetComputeStream(context)); if (softmax_lse_accum_bytes > 0) { CUDA_RETURN_IF_ERROR(cudaMemsetAsync(softmax_lse_accum_buffer.get(), 0, @@ -252,7 +266,7 @@ Status Attention::RunFlashAttention( if (!is_bsnh) { size_t q_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length * parameters.q_num_heads * parameters.head_size; - q_bsnh_buffer = GetScratchBuffer(q_bytes, context->GetComputeStream()); + q_bsnh_buffer = GetScratchBuffer(q_bytes, GetComputeStream(context)); ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( parameters.batch_size, parameters.q_sequence_length, parameters.q_num_heads, parameters.head_size, @@ -267,7 +281,7 @@ Status Attention::RunFlashAttention( if (!is_bsnh) { size_t out_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length * parameters.q_num_heads * parameters.v_head_size; - out_bsnh_buffer = GetScratchBuffer(out_bytes, context->GetComputeStream()); + out_bsnh_buffer = GetScratchBuffer(out_bytes, GetComputeStream(context)); out_data = out_bsnh_buffer.get(); } @@ -280,7 +294,7 @@ Status Attention::RunFlashAttention( "(past_sequence_length must be 0, got ", parameters.past_sequence_length, ")."); - auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); + auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToFlashSeqlensK( nonpad_kv_seqlen->Data(), seqlens_k_buffer.get(), @@ -334,7 +348,7 @@ Status Attention::RunFlashAttention( // Step 1: Compute per-batch past sequence lengths for the concat kernel. // The concat kernel needs past_seq_lens to know where past data ends and new begins. - auto past_seqlens_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); + auto past_seqlens_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); if (attn_mask != nullptr && attn_mask->IsDataType()) { size_t mask_dims = attn_mask->Shape().NumDimensions(); auto dims = attn_mask->Shape().GetDims(); @@ -364,8 +378,8 @@ Status Attention::RunFlashAttention( parameters.kv_num_heads * parameters.head_size; size_t v_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length * parameters.kv_num_heads * parameters.v_head_size; - k_bsnh_buffer = GetScratchBuffer(k_bytes, context->GetComputeStream()); - v_bsnh_buffer = GetScratchBuffer(v_bytes, context->GetComputeStream()); + k_bsnh_buffer = GetScratchBuffer(k_bytes, GetComputeStream(context)); + v_bsnh_buffer = GetScratchBuffer(v_bytes, GetComputeStream(context)); ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( parameters.batch_size, parameters.kv_sequence_length, parameters.kv_num_heads, parameters.head_size, @@ -413,7 +427,7 @@ Status Attention::RunFlashAttention( // Step 4: Compute total seqlens for mha_fwd_kvcache. // With k_new=nullptr, the kernel treats seqlens_k as the total valid token count // (not pre-append count), so we need past + new. - auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); + auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); if (attn_mask != nullptr && attn_mask->IsDataType()) { size_t mask_dims = attn_mask->Shape().NumDimensions(); auto dims = attn_mask->Shape().GetDims(); @@ -567,7 +581,7 @@ Status Attention::RunMemoryEfficientAttention( ORT_UNUSED_PARAMETER(past_key); ORT_UNUSED_PARAMETER(past_value); auto& device_prop = GetDeviceProp(); - auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + auto cuda_stream = Stream(context); const bool is_bsnh = parameters.transpose_output; const int sm = device_prop.major * 10 + device_prop.minor; @@ -581,7 +595,7 @@ Status Attention::RunMemoryEfficientAttention( if (!is_bsnh) { size_t q_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length * parameters.q_num_heads * parameters.head_size; - q_bsnh_buffer = GetScratchBuffer(q_bytes, context->GetComputeStream()); + q_bsnh_buffer = GetScratchBuffer(q_bytes, GetComputeStream(context)); ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( parameters.batch_size, parameters.q_sequence_length, parameters.q_num_heads, parameters.head_size, @@ -596,7 +610,7 @@ Status Attention::RunMemoryEfficientAttention( if (!is_bsnh) { size_t out_bytes = sizeof(T) * parameters.batch_size * parameters.q_sequence_length * parameters.q_num_heads * parameters.v_head_size; - out_bsnh_buffer = GetScratchBuffer(out_bytes, context->GetComputeStream()); + out_bsnh_buffer = GetScratchBuffer(out_bytes, GetComputeStream(context)); out_data = out_bsnh_buffer.get(); } @@ -622,8 +636,8 @@ Status Attention::RunMemoryEfficientAttention( static_cast(parameters.total_sequence_length) * static_cast(parameters.q_num_heads) * static_cast(parameters.head_size); - k_expand_buffer = GetScratchBuffer(expanded_kv_elements * sizeof(T), context->GetComputeStream()); - v_expand_buffer = GetScratchBuffer(expanded_kv_elements * sizeof(T), context->GetComputeStream()); + k_expand_buffer = GetScratchBuffer(expanded_kv_elements * sizeof(T), GetComputeStream(context)); + v_expand_buffer = GetScratchBuffer(expanded_kv_elements * sizeof(T), GetComputeStream(context)); onnxruntime::contrib::GroupQueryAttentionParameters ungroup_params = {}; ungroup_params.batch_size = parameters.batch_size; @@ -661,7 +675,7 @@ Status Attention::RunMemoryEfficientAttention( if (nonpad_kv_seqlen != nullptr) { // Convert nonpad_kv_seqlen to seqlens_k for custom right padding. // MEA expects actual token count (not count-1), so use FlashSeqlensK variant. - auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, context->GetComputeStream()); + auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToFlashSeqlensK( nonpad_kv_seqlen->Data(), seqlens_k_buffer.get(), @@ -710,7 +724,7 @@ Status Attention::RunMemoryEfficientAttention( parameters.v_head_size, sizeof(T) == sizeof(float))) { size_t workspace_bytes = sizeof(float) * parameters.batch_size * parameters.q_sequence_length * parameters.q_num_heads * parameters.v_head_size; - workspace_buffer = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + workspace_buffer = GetScratchBuffer(workspace_bytes, GetComputeStream(context)); p.workspace = workspace_buffer.get(); } else { p.workspace = nullptr; @@ -758,7 +772,7 @@ Status Attention::RunMemoryEfficientAttention( parameters.v_head_size, sizeof(T) == sizeof(float))) { size_t workspace_bytes = sizeof(float) * parameters.batch_size * parameters.q_sequence_length * parameters.q_num_heads * parameters.v_head_size; - workspace_buffer = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + workspace_buffer = GetScratchBuffer(workspace_bytes, GetComputeStream(context)); p.workspace = workspace_buffer.get(); } else { p.workspace = nullptr; @@ -845,7 +859,8 @@ Status Attention::RunUnfusedAttention( const attention_helper::AttentionParameters& parameters) const { using CudaT = typename ToCudaType::MappedType; auto& device_prop = GetDeviceProp(); - auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + auto cuda_stream = Stream(context); + auto ort_stream = GetOrtStream(context); // Bridge to contrib::AttentionParameters for the MHA unfused path onnxruntime::contrib::AttentionParameters contribop_parameters; @@ -927,7 +942,7 @@ Status Attention::RunUnfusedAttention( int64_t bias_elements = static_cast(parameters.batch_size) * parameters.q_sequence_length * parameters.total_sequence_length; - converted_mask_buffer = GetScratchBuffer(bias_elements * sizeof(NativeCudaT), context->GetComputeStream()); + converted_mask_buffer = GetScratchBuffer(bias_elements * sizeof(NativeCudaT), GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToAttentionBias( nonpad_kv_seqlen->Data(), reinterpret_cast(converted_mask_buffer.get()), @@ -957,7 +972,7 @@ Status Attention::RunUnfusedAttention( if (attn_mask->IsDataType()) { // Convert bool mask to additive bias in a temp buffer, then add in-place. - mask_bias_buffer = GetScratchBuffer(mask_elements * sizeof(NativeCudaT), context->GetComputeStream()); + mask_bias_buffer = GetScratchBuffer(mask_elements * sizeof(NativeCudaT), GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( attn_mask->Data(), reinterpret_cast(mask_bias_buffer.get()), @@ -991,7 +1006,7 @@ Status Attention::RunUnfusedAttention( if (attn_mask->IsDataType()) { using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; int64_t num_elements = attn_mask->Shape().Size(); - converted_mask_buffer = GetScratchBuffer(num_elements * sizeof(NativeCudaT), context->GetComputeStream()); + converted_mask_buffer = GetScratchBuffer(num_elements * sizeof(NativeCudaT), GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( attn_mask->Data(), reinterpret_cast(converted_mask_buffer.get()), @@ -1025,7 +1040,7 @@ Status Attention::RunUnfusedAttention( contribop_parameters.total_sequence_length, nullptr, false, false, false, false, false, no_qkv_workspace); - auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); + auto work_space = GetScratchBuffer(workspace_bytes, GetComputeStream(context)); data.has_qkv_workspace = !no_qkv_workspace; data.workspace = reinterpret_cast(work_space.get()); @@ -1035,7 +1050,7 @@ Status Attention::RunUnfusedAttention( cudnnHandle_t cudnn = GetCudnnHandle(context); return onnxruntime::contrib::cuda::QkvToContext( - device_prop, cublas, cudnn, context->GetComputeStream(), contribop_parameters, data); + device_prop, cublas, cudnn, ort_stream.get(), contribop_parameters, data); } // ============================================================================ diff --git a/onnxruntime/core/providers/cuda/math/cumsum.cc b/onnxruntime/core/providers/cuda/math/cumsum.cc index b8b1f29c643d8..899a37f5bff16 100644 --- a/onnxruntime/core/providers/cuda/math/cumsum.cc +++ b/onnxruntime/core/providers/cuda/math/cumsum.cc @@ -3,12 +3,36 @@ #include "cumsum.h" #include "cumsum_impl.h" -#include "core/providers/cpu/math/cumsum.h" #include "core/providers/common.h" namespace onnxruntime { namespace cuda { +namespace { + +Status GetAxisFromInput(const Tensor* axis_tensor, int64_t input_rank, int64_t& axis_out) { + if (!axis_tensor) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Axis tensor must be provided to the CumSum op"); + } + + if (axis_tensor->Shape().NumDimensions() > 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Axis tensor should be 0D or 1D"); + } + + if (axis_tensor->IsDataType()) { + axis_out = static_cast(axis_tensor->Data()[0]); + } else if (axis_tensor->IsDataType()) { + axis_out = axis_tensor->Data()[0]; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Axis tensor should be of type `int32_t` or `int64_t`"); + } + + axis_out = HandleNegativeAxis(axis_out, input_rank); + return Status::OK(); +} + +} // namespace + ONNX_OPERATOR_VERSIONED_KERNEL_EX( CumSum, kOnnxDomain, @@ -53,7 +77,7 @@ Status CumSum::ComputeInternal(OpKernelContext* ctx) const { const Tensor* axis_tensor = ctx->Input(1); // axis input tensor int64_t axis = 0; - ORT_THROW_IF_ERROR(cumsum_op::GetAxis(axis_tensor, rank, axis)); + ORT_THROW_IF_ERROR(GetAxisFromInput(axis_tensor, rank, axis)); TensorShape output_shape(input->Shape()); auto& output = *ctx->Output(0, output_shape); // output tensor diff --git a/onnxruntime/core/providers/cuda/math/einsum.cc b/onnxruntime/core/providers/cuda/math/einsum.cc index 6f2ed41cacab6..648cb30f0fe55 100644 --- a/onnxruntime/core/providers/cuda/math/einsum.cc +++ b/onnxruntime/core/providers/cuda/math/einsum.cc @@ -42,13 +42,14 @@ Status Einsum::ComputeInternal(OpKernelContext* context) const { EinsumEquationPreprocessor einsum_equation_preprocessor(*einsum_equation_preprocessor_); + auto ort_stream = GetOrtStream(context); EinsumOp::EinsumCudaAssets einsum_cuda_assets( - GetComputeStream(context), + ort_stream, GetDeviceProp(), GetCublasHandle(context), GetCudnnHandle(context), allocator, - cuda_ep_->UseTF32()); + UseTF32()); EinsumComputePreprocessor einsum_compute_preprocessor(einsum_equation_preprocessor, inputs, allocator, &einsum_cuda_assets); diff --git a/onnxruntime/core/providers/cuda/math/einsum.h b/onnxruntime/core/providers/cuda/math/einsum.h index 42fdb52c22c38..60facd285154a 100644 --- a/onnxruntime/core/providers/cuda/math/einsum.h +++ b/onnxruntime/core/providers/cuda/math/einsum.h @@ -6,7 +6,6 @@ #include "core/platform/threadpool.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cuda_kernel.h" -#include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cpu/math/einsum_utils/einsum_typed_compute_processor.h" #include "einsum_utils/einsum_auxiliary_ops.h" @@ -19,7 +18,6 @@ class Einsum final : public CudaKernel { ORT_ENFORCE(info.GetAttr("equation", &equation_).IsOK(), "Missing 'equation' attribute"); einsum_equation_preprocessor_ = std::make_unique(equation_); - cuda_ep_ = static_cast(info.GetExecutionProvider()); } Status ComputeInternal(OpKernelContext* context) const override; @@ -27,7 +25,6 @@ class Einsum final : public CudaKernel { private: std::string equation_; std::unique_ptr einsum_equation_preprocessor_; - const CUDAExecutionProvider* cuda_ep_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/math/gemm.cc b/onnxruntime/core/providers/cuda/math/gemm.cc index 7fa5e74b54248..59602c2458a7b 100644 --- a/onnxruntime/core/providers/cuda/math/gemm.cc +++ b/onnxruntime/core/providers/cuda/math/gemm.cc @@ -5,7 +5,9 @@ #include "core/providers/cpu/math/gemm_helper.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" +#ifndef BUILD_CUDA_EP_AS_PLUGIN #include "core/providers/cuda/tunable/math/gemm.h" +#endif namespace onnxruntime { namespace cuda { @@ -72,9 +74,11 @@ Status Gemm::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); +#ifndef BUILD_CUDA_EP_AS_PLUGIN if (GetTuningContext()->IsTunableOpEnabled()) { return tunable::TunableGemm(M, N, K, trans_A_, trans_B_, alpha_, B ? beta_ : 0.0f, this, ctx); } +#endif return ComputeDefault(ctx, M, N, K); } diff --git a/onnxruntime/core/providers/cuda/math/matmul.cc b/onnxruntime/core/providers/cuda/math/matmul.cc index 04ffa875c1b9d..67e84ac99322c 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.cc +++ b/onnxruntime/core/providers/cuda/math/matmul.cc @@ -5,7 +5,9 @@ #include "core/providers/cuda/shared_inc/fpgeneric.h" #include "core/providers/cuda/cuda_allocator.h" +#ifndef BUILD_CUDA_EP_AS_PLUGIN #include "core/providers/cuda/tunable/math/matmul.h" +#endif namespace onnxruntime { namespace cuda { @@ -121,9 +123,11 @@ Status MatMul::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); } +#ifndef BUILD_CUDA_EP_AS_PLUGIN if (GetTuningContext()->IsTunableOpEnabled()) { return tunable::TunableMatMul(alpha_, trans_a, trans_b, trans_batch_a_, trans_batch_b_, helper, this, ctx); } +#endif return ComputeDefault(ctx, helper); } @@ -219,9 +223,9 @@ Status FuncMatMul( MatMulComputeHelper::OffsetToArrays(reinterpret_cast(A->Data()), helper.LeftOffsets(), A_arrays.CpuSpan()); MatMulComputeHelper::OffsetToArrays(reinterpret_cast(B->Data()), helper.RightOffsets(), B_arrays.CpuSpan()); MatMulComputeHelper::OffsetToArrays(reinterpret_cast(Y->MutableData()), helper.OutputOffsets(), Y_arrays.CpuSpan()); - ORT_RETURN_IF_ERROR(A_arrays.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(B_arrays.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(Y_arrays.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(A_arrays.CopyToGpu(cuda_kernel->GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(B_arrays.CopyToGpu(cuda_kernel->GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(Y_arrays.CopyToGpu(cuda_kernel->GetComputeStream(ctx))); // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. // It requires Ampere or newer GPU, and pointers of matrices shall be aligned (ideal alignment is 16-byte). @@ -370,9 +374,9 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help MatMulComputeHelper::OffsetToArrays(reinterpret_cast(left_X->Data()), helper.LeftOffsets(), left_arrays.CpuSpan()); MatMulComputeHelper::OffsetToArrays(reinterpret_cast(right_X->Data()), helper.RightOffsets(), right_arrays.CpuSpan()); MatMulComputeHelper::OffsetToArrays(reinterpret_cast(Y->MutableData()), helper.OutputOffsets(), output_arrays.CpuSpan()); - ORT_RETURN_IF_ERROR(left_arrays.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(right_arrays.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(output_arrays.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(left_arrays.CopyToGpu(this->GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(right_arrays.CopyToGpu(this->GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(output_arrays.CopyToGpu(this->GetComputeStream(ctx))); // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. // It requires Ampere or newer GPU, and pointers of matrices shall be aligned (ideal alignment is 16-byte). diff --git a/onnxruntime/core/providers/cuda/math/matmul_integer.cc b/onnxruntime/core/providers/cuda/math/matmul_integer.cc index 6e43834bdf1ef..b42f204efa430 100644 --- a/onnxruntime/core/providers/cuda/math/matmul_integer.cc +++ b/onnxruntime/core/providers/cuda/math/matmul_integer.cc @@ -69,13 +69,13 @@ Status MatMulInteger::ComputeInternal(OpKernelContext* ctx) cons // OffsetOutput computes gets the final result IAllocatorUniquePtr a_row_buf; if (b_offset != 0) { - a_row_buf = GetScratchBuffer(helper.OutputShape().Size() / helper.N(), ctx->GetComputeStream()); + a_row_buf = GetScratchBuffer(helper.OutputShape().Size() / helper.N(), GetComputeStream(ctx)); ORT_RETURN_IF_ERROR(ReduceRowSumOnMatrixA(Stream(ctx), a_ptr, a_row_buf.get(), b_offset, helper)); } IAllocatorUniquePtr b_col_buf; if (a_offset != 0) { - b_col_buf = GetScratchBuffer(helper.OutputShape().Size() / helper.M(), ctx->GetComputeStream()); + b_col_buf = GetScratchBuffer(helper.OutputShape().Size() / helper.M(), GetComputeStream(ctx)); ORT_RETURN_IF_ERROR(ReduceColSumOnMatrixB(Stream(ctx), b_ptr, b_col_buf.get(), a_offset, helper)); } @@ -105,7 +105,8 @@ Status MatMulInteger::ComputeInternal(OpKernelContext* ctx) cons output_ptr + helper.OutputOffsets()[batch], static_cast(helper.N()), this, - ctx->GetComputeStream())); + GetComputeStream(ctx), Stream(ctx), + GetCublasHandle(ctx))); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/math/softmax.cc b/onnxruntime/core/providers/cuda/math/softmax.cc index 9402232a24737..609f17891fc0a 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.cc +++ b/onnxruntime/core/providers/cuda/math/softmax.cc @@ -13,7 +13,7 @@ namespace cuda { template Status SoftMaxComputeHelper( - Stream* stream, + SoftmaxComputeStreamT stream, const T* X, const TensorShape& input_shape, TOut* Y, @@ -40,9 +40,9 @@ Status SoftMaxComputeHelper( } #define SPECIALIZED_SOFTMAX_HELPER_IMPL(T, TOut) \ - template Status SoftMaxComputeHelper(Stream * stream, const T* input, \ + template Status SoftMaxComputeHelper(SoftmaxComputeStreamT stream, const T* input, \ const TensorShape& shape, TOut* Y, int64_t axis); \ - template Status SoftMaxComputeHelper(Stream * stream, const T* input, \ + template Status SoftMaxComputeHelper(SoftmaxComputeStreamT stream, const T* input, \ const TensorShape& shape, TOut* Y, int64_t axis); SPECIALIZED_SOFTMAX_HELPER_IMPL(MLFloat16, float) @@ -177,12 +177,13 @@ Status Softmax::ComputeInternal(OpKernelContext* ctx) const { } Status status; + auto compute_stream = Stream(ctx); if (log_softmax_) { - status = SoftMaxComputeHelper(ctx->GetComputeStream(), X_data, *compute_input_shape, Y_data, + status = SoftMaxComputeHelper(compute_stream, X_data, *compute_input_shape, Y_data, is_transpose_required ? static_cast(rank) - 1 : static_cast(axis)); } else { - status = SoftMaxComputeHelper(ctx->GetComputeStream(), X_data, *compute_input_shape, Y_data, + status = SoftMaxComputeHelper(compute_stream, X_data, *compute_input_shape, Y_data, is_transpose_required ? static_cast(rank) - 1 : static_cast(axis)); } diff --git a/onnxruntime/core/providers/cuda/math/softmax.h b/onnxruntime/core/providers/cuda/math/softmax.h index 6f4016b655c96..c0c0818042c15 100644 --- a/onnxruntime/core/providers/cuda/math/softmax.h +++ b/onnxruntime/core/providers/cuda/math/softmax.h @@ -9,20 +9,22 @@ namespace onnxruntime { namespace cuda { +using SoftmaxComputeStreamT = cudaStream_t; + template Status SoftMaxComputeHelper( - Stream* stream, + SoftmaxComputeStreamT stream, const T* input, const TensorShape& shape, TOut* Y, int64_t axis); template -Status dispatch_warpwise_softmax_forward(Stream* stream, output_t* dst, const input_t* src, +Status dispatch_warpwise_softmax_forward(SoftmaxComputeStreamT stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count); template -Status dispatch_blockwise_softmax_forward(Stream* stream, output_t* output, const input_t* input, +Status dispatch_blockwise_softmax_forward(SoftmaxComputeStreamT stream, output_t* output, const input_t* input, int softmax_elements, int input_stride, int output_stride, int batch_count); template @@ -32,7 +34,7 @@ class Softmax final : public CudaKernel { const auto& node = info.node(); opset_ = node.SinceVersion(); - int64_t axis; + int64_t axis = 0; Status status = info.GetAttr("axis", &axis); if (status.IsOK()) { diff --git a/onnxruntime/core/providers/cuda/math/softmax_impl.cu b/onnxruntime/core/providers/cuda/math/softmax_impl.cu index 04e66e9e1529e..10550fb0cb810 100644 --- a/onnxruntime/core/providers/cuda/math/softmax_impl.cu +++ b/onnxruntime/core/providers/cuda/math/softmax_impl.cu @@ -29,9 +29,9 @@ namespace onnxruntime { namespace cuda { template -Status dispatch_warpwise_softmax_forward(Stream* ort_stream, output_t* dst, const input_t* src, int softmax_elements, +Status dispatch_warpwise_softmax_forward(SoftmaxComputeStreamT ort_stream, output_t* dst, const input_t* src, int softmax_elements, int softmax_elements_stride, int batch_count) { - auto stream = static_cast(ort_stream->GetHandle()); + auto stream = ort_stream; if (softmax_elements == 0) { return Status::OK(); } else { @@ -95,18 +95,18 @@ Status dispatch_warpwise_softmax_forward(Stream* ort_stream, output_t* dst, cons return CUDA_CALL(cudaGetLastError()); } -#define SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ - template Status dispatch_warpwise_softmax_forward(Stream * ort_stream, \ - output_t * dst, \ - const input_t* src, \ - int softmax_elements, \ - int softmax_elements_stride, \ - int batch_count); \ - template Status dispatch_warpwise_softmax_forward(Stream * ort_stream, \ - output_t * dst, \ - const input_t* src, \ - int softmax_elements, \ - int softmax_elements_stride, \ +#define SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ + template Status dispatch_warpwise_softmax_forward(SoftmaxComputeStreamT ort_stream, \ + output_t * dst, \ + const input_t* src, \ + int softmax_elements, \ + int softmax_elements_stride, \ + int batch_count); \ + template Status dispatch_warpwise_softmax_forward(SoftmaxComputeStreamT ort_stream, \ + output_t * dst, \ + const input_t* src, \ + int softmax_elements, \ + int softmax_elements_stride, \ int batch_count); SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(float, float, float) @@ -116,9 +116,9 @@ SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(double, double, double) SPECIALIZED_WRAPWISE_SOFTMAX_IMPL(BFloat16, BFloat16, float) template -Status dispatch_blockwise_softmax_forward(Stream* ort_stream, output_t* output, const input_t* input, int softmax_elements, +Status dispatch_blockwise_softmax_forward(SoftmaxComputeStreamT ort_stream, output_t* output, const input_t* input, int softmax_elements, int input_stride, int output_stride, int batch_count) { - auto stream = static_cast(ort_stream->GetHandle()); + auto stream = ort_stream; dim3 grid(batch_count); constexpr int ILP = sizeof(float4) / sizeof(input_t); dim3 block = SoftMax_getBlockSize(ILP, softmax_elements); @@ -134,12 +134,12 @@ Status dispatch_blockwise_softmax_forward(Stream* ort_stream, output_t* output, return CUDA_CALL(cudaGetLastError()); } -#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ - template Status dispatch_blockwise_softmax_forward( \ - Stream * ort_stream, output_t * output, const input_t* src, int softmax_elements, \ - int input_stride, int output_stride, int batch_count); \ - template Status dispatch_blockwise_softmax_forward( \ - Stream * ort_stream, output_t * output, const input_t* src, int softmax_elements, \ +#define SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(input_t, output_t, acc_t) \ + template Status dispatch_blockwise_softmax_forward( \ + SoftmaxComputeStreamT ort_stream, output_t * output, const input_t* src, int softmax_elements, \ + int input_stride, int output_stride, int batch_count); \ + template Status dispatch_blockwise_softmax_forward( \ + SoftmaxComputeStreamT ort_stream, output_t * output, const input_t* src, int softmax_elements, \ int input_stride, int output_stride, int batch_count); SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(float, float, float) diff --git a/onnxruntime/core/providers/cuda/math/topk.cc b/onnxruntime/core/providers/cuda/math/topk.cc index bab6f15f2c774..fcfc27c5917d9 100644 --- a/onnxruntime/core/providers/cuda/math/topk.cc +++ b/onnxruntime/core/providers/cuda/math/topk.cc @@ -78,7 +78,8 @@ TopK::TopK(const OpKernelInfo& info) : CudaKernel(info) { #define IS_PRIM_TYPE(T) utils::IsPrimitiveDataType(prim_type) #define TOPKIMPL(T) TopKImpl(this, use_deterministic_compute, \ - ctx->GetComputeStream(), tensor_X->Data(), \ + Stream(ctx), GetComputeStream(ctx), \ + tensor_X->Data(), \ static_cast(tensor_V->MutableDataRaw()), \ static_cast(tensor_I->MutableDataRaw()), \ elem_nums_cuda, \ diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.cuh b/onnxruntime/core/providers/cuda/math/topk_impl.cuh index a0db0330b050c..65960a36f150f 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/topk_impl.cuh @@ -186,6 +186,12 @@ __device__ __forceinline__ bool SamePrefix(const BFloat16* f0, const BFloat16* f return SamePrefix((const int16_t*)f0, (const int16_t*)f1, skip); } +#ifdef __CUDACC__ +__device__ __forceinline__ bool SamePrefix(const nv_bfloat16* f0, const nv_bfloat16* f1, int64_t skip) { + return SamePrefix((const int16_t*)f0, (const int16_t*)f1, skip); +} +#endif + __device__ __forceinline__ bool SamePrefix(const float* f0, const float* f1, int64_t skip) { return SamePrefix((const int32_t*)f0, (const int32_t*)f1, skip); } @@ -207,6 +213,12 @@ __device__ __forceinline__ int32_t Radix(const BFloat16* f, int64_t skip) { return Radix((const int16_t*)f, skip); } +#ifdef __CUDACC__ +__device__ __forceinline__ int32_t Radix(const nv_bfloat16* f, int64_t skip) { + return Radix((const int16_t*)f, skip); +} +#endif + __device__ __forceinline__ int32_t Radix(const float* f, int64_t skip) { return Radix((const int32_t*)f, skip); } @@ -228,6 +240,12 @@ __device__ __forceinline__ void SetByte(BFloat16* f, int64_t byte) { SetByte((int16_t*)f, byte); } +#ifdef __CUDACC__ +__device__ __forceinline__ void SetByte(nv_bfloat16* f, int64_t byte) { + SetByte((int16_t*)f, byte); +} +#endif + __device__ __forceinline__ void SetByte(float* f, int64_t byte) { SetByte((int32_t*)f, byte); } @@ -420,14 +438,13 @@ __global__ void ExcludeOutput(T* output_i, T K, T dimension) { template Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, - Stream* ort_stream, const T* input_x, T* output_v, int64_t* output_i, + cudaStream_t stream, void* alloc_stream, const T* input_x, T* output_v, int64_t* output_i, const TArray& elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) { typedef typename ToCudaType::MappedType CudaT; using CubT = typename CubSortType::type; const CudaT* input_x_ptr = reinterpret_cast(input_x); CudaT* output_v_ptr = reinterpret_cast(output_v); - cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; auto aligned_K = ALIGN(K); auto aligned_dimension = ALIGN(dimension); @@ -462,10 +479,10 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, NumericLimits::Lowest(), NumericLimits::Max()); } } else { - auto input_key_buffer = kernel->GetScratchBuffer(dimension, ort_stream); - auto output_key_buffer = kernel->GetScratchBuffer(dimension, ort_stream); - auto input_value_buffer = kernel->GetScratchBuffer(dimension, ort_stream); - auto output_value_buffer = kernel->GetScratchBuffer(dimension, ort_stream); + auto input_key_buffer = kernel->GetScratchBuffer(dimension, alloc_stream); + auto output_key_buffer = kernel->GetScratchBuffer(dimension, alloc_stream); + auto input_value_buffer = kernel->GetScratchBuffer(dimension, alloc_stream); + auto output_value_buffer = kernel->GetScratchBuffer(dimension, alloc_stream); auto* input_key = input_key_buffer.get(); auto* output_key = output_key_buffer.get(); auto* input_value = input_value_buffer.get(); @@ -475,7 +492,7 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, auto* output_key_cub = reinterpret_cast(output_key); size_t temp_bytes = 0; CUDA_RETURN_IF_ERROR(cub::DeviceRadixSort::SortPairs(nullptr, temp_bytes, input_key_cub, output_key_cub, input_value, output_value, dimension, 0, sizeof(CubT) * 8, stream)); - auto temp_storage_buffer = kernel->GetScratchBuffer(temp_bytes, ort_stream); + auto temp_storage_buffer = kernel->GetScratchBuffer(temp_bytes, alloc_stream); auto* temp_storage = temp_storage_buffer.get(); auto blocks_per_grid_D = (int)(ceil(static_cast(dimension) / BT)); auto blocks_per_grid_K = (int)(ceil(static_cast(K) / BT)); @@ -497,7 +514,8 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, #define TOPKIMPLE(T) template Status TopKImpl(const CudaKernel* kernel, \ bool use_deterministic_compute, \ - Stream* ort_stream, \ + cudaStream_t stream, \ + void* alloc_stream, \ const T* input_x, \ T* output_v, \ int64_t* output_i, \ diff --git a/onnxruntime/core/providers/cuda/math/topk_impl.h b/onnxruntime/core/providers/cuda/math/topk_impl.h index c5f63aadc402a..f072dd4e66107 100644 --- a/onnxruntime/core/providers/cuda/math/topk_impl.h +++ b/onnxruntime/core/providers/cuda/math/topk_impl.h @@ -11,7 +11,7 @@ namespace onnxruntime { namespace cuda { template -Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, Stream* ort_stream, +Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute, cudaStream_t stream, void* alloc_stream, const T* input_x, T* output_v, int64_t* output_i, const TArray& elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension); diff --git a/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc index f543ba9c975e1..bb05c04a4dc1e 100644 --- a/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc @@ -158,7 +158,7 @@ Status VariadicElementwiseOp OpKernelContext* context) const { const auto& node = Node(); const auto& node_name = node.Name(); - auto input_count = node.InputArgCount().front(); + auto input_count = context->InputCount(); ORT_RETURN_IF_NOT(input_count >= 1, "Must have 1 or more inputs"); const InputTensorVector input_tensors = diff --git a/onnxruntime/core/providers/cuda/nn/batch_norm.cc b/onnxruntime/core/providers/cuda/nn/batch_norm.cc index 02da1a2c99dfd..b5aacff1d22ef 100644 --- a/onnxruntime/core/providers/cuda/nn/batch_norm.cc +++ b/onnxruntime/core/providers/cuda/nn/batch_norm.cc @@ -99,10 +99,10 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) // Convert the scale, B, mean, var to float const int64_t C = x_shape.GetDims()[NHWC ? 3 : 1]; - auto f_scale = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); - auto f_B = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); - auto f_mean = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); - auto f_var = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); + auto f_scale = GetScratchBuffer(C, GetComputeStream(p_op_kernel_context)); + auto f_B = GetScratchBuffer(C, GetComputeStream(p_op_kernel_context)); + auto f_mean = GetScratchBuffer(C, GetComputeStream(p_op_kernel_context)); + auto f_var = GetScratchBuffer(C, GetComputeStream(p_op_kernel_context)); Impl_Cast(Stream(p_op_kernel_context), scale_data, f_scale.get(), C); Impl_Cast(Stream(p_op_kernel_context), b_data, f_B.get(), C); Impl_Cast(Stream(p_op_kernel_context), mean_data, f_mean.get(), C); @@ -137,7 +137,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) auto saved_mean_data = reinterpret_cast(saved_mean->MutableData()); auto saved_inv_var_data = reinterpret_cast(saved_var->MutableData()); - auto stream = static_cast(p_op_kernel_context->GetComputeStream()->GetHandle()); + auto stream = Stream(p_op_kernel_context); CUDA_RETURN_IF_ERROR( cudaMemcpyAsync(running_mean_data, mean_data, mean->SizeInBytes(), cudaMemcpyDeviceToDevice, stream)); CUDA_RETURN_IF_ERROR( diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index dbda2d3ad0c88..20bc990aaee24 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -237,8 +237,12 @@ Status Conv::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShap CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_operation_graph(handle)); CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->create_execution_plans({heur_mode})); } catch (const std::exception& ex) { - std::string message = MakeString("Failed to initialize CUDNN Frontend", ex.what(), - "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); +#ifndef BUILD_CUDA_EP_AS_PLUGIN + std::string message = MakeString("Failed to initialize CUDNN Frontend: ", ex.what(), + " with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); +#else + std::string message = MakeString("Failed to initialize CUDNN Frontend: ", ex.what()); +#endif return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); } @@ -249,8 +253,12 @@ Status Conv::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShap CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_plans(handle)); } catch (const std::exception& ex) { if (!fuse_bias && !fuse_act && use_tf32) { - std::string message = MakeString("OP not supported by CUDNN Frontend", ex.what(), - "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); +#ifndef BUILD_CUDA_EP_AS_PLUGIN + std::string message = MakeString("OP not supported by CUDNN Frontend: ", ex.what(), + " with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); +#else + std::string message = MakeString("OP not supported by CUDNN Frontend: ", ex.what()); +#endif return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); } @@ -367,8 +375,6 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected s_.Y = context->Output(0, TensorShape(s_.y_dims)); s_.y_data = reinterpret_cast(s_.Y->MutableData()); - const CUDAExecutionProvider* cuda_ep = - static_cast(this->Info().GetExecutionProvider()); TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; TensorShapeVector y_dims_cudnn{y_dims.begin(), y_dims.end()}; @@ -395,7 +401,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected // PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems // to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT. // See PR #7348 and #7702 for more context. - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + if (this->GetCudnnConv1dPadToNc1d()) { x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); w_dims_cudnn.insert(w_dims_cudnn.begin() + 2, 1); @@ -423,7 +429,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected auto handle = GetCudnnHandle(context); - int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); + int cudnn_conv_algo = this->GetCudnnConvAlgo(); #if !defined(__CUDACC__) cudnn_frontend::HeurMode_t heur_mode; switch (cudnn_conv_algo) { @@ -443,9 +449,9 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected break; } - const auto use_tf32 = cuda_ep->UseTF32(); + const auto use_tf32 = this->UseTF32(); // fuse if this op is part of a FusedConv or if the EP is set to fuse ops - const auto fuse_bias = cuda_ep->IsFuseConvBias() || is_fused_node_; + const auto fuse_bias = this->IsFuseConvBias() || is_fused_node_; const auto fuse_act = is_fused_node_; ORT_RETURN_IF_ERROR(CreateCudnnFeExecutionPlan(x_dims_cudnn, w_dims_cudnn, B, Z, y_dims_cudnn, handle, heur_mode, @@ -491,7 +497,7 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { CUDA_RETURN_IF_ERROR(cudaMemset(s_.y_data, 0, s_.Y->SizeInBytes())); } } - auto ws = GetWorkSpace(context->GetComputeStream()); + auto ws = GetWorkSpace(GetComputeStream(context)); CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_graph->execute(cudnn_handle, s_.variant_pack, diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index e4047a6af272e..fa9808a6d3d3e 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -224,7 +224,7 @@ class Conv : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; protected: - inline IAllocatorUniquePtr GetWorkSpace(onnxruntime::Stream* stream) const { + inline IAllocatorUniquePtr GetWorkSpace(void* stream) const { return GetScratchBuffer(s_.workspace_bytes, stream); } diff --git a/onnxruntime/core/providers/cuda/nn/conv_8.h b/onnxruntime/core/providers/cuda/nn/conv_8.h index bcee1bcb7e231..2ce213b92810b 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_8.h +++ b/onnxruntime/core/providers/cuda/nn/conv_8.h @@ -189,16 +189,13 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) s_.Y = context->Output(0, TensorShape(s_.y_dims)); if (post_slicing_required) { // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. - s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); + s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, GetComputeStream(context)); s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); } else { // No post slicing needed. Fill the output tensor's buffer directly. s_.y_data = reinterpret_cast(s_.Y->MutableData()); } - const CUDAExecutionProvider* cuda_ep = - static_cast(this->Info().GetExecutionProvider()); - TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; TensorShapeVector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads; if (kernel_rank < 2) { @@ -210,7 +207,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) // PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems // to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT. // See PR #7348 and #7702 for more context. - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + if (this->GetCudnnConv1dPadToNc1d()) { x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); w_dims.insert(w_dims.begin() + 2, 1); @@ -313,13 +310,13 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) cudnnConvolutionFwdAlgoPerf_t perf; int algo_count = 1; - int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); + int cudnn_conv_algo = this->GetCudnnConvAlgo(); ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", cudnn_conv_algo); switch (cudnn_conv_algo) { case 0: { static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; - size_t max_ws_size = cuda_ep->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetCudnnHandle(context), s_, kAllAlgos, num_algos) - : AlgoSearchWorkspaceSize; + size_t max_ws_size = this->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetCudnnHandle(context), s_, kAllAlgos, num_algos) + : AlgoSearchWorkspaceSize; // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. IAllocatorUniquePtr algo_search_workspace = GetTransientScratchBuffer(max_ws_size); @@ -376,7 +373,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) return Status::OK(); } if (s_.post_slicing_required) { - s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); + s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size, GetComputeStream(context)); s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); } else { s_.y_data = reinterpret_cast(s_.Y->MutableData()); @@ -394,7 +391,7 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { } const auto alpha = Consts::One; const auto beta = Consts::Zero; - IAllocatorUniquePtr workspace = GetWorkSpace(context->GetComputeStream()); + IAllocatorUniquePtr workspace = GetWorkSpace(GetComputeStream(context)); auto cudnn_handle = GetCudnnHandle(context); CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(cudnn_handle, &alpha, @@ -481,4 +478,4 @@ Status CudnnConvolutionDescriptor::Set( return Status::OK(); } } // namespace cuda -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 28197c20af052..c7eeadefeb555 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -194,8 +194,7 @@ Status ConvTranspose::CreateCudnnFeExecutionPlan(const onnxruntime::T CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_operation_graph(handle)); CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->create_execution_plans({heur_mode})); } catch (const std::exception& ex) { - std::string message = MakeString("Failed to initialize CUDNN Frontend", ex.what(), - "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + std::string message = MakeString("Failed to initialize CUDNN Frontend: ", ex.what()); return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); } @@ -206,8 +205,7 @@ Status ConvTranspose::CreateCudnnFeExecutionPlan(const onnxruntime::T CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_plans(handle)); } catch (const std::exception& ex) { if (!fuse_bias && !fuse_act && use_tf32) { - std::string message = MakeString("OP not supported by CUDNN Frontend", ex.what(), - "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + std::string message = MakeString("OP not supported by CUDNN Frontend: ", ex.what()); return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); } @@ -225,8 +223,9 @@ Status ConvTranspose::CreateCudnnFeExecutionPlan(const onnxruntime::T template Status ConvTranspose::UpdateState(OpKernelContext* context, bool dynamic_padding) const { constexpr bool channels_last = Layout == LAYOUT_NHWC; - - size_t num_inputs = OpKernel::Node().InputDefs().size(); + size_t num_inputs = static_cast(Info().GetInputCount()); + // Standard ONNX ConvTranspose has inputs X, W, optional B. + // ConvTransposeWithDynamicPads inserts Pads at input 2, so bias becomes input 3. bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; // set X @@ -362,8 +361,6 @@ Status ConvTranspose::UpdateState(OpKernelContext* context, bool dyna s_.Y = context->Output(0, s_.y_dims); s_.y_data = reinterpret_cast(s_.Y->MutableData()); - const CUDAExecutionProvider* cuda_ep = - static_cast(this->Info().GetExecutionProvider()); TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; TensorShapeVector y_dims_cudnn{y_dims.begin(), y_dims.end()}; @@ -390,7 +387,7 @@ Status ConvTranspose::UpdateState(OpKernelContext* context, bool dyna // PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems // to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT. // See PR #7348 and #7702 for more context. - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + if (this->GetCudnnConv1dPadToNc1d()) { x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); w_dims_cudnn.insert(w_dims_cudnn.begin() + 2, 1); @@ -418,7 +415,7 @@ Status ConvTranspose::UpdateState(OpKernelContext* context, bool dyna auto handle = GetCudnnHandle(context); - int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); + int cudnn_conv_algo = this->GetCudnnConvAlgo(); #if !defined(__CUDACC__) cudnn_frontend::HeurMode_t heur_mode; switch (cudnn_conv_algo) { @@ -436,8 +433,8 @@ Status ConvTranspose::UpdateState(OpKernelContext* context, bool dyna break; } - auto use_tf32 = cuda_ep->UseTF32(); - const auto fuse_bias = cuda_ep->IsFuseConvBias() || is_fused_node_; + auto use_tf32 = this->UseTF32(); + const auto fuse_bias = this->IsFuseConvBias() || is_fused_node_; const auto fuse_act = is_fused_node_; ORT_RETURN_IF_ERROR(CreateCudnnFeExecutionPlan(x_dims_cudnn, w_dims_cudnn, B, y_dims_cudnn, handle, heur_mode, @@ -483,7 +480,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool CUDA_RETURN_IF_ERROR(cudaMemset(s_.y_data, 0, s_.Y->SizeInBytes())); } } - auto ws = GetWorkSpace(context->GetComputeStream()); + auto ws = GetWorkSpace(GetComputeStream(context)); CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_graph->execute(cudnn_handle, s_.variant_pack, diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index 1a6957164d22f..072ef5c3221ef 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -24,6 +24,8 @@ class ConvTranspose : public CudaKernel { Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; + // `dynamic_padding` is used by the contrib op ConvTransposeWithDynamicPads, + // which adds a Pads input before the optional bias input. Status DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const; private: @@ -37,7 +39,7 @@ class ConvTranspose : public CudaKernel { bool W_already_nhwc = false; // In case NHWC == true and Conv is not in kMSInternalNHWCDomain protected: - inline IAllocatorUniquePtr GetWorkSpace(onnxruntime::Stream* stream) const { + inline IAllocatorUniquePtr GetWorkSpace(void* stream) const { return GetScratchBuffer(s_.workspace_bytes, stream); } diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h index aa1fe26ac97db..10feb1acf8187 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h @@ -48,19 +48,19 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy TensorShapeVector w_dims = w_shape.AsShapeVector(); auto w_data = reinterpret_cast(W->Data()); - size_t num_inputs = OpKernel::Node().InputDefs().size(); + const size_t num_inputs = static_cast(Info().GetInputCount()); + // Standard ONNX ConvTranspose has inputs X, W, optional B. + // ConvTransposeWithDynamicPads inserts Pads at input 2, so bias becomes input 3. bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; CudaT* y_data = nullptr; - const auto* cuda_ep = static_cast(Info().GetExecutionProvider()); - // convert 1D to 2D if (x_dimensions == 3) { // we can either add a fake H or W dimension with a value of 1. to be consistent with the Conv behavior we use // GetCudnnConv1dPadToNc1d to determine which is added. // see Conv::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details. - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + if (this->GetCudnnConv1dPadToNc1d()) { // add fake H dimension const auto insert_at = NHWC ? 1 : 2; @@ -112,7 +112,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy auto y_dims = p.Y->Shape().AsShapeVector(); if (x_dimensions == 3) { - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + if (this->GetCudnnConv1dPadToNc1d()) { // add fake H dimension of 1 // NCHW: N, M, d1 -> N, M, 1, d1 or // NHWC: N, d1, M -> N, 1, d1, M @@ -190,7 +190,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy if (!s_.cached_benchmark_results.contains(x_dims)) { IAllocatorUniquePtr algo_search_workspace = - GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); + GetScratchBuffer(AlgoSearchWorkspaceSize, GetComputeStream(context)); // set math type to tensor core before algorithm search if constexpr (std::is_same::value) { @@ -220,7 +220,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy if (!y_data) { auto y_dims = s_.y_dims.AsShapeVector(); if (x_dimensions == 3) { - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + if (this->GetCudnnConv1dPadToNc1d()) { // erase the fake H dimension y_dims.erase(y_dims.begin() + (NHWC ? 1 : 2)); } else { @@ -241,7 +241,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy const auto alpha = Consts::One; const auto beta = Consts::Zero; - IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, context->GetComputeStream()); + IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, GetComputeStream(context)); CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardData(GetCudnnHandle(context), &alpha, s_.w_desc, w_data, s_.x_tensor, x_data, s_.conv_desc, s_.algo, workspace.get(), diff --git a/onnxruntime/core/providers/cuda/nn/dropout.cc b/onnxruntime/core/providers/cuda/nn/dropout.cc index 16818d010361a..5011e7aef7872 100644 --- a/onnxruntime/core/providers/cuda/nn/dropout.cc +++ b/onnxruntime/core/providers/cuda/nn/dropout.cc @@ -35,6 +35,17 @@ struct DropoutComputeImpl { } // namespace +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Dropout, kOnnxDomain, 7, 9, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()), + Dropout); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Dropout, kOnnxDomain, 10, 11, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Dropout); + ONNX_OPERATOR_VERSIONED_KERNEL_EX(Dropout, kOnnxDomain, 12, 12, kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) @@ -93,14 +104,22 @@ Status Dropout::ComputeInternal(OpKernelContext* context) const { CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y_data, X_data, X->SizeInBytes(), cudaMemcpyDeviceToDevice, Stream(context))); } - // If mask is requested, return all 1s. + // If mask is requested, fill it appropriately. + // BitmaskDropout (UseBitmask=true): mask is always bitmask where 1 = kept. All 1s for inference. + // Opset 12+: mask is bool, spec says "mask will contain all ones". All true for inference. + // Opset 7-11: mask is memset to 0, consistent with CPU IdentityOp behavior. if (mask) { - if (UseBitmask) { + if constexpr (UseBitmask) { + // BitmaskDropout always uses bitmask semantics (all bits set = all kept). CUDA_RETURN_IF_ERROR( cudaMemsetAsync(mask->MutableDataRaw(), -1, mask_element_count * sizeof(BitmaskElementType), Stream(context))); - } else { + } else if (opset_ >= 12) { CUDA_RETURN_IF_ERROR( cudaMemsetAsync(mask->MutableData(), true, mask_element_count * sizeof(bool), Stream(context))); + } else { + // Opset 7-11: zero-fill mask to match CPU IdentityOp behavior. + CUDA_RETURN_IF_ERROR( + cudaMemsetAsync(mask->MutableDataRaw(), 0, mask->SizeInBytes(), Stream(context))); } } @@ -111,7 +130,7 @@ Status Dropout::ComputeInternal(OpKernelContext* context) const { void* const mask_data = [this, mask_element_count, mask, &temp_mask_buffer, context]() { if (mask) return mask->MutableDataRaw(); temp_mask_buffer = - GetScratchBuffer(mask_element_count * (UseBitmask ? sizeof(BitmaskElementType) : sizeof(bool)), context->GetComputeStream()); + GetScratchBuffer(mask_element_count * (UseBitmask ? sizeof(BitmaskElementType) : sizeof(bool)), GetComputeStream(context)); return temp_mask_buffer.get(); }(); diff --git a/onnxruntime/core/providers/cuda/nn/dropout.h b/onnxruntime/core/providers/cuda/nn/dropout.h index 183456573f317..10049f7320981 100644 --- a/onnxruntime/core/providers/cuda/nn/dropout.h +++ b/onnxruntime/core/providers/cuda/nn/dropout.h @@ -18,6 +18,7 @@ class Dropout final : public CudaKernel { if (info.GetAttr("seed", &seed).IsOK()) { generator_ = std::make_unique(static_cast(seed)); } + opset_ = info.node().SinceVersion(); } Status ComputeInternal(OpKernelContext* context) const override; @@ -25,6 +26,7 @@ class Dropout final : public CudaKernel { private: mutable std::unique_ptr generator_; static constexpr float default_ratio_ = 0.5f; + int opset_ = 12; }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/nn/instance_norm.cc b/onnxruntime/core/providers/cuda/nn/instance_norm.cc index 30ba80dc8b05a..46c8086b74c3f 100644 --- a/onnxruntime/core/providers/cuda/nn/instance_norm.cc +++ b/onnxruntime/core/providers/cuda/nn/instance_norm.cc @@ -104,15 +104,15 @@ Status InstanceNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) co const size_t stats_byte_count = stats_count * sizeof(CudaT); // Mean & Variance are inputs & outputs and must be initialized to zero to work properly - auto mean = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto mean = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mean.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); - auto variance = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto variance = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(variance.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); // We must set the scale & bias inputs to zero as they are inputs to the calculation - auto unused_scale = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto unused_scale = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(unused_scale.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); - auto unused_bias = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto unused_bias = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(unused_bias.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); // first, compute mean and variance per-instance per-channel using cudnnBatchNorm training @@ -201,10 +201,10 @@ Status InstanceNorm::ComputeInternal(OpKernelContext* p_op_kernel_con // alpha, beta will be of type float as the Consts struct specialization // for MLFloat16 type take care of that. Only Convert the scale, bias to float) - auto scale_data_fp32 = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); + auto scale_data_fp32 = GetScratchBuffer(C, GetComputeStream(p_op_kernel_context)); Impl_Cast(Stream(p_op_kernel_context), scale_data, scale_data_fp32.get(), C); - auto bias_data_fp32 = GetScratchBuffer(C, p_op_kernel_context->GetComputeStream()); + auto bias_data_fp32 = GetScratchBuffer(C, GetComputeStream(p_op_kernel_context)); Impl_Cast(Stream(p_op_kernel_context), bias_data, bias_data_fp32.get(), C); CUDNN_RETURN_IF_ERROR(BatchNormalizationForwardTrainingHelper( @@ -247,15 +247,15 @@ Status InstanceNorm::ComputeInternal(OpKernelContext* p_op_kernel_con const size_t stats_byte_count = stats_count * sizeof(float); // Mean & Variance are inputs & outputs and must be initialized to zero to work properly - auto mean = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto mean = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mean.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); - auto variance = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto variance = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(variance.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); // We must set the scale & bias inputs to zero as they are inputs to the calculation - auto unused_scale = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto unused_scale = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(unused_scale.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); - auto unused_bias = GetScratchBuffer(stats_count, p_op_kernel_context->GetComputeStream()); + auto unused_bias = GetScratchBuffer(stats_count, GetComputeStream(p_op_kernel_context)); CUDA_RETURN_IF_ERROR(cudaMemsetAsync(unused_bias.get(), 0, stats_byte_count, Stream(p_op_kernel_context))); // first, compute mean and variance per-instance per-channel using cudnnBatchNorm training diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index 3a97a5f2481e7..89c6047769cf0 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -246,8 +246,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) cons const auto input_count = x_shape.Size(); const auto output_count = y_shape.Size(); - IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count, context->GetComputeStream()); - auto temp_Y = GetScratchBuffer(output_count, context->GetComputeStream()); + IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count, GetComputeStream(context)); + auto temp_Y = GetScratchBuffer(output_count, GetComputeStream(context)); Impl_Cast(Stream(context), reinterpret_cast(x_data), temp_X.get(), input_count); CUDNN_RETURN_IF_ERROR(PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, temp_X.get(), &beta, y_tensor, temp_Y.get())); diff --git a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression.cc b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression.cc index b55efc1180f10..1fbe0573301c4 100644 --- a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression.cc +++ b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression.cc @@ -62,7 +62,7 @@ Status NonMaxSuppression::ComputeInternal(OpKernelContext* ctx) const { IAllocatorUniquePtr d_selected_indices{}; IAllocatorUniquePtr h_number_selected_ptr{AllocateBufferOnCPUPinned(sizeof(int))}; auto* h_number_selected = static_cast(h_number_selected_ptr.get()); - auto* stream = ctx->GetComputeStream(); + auto* stream = GetComputeStream(ctx); ORT_RETURN_IF_ERROR(NonMaxSuppressionImpl( Stream(ctx), [this, stream](size_t bytes) { return GetScratchBuffer(bytes, stream); }, @@ -114,10 +114,10 @@ Status NonMaxSuppression::ComputeInternal(OpKernelContext* ctx) const { } } - ORT_RETURN_IF_ERROR(concat_sizes_gpu.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(concat_sizes_range_gpu.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(concat_sizes_gpu.CopyToGpu(GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu(GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(concat_sizes_range_gpu.CopyToGpu(GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu(GetComputeStream(ctx))); ORT_RETURN_IF_ERROR(ConcatImpl(Stream(ctx), sizeof(int64_t), diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc new file mode 100644 index 0000000000000..8f2195b03d1a1 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.cc @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cuda_allocator_plugin.h" + +namespace onnxruntime { +namespace cuda_plugin { + +namespace { + +void RestoreDeviceIfKnown(bool restore_prev_device, int prev_device) noexcept { + if (restore_prev_device) { + static_cast(cudaSetDevice(prev_device)); + } +} + +} // namespace + +// --------------------------------------------------------------------------- +// CudaDeviceAllocator — uses cudaMalloc/cudaFree for GPU device memory. +// +// PERFORMANCE NOTE (Direct cudaMalloc Penalty): +// No arena or caching layer is provided within this plugin. Every allocation +// goes directly to CUDA (cudaMalloc). For models with dynamic shape resizing +// or many intermediate buffers, this can cause substantial overhead. +// Compared to the built-in CUDA Execution Provider, which has an integrated +// memory arena, this is a notable performance gap unless an external +// memory pool/arena is injected or configured by the application. +// --------------------------------------------------------------------------- + +CudaDeviceAllocator::CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int device_id) + : CudaAllocatorBase(CudaAllocatorKind::kDevice, memory_info), + device_id_(device_id) { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + Reserve = ReserveImpl; + GetStats = nullptr; + AllocOnStream = nullptr; +} + +/*static*/ void* ORT_API_CALL CudaDeviceAllocator::AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept { + auto* alloc = static_cast(this_ptr); + void* p = nullptr; + if (size == 0) return nullptr; + // Save and restore CUDA device context to avoid corrupting the calling + // thread's device state in multi-GPU scenarios. + int prev_device = -1; + const bool restore_prev_device = cudaGetDevice(&prev_device) == cudaSuccess; + if (cudaSetDevice(alloc->device_id_) != cudaSuccess) { + RestoreDeviceIfKnown(restore_prev_device, prev_device); + return nullptr; + } + cudaError_t err = cudaMalloc(&p, size); + RestoreDeviceIfKnown(restore_prev_device, prev_device); + if (err != cudaSuccess) { + return nullptr; + } + return p; +} + +/*static*/ void ORT_API_CALL CudaDeviceAllocator::FreeImpl(OrtAllocator* this_ptr, void* p) noexcept { + auto* alloc = static_cast(this_ptr); + if (p != nullptr) { + int prev_device = -1; + const bool restore_prev_device = cudaGetDevice(&prev_device) == cudaSuccess; + if (cudaSetDevice(alloc->device_id_) != cudaSuccess) { + RestoreDeviceIfKnown(restore_prev_device, prev_device); + return; + } + + static_cast(cudaFree(p)); + RestoreDeviceIfKnown(restore_prev_device, prev_device); + } +} + +/*static*/ const OrtMemoryInfo* ORT_API_CALL CudaDeviceAllocator::InfoImpl(const OrtAllocator* this_ptr) noexcept { + const auto* alloc = static_cast(this_ptr); + return alloc->GetMemoryInfo(); +} + +/*static*/ void* ORT_API_CALL CudaDeviceAllocator::ReserveImpl(OrtAllocator* this_ptr, size_t size) noexcept { + // Reserve currently delegates to Alloc (no separate reservation pool). + return AllocImpl(this_ptr, size); +} + +// --------------------------------------------------------------------------- +// CudaPinnedAllocator — uses cudaHostAlloc/cudaFreeHost for page-locked +// host memory visible to the GPU. +// --------------------------------------------------------------------------- + +CudaPinnedAllocator::CudaPinnedAllocator(const OrtMemoryInfo* memory_info) + : CudaAllocatorBase(CudaAllocatorKind::kPinned, memory_info) { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + Reserve = ReserveImpl; + GetStats = nullptr; + AllocOnStream = nullptr; +} + +/*static*/ void* ORT_API_CALL CudaPinnedAllocator::AllocImpl(OrtAllocator* /*this_ptr*/, size_t size) noexcept { + void* p = nullptr; + if (size == 0) return nullptr; + cudaError_t err = cudaHostAlloc(&p, size, cudaHostAllocDefault); + if (err != cudaSuccess) { + return nullptr; + } + return p; +} + +/*static*/ void ORT_API_CALL CudaPinnedAllocator::FreeImpl(OrtAllocator* /*this_ptr*/, void* p) noexcept { + if (p != nullptr) { + cudaFreeHost(p); + } +} + +/*static*/ const OrtMemoryInfo* ORT_API_CALL CudaPinnedAllocator::InfoImpl(const OrtAllocator* this_ptr) noexcept { + const auto* alloc = static_cast(this_ptr); + return alloc->GetMemoryInfo(); +} + +/*static*/ void* ORT_API_CALL CudaPinnedAllocator::ReserveImpl(OrtAllocator* this_ptr, size_t size) noexcept { + // Reserve currently delegates to Alloc (no separate reservation pool). + return AllocImpl(this_ptr, size); +} + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h new file mode 100644 index 0000000000000..8b0d41cad6541 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_allocator_plugin.h @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// CUDA device and pinned memory allocator implementations for the plugin EP. +// Provides CudaDeviceAllocator (cudaMalloc/cudaFree) and CudaPinnedAllocator +// (cudaHostAlloc/cudaFreeHost) conforming to the OrtAllocator interface. +// No arena or caching layer; every allocation goes directly to CUDA. + +#pragma once + +#include "cuda_plugin_utils.h" + +namespace onnxruntime { +namespace cuda_plugin { + +/// Allocator type: device memory (GPU) or pinned (page-locked host) memory. +enum class CudaAllocatorKind { + kDevice, ///< GPU device memory via cudaMalloc + kPinned, ///< Page-locked host memory via cudaHostAlloc +}; + +/// Base class for CUDA allocators implementing the OrtAllocator C interface. +class CudaAllocatorBase : public OrtAllocator { + public: + explicit CudaAllocatorBase(CudaAllocatorKind kind, const OrtMemoryInfo* memory_info) + : OrtAllocator{}, + kind_(kind), + memory_info_(memory_info) {} + + CudaAllocatorKind GetKind() const { return kind_; } + const OrtMemoryInfo* GetMemoryInfo() const { return memory_info_; } + + private: + CudaAllocatorKind kind_; + const OrtMemoryInfo* memory_info_; +}; + +/// CUDA device memory allocator using cudaMalloc/cudaFree. +/// Lifetime is managed by the EP factory (ReleaseAllocatorImpl), not by a Release callback. +class CudaDeviceAllocator final : public CudaAllocatorBase { + public: + CudaDeviceAllocator(const OrtMemoryInfo* memory_info, int device_id); + ~CudaDeviceAllocator() = default; + + private: + static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept; + static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept; + static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept; + static void* ORT_API_CALL ReserveImpl(OrtAllocator* this_ptr, size_t size) noexcept; + + int device_id_; +}; + +/// CUDA pinned (host) memory allocator using cudaHostAlloc/cudaFreeHost. +/// Lifetime is managed by the EP factory (ReleaseAllocatorImpl), not by a Release callback. +class CudaPinnedAllocator final : public CudaAllocatorBase { + public: + CudaPinnedAllocator(const OrtMemoryInfo* memory_info); + ~CudaPinnedAllocator() = default; + + private: + static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept; + static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept; + static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept; + static void* ORT_API_CALL ReserveImpl(OrtAllocator* this_ptr, size_t size) noexcept; +}; + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc new file mode 100644 index 0000000000000..a65c4e925c97f --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cc @@ -0,0 +1,506 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Plugin EP control flow kernel implementations for If, Loop, and Scan. +// These delegate to OrtEpApi::CreateIfKernel/CreateLoopKernel/CreateScanKernel +// instead of inheriting from CPU base classes. + +#include "core/providers/cuda/plugin/cuda_controlflow_plugin.h" +#include + +namespace onnxruntime { +namespace cuda { +namespace plugin { + +namespace { + +/// Determine byte size of a single element for the given ONNX data type. +/// Used by Scan transpose kernel to allocate and copy tensor data. +/// Returns error for sub-byte types (INT4, UINT4) and strings. +Status GetTensorElementStorageSize(ONNXTensorElementDataType elem_type, size_t& element_size) { + switch (elem_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: + element_size = 1; + return Status::OK(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + element_size = 2; + return Status::OK(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + element_size = 4; + return Status::OK(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: + element_size = 8; + return Status::OK(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: + element_size = 16; + return Status::OK(); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT2: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Scan Transpose: packed sub-byte tensor types are unsupported"); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Scan Transpose: string tensors are unsupported"); + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Scan Transpose: unsupported element type ", static_cast(elem_type)); + } +} + +} // namespace + +// =================================================================== +// If kernel +// =================================================================== + +Status PluginIfKernel::CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) { + OrtStatus* status = Ort::GetEpApi().CreateIfKernel(info, impl); + if (status) { + std::string msg = Ort::GetApi().GetErrorMessage(status); + Ort::GetApi().ReleaseStatus(status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, msg); + } + return Status::OK(); +} + +// =================================================================== +// Loop kernel helper +// =================================================================== + +PluginLoopHelper::PluginLoopHelper() : OrtLoopKernelHelper{} { + ort_version_supported = ORT_API_VERSION; + Release = ReleaseImpl; + ConcatOutput = ConcatOutputImpl; +} + +/*static*/ +void ORT_API_CALL PluginLoopHelper::ReleaseImpl(_In_ OrtLoopKernelHelper* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +/*static*/ +OrtStatus* ORT_API_CALL PluginLoopHelper::ConcatOutputImpl( + _In_ OrtLoopKernelHelper* /*this_ptr*/, + _In_opt_ void* stream_handle, + _In_reads_(num_per_iteration_outputs) const OrtValue* const* per_iteration_outputs, + _In_ size_t num_per_iteration_outputs, + _Out_writes_bytes_all_(output_size_in_bytes) void* output, + _In_ size_t output_size_in_bytes) noexcept { + try { + if (num_per_iteration_outputs == 0) return nullptr; + + cudaStream_t cuda_stream = static_cast(stream_handle); + + Ort::ConstValue first_output(per_iteration_outputs[0]); + size_t bytes_per_iteration = first_output.GetTensorSizeInBytes(); + if (bytes_per_iteration > output_size_in_bytes) { + return Ort::Status("Loop ConcatOutput: output buffer too small for first iteration", ORT_FAIL).release(); + } + + char* cur = static_cast(output); + size_t total_bytes_copied = 0; + for (size_t i = 0; i < num_per_iteration_outputs; i++) { + Ort::ConstValue val(per_iteration_outputs[i]); + size_t cur_bytes = val.GetTensorSizeInBytes(); + if (cur_bytes != bytes_per_iteration) { + return Ort::Status("Inconsistent size in loop output iteration", ORT_FAIL).release(); + } + if (cur_bytes > output_size_in_bytes - total_bytes_copied) { + return Ort::Status("Loop ConcatOutput: output buffer too small", ORT_FAIL).release(); + } + PL_CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cur, val.GetTensorRawData(), bytes_per_iteration, + cudaMemcpyDeviceToDevice, cuda_stream)); + cur += bytes_per_iteration; + total_bytes_copied += bytes_per_iteration; + } + + if (total_bytes_copied != output_size_in_bytes) { + return Ort::Status("Loop ConcatOutput: output buffer not fully filled", ORT_FAIL).release(); + } + + return nullptr; + } catch (const std::exception& ex) { + return Ort::Status(ex.what(), ORT_RUNTIME_EXCEPTION).release(); + } +} + +// =================================================================== +// Loop kernel +// =================================================================== + +Status PluginLoopKernel::CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) { + auto helper = std::make_unique(); + OrtStatus* status = Ort::GetEpApi().CreateLoopKernel(info, helper.get(), impl); + if (status) { + std::string msg = Ort::GetApi().GetErrorMessage(status); + Ort::GetApi().ReleaseStatus(status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, msg); + } + helper.release(); // ORT takes ownership on success + return Status::OK(); +} + +// =================================================================== +// Scan kernel helper +// =================================================================== + +PluginScanHelper::PluginScanHelper() : OrtScanKernelHelper{} { + ort_version_supported = ORT_API_VERSION; + Release = ReleaseImpl; + Transpose = TransposeImpl; +} + +/*static*/ +void ORT_API_CALL PluginScanHelper::ReleaseImpl(_In_ OrtScanKernelHelper* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +/*static*/ +OrtStatus* ORT_API_CALL PluginScanHelper::TransposeImpl( + _In_ OrtScanKernelHelper* /*this_ptr*/, + _In_reads_(num_permutation_elems) const size_t* permutation, + _In_ size_t num_permutation_elems, + _In_ const OrtValue* ort_input, + _In_opt_ OrtSyncStream* stream, + _Inout_ OrtValue* ort_output) noexcept { + try { + // Get the CUDA stream from the OrtSyncStream + cudaStream_t cuda_stream = nullptr; + if (stream) { + const OrtSyncStreamImpl* impl = Ort::GetEpApi().SyncStream_GetImpl(stream); + if (impl) { + // GetHandle is a function pointer on OrtSyncStreamImpl + cuda_stream = static_cast( + const_cast(impl)->GetHandle(const_cast(impl))); + } + } + + Ort::ConstValue input(ort_input); + Ort::UnownedValue output(ort_output); + + Ort::TensorTypeAndShapeInfo input_info = input.GetTensorTypeAndShapeInfo(); + std::vector input_shape = input_info.GetShape(); + size_t num_dims = input_shape.size(); + size_t total_elements = input_info.GetElementCount(); + + if (num_dims != num_permutation_elems) { + return Ort::Status("Scan Transpose: permutation size does not match input rank", ORT_FAIL).release(); + } + + std::vector seen_permutation_indices(num_dims, false); + for (size_t i = 0; i < num_permutation_elems; ++i) { + const size_t perm_index = permutation[i]; + if (perm_index >= num_dims) { + return Ort::Status("Scan Transpose: permutation index is out of range", ORT_FAIL).release(); + } + if (seen_permutation_indices[perm_index]) { + return Ort::Status("Scan Transpose: permutation contains duplicate indices", ORT_FAIL).release(); + } + seen_permutation_indices[perm_index] = true; + } + + if (total_elements == 0) return nullptr; + + // Determine element size from the data type + ONNXTensorElementDataType elem_type = input_info.GetElementType(); + size_t element_size = 0; + auto status = GetTensorElementStorageSize(elem_type, element_size); + if (!status.IsOK()) { + return Ort::Status(status.ErrorMessage().c_str(), ORT_EP_FAIL).release(); + } + + const void* input_data = input.GetTensorRawData(); + void* output_data = output.GetTensorMutableData(); + + // Launch the GPU transpose kernel + OrtStatus* ort_status = LaunchTransposeKernel(input_data, output_data, + input_shape.data(), permutation, + num_dims, element_size, total_elements, + cuda_stream); + if (ort_status != nullptr) { + return ort_status; + } + + return nullptr; + } catch (const std::exception& ex) { + return Ort::Status(ex.what(), ORT_RUNTIME_EXCEPTION).release(); + } +} + +// =================================================================== +// Scan kernel +// =================================================================== + +Status PluginScanKernel::CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) { + auto helper = std::make_unique(); + OrtStatus* status = Ort::GetEpApi().CreateScanKernel(info, helper.get(), impl); + if (status) { + std::string msg = Ort::GetApi().GetErrorMessage(status); + Ort::GetApi().ReleaseStatus(status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, msg); + } + helper.release(); // ORT takes ownership on success + return Status::OK(); +} + +} // namespace plugin +} // namespace cuda +} // namespace onnxruntime + +// =================================================================== +// Kernel Registrations — same opset versions as the framework CUDA EP +// =================================================================== + +using namespace onnxruntime::cuda::plugin; + +namespace onnxruntime { +namespace cuda { + +// --- If --- + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 1, 10, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginIfKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 11, 12, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginIfKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 13, 18, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + // The adapter EP API currently exposes tensor OrtDataType creation only. + .TypeConstraint("V", DataTypeImpl::AllTensorTypes()), + PluginIfKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 19, 20, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), + PluginIfKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 21, 22, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), + PluginIfKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(If, + kOnnxDomain, + 23, 24, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), + PluginIfKernel); + +ONNX_OPERATOR_KERNEL_EX(If, + kOnnxDomain, + 25, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), + PluginIfKernel); + +// --- Loop --- + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, + kOnnxDomain, + 1, 10, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginLoopKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, + kOnnxDomain, + 11, 12, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginLoopKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, + kOnnxDomain, + 13, 18, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorTypes()), + PluginLoopKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, + kOnnxDomain, + 19, 20, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), + PluginLoopKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, + kOnnxDomain, + 21, 22, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), + PluginLoopKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Loop, + kOnnxDomain, + 23, 24, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), + PluginLoopKernel); + +ONNX_OPERATOR_KERNEL_EX(Loop, + kOnnxDomain, + 25, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("B", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorTypesIRv9()), + PluginLoopKernel); + +// --- Scan (opset 8) --- + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, + kOnnxDomain, + 8, 8, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("I", DataTypeImpl::GetTensorType()) + .TypeConstraint("V", DataTypeImpl::AllTensorTypes()), + PluginScanKernel); + +// --- Scan (opset 9+) --- + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, + kOnnxDomain, + 9, 10, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginScanKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, + kOnnxDomain, + 11, 15, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginScanKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, + kOnnxDomain, + 16, 18, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), + PluginScanKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, + kOnnxDomain, + 19, 20, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + PluginScanKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, + kOnnxDomain, + 21, 22, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + PluginScanKernel); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Scan, + kOnnxDomain, + 23, 24, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + PluginScanKernel); + +ONNX_OPERATOR_KERNEL_EX(Scan, + kOnnxDomain, + 25, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + PluginScanKernel); + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu new file mode 100644 index 0000000000000..6ff3296dadcb8 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.cu @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// GPU transpose kernel for the Scan control flow helper. +// Supports permutations up to kMaxTransposeDims dimensions by computing +// output coordinates from linear indices. + +#include +#include +#include +#include +#include + +#include "cuda_plugin_utils.h" + +namespace onnxruntime { +namespace cuda { +namespace plugin { + +namespace { + +// Maximum number of dimensions supported by the transpose kernel. +// Most real-world tensors have <= 8 dimensions. +constexpr int kMaxTransposeDims = 8; + +struct TransposeArgs { + int64_t input_strides[kMaxTransposeDims]; + int64_t output_strides[kMaxTransposeDims]; + int perm[kMaxTransposeDims]; +}; + +} // namespace + +// Kernel: each thread handles one element, computing its output position +// from the input position via the permutation. +__global__ void TransposeNDKernel(const char* __restrict__ input, + char* __restrict__ output, + TransposeArgs args, + int num_dims, + size_t element_size, + size_t total_elements) { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= total_elements) return; + + // Decompose linear index into input coordinates + int64_t coords[kMaxTransposeDims]; + size_t remaining = idx; + for (int d = 0; d < num_dims; d++) { + coords[d] = static_cast(remaining / static_cast(args.input_strides[d])); + remaining %= static_cast(args.input_strides[d]); + } + + // Compute output linear index via permutation + size_t out_idx = 0; + for (int d = 0; d < num_dims; d++) { + out_idx += static_cast(coords[args.perm[d]]) * static_cast(args.output_strides[d]); + } + + // Copy element bytes + const char* src = input + idx * element_size; + char* dst = output + out_idx * element_size; + // Use memcpy for arbitrary element sizes (compiler optimizes for common sizes) + memcpy(dst, src, element_size); +} + +OrtStatus* LaunchTransposeKernel(const void* input, void* output, + const int64_t* input_shape, const size_t* permutation, + size_t num_dims, size_t element_size, size_t total_elements, + cudaStream_t stream) { + if (total_elements == 0 || num_dims == 0) { + return nullptr; + } + + if (num_dims > static_cast(kMaxTransposeDims)) { + return Ort::Status("Scan Transpose: rank exceeds the supported maximum rank", ORT_FAIL).release(); + } + + TransposeArgs args; + + // Compute input strides (row-major) + args.input_strides[num_dims - 1] = 1; + for (int d = static_cast(num_dims) - 2; d >= 0; d--) { + args.input_strides[d] = args.input_strides[d + 1] * input_shape[d + 1]; + } + + // Compute output shape and strides from permutation + int64_t output_shape[kMaxTransposeDims]; + for (size_t d = 0; d < num_dims; d++) { + output_shape[d] = input_shape[permutation[d]]; + args.perm[d] = static_cast(permutation[d]); + } + args.output_strides[num_dims - 1] = 1; + for (int d = static_cast(num_dims) - 2; d >= 0; d--) { + args.output_strides[d] = args.output_strides[d + 1] * output_shape[d + 1]; + } + + constexpr int kBlockSize = 256; + int num_blocks = static_cast((total_elements + kBlockSize - 1) / kBlockSize); + + TransposeNDKernel<<>>( + static_cast(input), + static_cast(output), + args, + static_cast(num_dims), + element_size, + total_elements); + + PL_CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + return nullptr; +} + +} // namespace plugin +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h new file mode 100644 index 0000000000000..da6fb94023333 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_controlflow_plugin.h @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Plugin EP control flow kernel wrappers for If, Loop, and Scan. +// These delegate to OrtEpApi::CreateIfKernel/CreateLoopKernel/CreateScanKernel +// instead of inheriting from CPU base classes. + +#pragma once + +#include "core/session/onnxruntime_cxx_api.h" + +namespace onnxruntime { +namespace cuda { +namespace plugin { + +// =================================================================== +// If kernel wrapper — delegates to OrtEpApi::CreateIfKernel +// =================================================================== + +class PluginIfKernel : public OpKernel { + public: + explicit PluginIfKernel(const OpKernelInfo& info) : OpKernel(info) {} + + Status CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) override; + + Status Compute(OpKernelContext*) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Plugin If kernel should not be called directly"); + } +}; + +// =================================================================== +// Loop kernel helper — provides GPU ConcatOutput via cudaMemcpyAsync +// =================================================================== + +class PluginLoopHelper : public OrtLoopKernelHelper { + public: + PluginLoopHelper(); + + static void ORT_API_CALL ReleaseImpl(_In_ OrtLoopKernelHelper* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL ConcatOutputImpl( + _In_ OrtLoopKernelHelper* this_ptr, + _In_opt_ void* stream_handle, + _In_reads_(num_per_iteration_outputs) const OrtValue* const* per_iteration_outputs, + _In_ size_t num_per_iteration_outputs, + _Out_writes_bytes_all_(output_size_in_bytes) void* output, + _In_ size_t output_size_in_bytes) noexcept; +}; + +class PluginLoopKernel : public OpKernel { + public: + explicit PluginLoopKernel(const OpKernelInfo& info) : OpKernel(info) {} + + Status CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) override; + + Status Compute(OpKernelContext*) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Plugin Loop kernel should not be called directly"); + } +}; + +// =================================================================== +// Scan kernel helper — provides GPU Transpose via CUDA kernel +// =================================================================== + +class PluginScanHelper : public OrtScanKernelHelper { + public: + PluginScanHelper(); + + static void ORT_API_CALL ReleaseImpl(_In_ OrtScanKernelHelper* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL TransposeImpl( + _In_ OrtScanKernelHelper* this_ptr, + _In_reads_(num_permutation_elems) const size_t* permutation, + _In_ size_t num_permutation_elems, + _In_ const OrtValue* input, + _In_opt_ OrtSyncStream* stream, + _Inout_ OrtValue* output) noexcept; +}; + +class PluginScanKernel : public OpKernel { + public: + explicit PluginScanKernel(const OpKernelInfo& info) : OpKernel(info) {} + + Status CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) override; + + Status Compute(OpKernelContext*) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Plugin Scan kernel should not be called directly"); + } +}; + +// GPU transpose helper (defined in cuda_controlflow_plugin.cu) +OrtStatus* LaunchTransposeKernel(const void* input, void* output, + const int64_t* input_shape, const size_t* permutation, + size_t num_dims, size_t element_size, size_t total_elements, + cudaStream_t stream); + +} // namespace plugin +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc new file mode 100644 index 0000000000000..e4b3ed8f3c314 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.cc @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cuda_data_transfer_plugin.h" + +namespace onnxruntime { +namespace cuda_plugin { + +CudaDataTransfer::CudaDataTransfer(const OrtApi& ort_api, const OrtEpApi& ep_api) + : OrtDataTransferImpl{}, + ort_api_(ort_api), + ep_api_(ep_api) { + ort_version_supported = ORT_API_VERSION; + Release = ReleaseImpl; + CanCopy = CanCopyImpl; + CopyTensors = CopyTensorsImpl; +} + +/*static*/ void ORT_API_CALL CudaDataTransfer::ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +/*static*/ bool ORT_API_CALL CudaDataTransfer::CanCopyImpl( + const OrtDataTransferImpl* this_ptr, + const OrtMemoryDevice* src_device, + const OrtMemoryDevice* dst_device) noexcept { + auto* dt = static_cast(this_ptr); + const OrtEpApi& ep_api = dt->ep_api_; + auto src_type = ep_api.MemoryDevice_GetDeviceType(src_device); + auto dst_type = ep_api.MemoryDevice_GetDeviceType(dst_device); + + bool src_is_cpu = (src_type == OrtMemoryInfoDeviceType_CPU); + bool dst_is_cpu = (dst_type == OrtMemoryInfoDeviceType_CPU); + bool src_is_gpu = (src_type == OrtMemoryInfoDeviceType_GPU); + bool dst_is_gpu = (dst_type == OrtMemoryInfoDeviceType_GPU); + + if ((src_is_gpu && ep_api.MemoryDevice_GetVendorId(src_device) != OrtDevice::VendorIds::NVIDIA) || + (dst_is_gpu && ep_api.MemoryDevice_GetVendorId(dst_device) != OrtDevice::VendorIds::NVIDIA)) { + return false; + } + + // Support CPU→GPU, GPU→CPU, GPU→GPU + return (src_is_cpu && dst_is_gpu) || + (src_is_gpu && dst_is_cpu) || + (src_is_gpu && dst_is_gpu); +} + +/*static*/ OrtStatus* ORT_API_CALL CudaDataTransfer::CopyTensorsImpl( + OrtDataTransferImpl* this_ptr, + const OrtValue** src_tensors, + OrtValue** dst_tensors, + OrtSyncStream** streams, + size_t count) noexcept { + EXCEPTION_TO_STATUS_BEGIN + + auto* dt = static_cast(this_ptr); + bool need_stream_sync = false; + + for (size_t i = 0; i < count; ++i) { + Ort::ConstValue src{src_tensors[i]}; + Ort::UnownedValue dst{dst_tensors[i]}; + + size_t bytes = 0; + auto* status = dt->ort_api_.GetTensorSizeInBytes(src_tensors[i], &bytes); + if (status != nullptr) { + return status; + } + if (bytes == 0) continue; + + const void* src_data = src.GetTensorRawData(); + void* dst_data = dst.GetTensorMutableRawData(); + + // Determine copy direction + const OrtMemoryInfo* src_mem_info = src.GetTensorMemoryInfo(); + const OrtMemoryInfo* dst_mem_info = dst.GetTensorMemoryInfo(); + const OrtMemoryDevice* src_dev = dt->ep_api_.MemoryInfo_GetMemoryDevice(src_mem_info); + const OrtMemoryDevice* dst_dev = dt->ep_api_.MemoryInfo_GetMemoryDevice(dst_mem_info); + auto src_dev_type = dt->ep_api_.MemoryDevice_GetDeviceType(src_dev); + auto dst_dev_type = dt->ep_api_.MemoryDevice_GetDeviceType(dst_dev); + auto src_mem_type = dt->ep_api_.MemoryDevice_GetMemoryType(src_dev); + + cudaMemcpyKind copy_kind; + if (src_dev_type == OrtMemoryInfoDeviceType_CPU && dst_dev_type == OrtMemoryInfoDeviceType_GPU) { + copy_kind = cudaMemcpyHostToDevice; + } else if (src_dev_type == OrtMemoryInfoDeviceType_GPU && dst_dev_type == OrtMemoryInfoDeviceType_CPU) { + copy_kind = cudaMemcpyDeviceToHost; + } else if (src_dev_type == OrtMemoryInfoDeviceType_GPU && dst_dev_type == OrtMemoryInfoDeviceType_GPU) { + copy_kind = cudaMemcpyDeviceToDevice; + } else { + return dt->ort_api_.CreateStatus(ORT_EP_FAIL, "Unsupported copy direction"); + } + + // Use async copy if stream is provided + if (streams != nullptr && streams[i] != nullptr) { + cudaStream_t cuda_stream = static_cast( + Ort::GetApi().SyncStream_GetHandle(streams[i])); + PL_CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, copy_kind, cuda_stream)); + } else { + if (copy_kind == cudaMemcpyDeviceToDevice && dst_data == src_data) { + continue; + } + + PL_CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, copy_kind)); + + if (copy_kind == cudaMemcpyDeviceToDevice) { + // Match the built-in CUDA EP: cudaMemcpy D2D launches on the default + // stream but does not guarantee host-side completion on return. + need_stream_sync = true; + } else if (copy_kind == cudaMemcpyHostToDevice && src_mem_type != OrtDeviceMemoryType_HOST_ACCESSIBLE) { + // Pageable host memory may still be in flight after cudaMemcpy returns. + need_stream_sync = true; + } + } + } + + if (need_stream_sync) { + PL_CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } + + return nullptr; + + EXCEPTION_TO_STATUS_END +} + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h new file mode 100644 index 0000000000000..a43f90cf01f72 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_data_transfer_plugin.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// CUDA data transfer implementation for CPU<->GPU and GPU<->GPU memory copies. +// Implements OrtDataTransferImpl to handle synchronous and async copies +// via cudaMemcpy/cudaMemcpyAsync. + +#pragma once + +#include "cuda_plugin_utils.h" + +namespace onnxruntime { +namespace cuda_plugin { + +/// CUDA data transfer implementation for CPU↔GPU and GPU↔GPU copies. +class CudaDataTransfer : public OrtDataTransferImpl { + public: + CudaDataTransfer(const OrtApi& ort_api, const OrtEpApi& ep_api); + ~CudaDataTransfer() = default; + + private: + static void ORT_API_CALL ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept; + + static bool ORT_API_CALL CanCopyImpl( + const OrtDataTransferImpl* this_ptr, + const OrtMemoryDevice* src_device, + const OrtMemoryDevice* dst_device) noexcept; + + static OrtStatus* ORT_API_CALL CopyTensorsImpl( + OrtDataTransferImpl* this_ptr, + const OrtValue** src_tensors, + OrtValue** dst_tensors, + OrtSyncStream** streams, + size_t count) noexcept; + + const OrtApi& ort_api_; + const OrtEpApi& ep_api_; +}; + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc new file mode 100644 index 0000000000000..e6934548acfd0 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -0,0 +1,245 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cuda_ep.h" +#include "cuda_ep_factory.h" +#include "cuda_stream_plugin.h" +#include "core/providers/cuda/plugin/cuda_kernel_adapter.h" +#include "core/providers/cuda/cuda_allocator.h" +#include "core/framework/allocator.h" +#include "ep/get_capability_utils.h" + +#include +#include +#include +#include + +namespace onnxruntime { +namespace cuda_plugin { + +namespace { + +std::unique_ptr CreateCudaPluginProvider(std::string_view ep_name, const OrtEp* ort_ep) { + return std::make_unique<::onnxruntime::CUDAExecutionProvider>(std::string{ep_name}, ort_ep); +} + +AllocatorPtr CreateCudaPluginTempSpaceAllocator(int device_id) { + return std::make_shared<::onnxruntime::CUDAAllocator>(device_id, ::onnxruntime::CUDA); +} + +AllocatorPtr CreateCudaPluginTempSpaceCpuAllocator() { + return ::onnxruntime::CPUAllocator::DefaultInstance(); +} + +} // namespace + +CudaEp::CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& logger) + : onnxruntime::ep::adapter::Ep{CreateCudaPluginProvider(factory.GetEpName(), static_cast(this)), + CreateCudaPluginTempSpaceCpuAllocator(), + CreateCudaPluginTempSpaceAllocator(config.device_id)}, + factory_(factory), + name_(factory.GetEpName()), + config_(config), + logger_(logger) { + ort_version_supported = ORT_API_VERSION; + + // Set function pointers for kernel-registry-based EP + GetName = GetNameImpl; + GetCapability = GetCapabilityImpl; + GetKernelRegistry = GetKernelRegistryImpl; + GetPreferredDataLayout = GetPreferredDataLayoutImpl; + ShouldConvertDataLayoutForOp = ShouldConvertDataLayoutForOpImpl; + OnRunStart = nullptr; + OnRunEnd = nullptr; + + // Not a compile-based EP + Compile = nullptr; + ReleaseNodeComputeInfos = nullptr; + + const OrtApi& ort_api = factory_.GetOrtApi(); + Ort::Status log_status(ort_api.Logger_LogMessage(&logger_, ORT_LOGGING_LEVEL_INFO, + "CUDA Plugin EP created", + ORT_FILE, __LINE__, __FUNCTION__)); + + // Store per-EP runtime configuration inside the adapter-wrapped execution + // provider itself. Migrated kernels retrieve a shared config object at + // compute time via GetCudaKernelAdapterRuntimeConfigForProvider(). + // Adding a new config field only requires updating + // CudaKernelAdapterRuntimeConfig, CudaEp::Config, and the struct-initializer + // below — no function-signature change. + onnxruntime::cuda::detail::CudaKernelAdapterRuntimeConfig adapter_config; + adapter_config.use_tf32 = config_.use_tf32; + adapter_config.skip_layer_norm_strict_mode = config_.enable_skip_layer_norm_strict_mode; + adapter_config.cudnn_conv_algo = config_.cudnn_conv_algo; + adapter_config.cudnn_conv_use_max_workspace = config_.cudnn_conv_use_max_workspace; + adapter_config.cudnn_conv1d_pad_to_nc1d = config_.cudnn_conv1d_pad_to_nc1d; + adapter_config.fuse_conv_bias = config_.fuse_conv_bias; + adapter_config.sdpa_kernel = config_.sdpa_kernel; + adapter_config.device_id = config_.device_id; + onnxruntime::cuda::SetCudaKernelAdapterRuntimeConfigForProvider( + static_cast(EpImpl()), adapter_config); +} + +CudaEp::~CudaEp() = default; + +/*static*/ +const char* ORT_API_CALL CudaEp::GetNameImpl(const OrtEp* this_ptr) noexcept { + return static_cast(this_ptr)->name_.c_str(); +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( + OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + EXCEPTION_TO_STATUS_BEGIN + + auto* ep = static_cast(this_ptr); + const OrtEpApi& ep_api = ep->factory_.GetEpApi(); + + Ort::ConstGraph graph{ort_graph}; + std::vector all_nodes = graph.GetNodes(); + + if (all_nodes.empty()) { + return nullptr; + } + + // Three-phase filtering determines which graph nodes run on this EP: + // Phase 1: Collect tentative nodes that have a registered CUDA kernel. + // Phase 2: Filter out CPU-preferred nodes (cheap ops where device-to-host + // copy overhead would exceed the compute benefit). + // Phase 3: Register remaining nodes as supported by this EP. + + // Phase 1: Collect tentative nodes — those for which we have a registered kernel. + std::vector candidate_nodes; + candidate_nodes.reserve(all_nodes.size()); + std::vector tentative_nodes; + tentative_nodes.reserve(all_nodes.size()); + + for (const auto& node : all_nodes) { + std::string ep_name = node.GetEpName(); + if (!ep_name.empty()) { + if (ep_name == ep->name_) { + candidate_nodes.push_back(node); + } + continue; + } + + const OrtKernelDef* kernel_def = nullptr; + RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_LookUpKernel( + graph_support_info, node, &kernel_def)); + + if (kernel_def != nullptr) { + candidate_nodes.push_back(node); + tentative_nodes.push_back(node); + } + } + + // Phase 2: Filter out CPU-preferred nodes (e.g., Shape, NonZero, small compute ops + // that would be cheaper on CPU than incurring device-to-host copy overhead). + std::unordered_set cpu_preferred_nodes; + RETURN_IF_ERROR(ep::GetCpuPreferredNodes( + *ort_graph, *graph_support_info, ep->logger_, + gsl::span(tentative_nodes.data(), tentative_nodes.size()), + cpu_preferred_nodes)); + + // Phase 3: Add final supported nodes (tentative minus CPU-preferred). + for (const OrtNode* ort_node : candidate_nodes) { + if (cpu_preferred_nodes.count(ort_node) == 0) { + Ort::ConstNode node{ort_node}; + RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_AddSingleNode( + graph_support_info, node)); + } + } + + return nullptr; + + EXCEPTION_TO_STATUS_END +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEp::GetKernelRegistryImpl( + OrtEp* this_ptr, + const OrtKernelRegistry** kernel_registry) noexcept { + auto* ep = static_cast(this_ptr); + *kernel_registry = nullptr; + + RETURN_IF_ERROR(ep->factory_.GetKernelRegistryForEp(*ep, kernel_registry)); + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEp::GetPreferredDataLayoutImpl( + OrtEp* this_ptr, OrtEpDataLayout* preferred_data_layout) noexcept { + const auto* ep = static_cast(this_ptr); +#ifdef ENABLE_CUDA_NHWC_OPS + *preferred_data_layout = ep->config_.prefer_nhwc ? OrtEpDataLayout_NHWC : OrtEpDataLayout_NCHW; +#else + ORT_UNUSED_PARAMETER(ep); + *preferred_data_layout = OrtEpDataLayout_NCHW; +#endif + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEp::ShouldConvertDataLayoutForOpImpl( + OrtEp* this_ptr, const char* domain, const char* op_type, + OrtEpDataLayout target_data_layout, int* should_convert) noexcept { + ORT_UNUSED_PARAMETER(this_ptr); + + if (should_convert == nullptr) { + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "should_convert must not be null."); + } + + const char* safe_domain = domain != nullptr ? domain : ""; + const char* safe_op_type = op_type != nullptr ? op_type : ""; + +#ifndef ENABLE_CUDA_NHWC_OPS + ORT_UNUSED_PARAMETER(safe_domain); + ORT_UNUSED_PARAMETER(safe_op_type); + ORT_UNUSED_PARAMETER(target_data_layout); + *should_convert = 0; // NHWC kernels are not compiled into this plugin build. + return nullptr; +#else + + // Only convert to NHWC; for any other target layout, let ORT decide. + if (target_data_layout != OrtEpDataLayout_NHWC) { + *should_convert = -1; // Let ORT decide + return nullptr; + } + + // ONNX domain ops that have NHWC kernel registrations. + static const std::unordered_set cuda_nhwc_onnx_ops{ + "BatchNormalization", + "Conv", + "ConvTranspose", + "GlobalMaxPool", + "MaxPool", + "GlobalAveragePool", + "AveragePool", + "GridSample", + "DepthToSpace", + "SpaceToDepth", + "LRN", + }; + + // Check ONNX domain (empty string) or MS domain (com.microsoft) + bool is_onnx_domain = (safe_domain[0] == '\0'); + bool is_ms_domain = (std::strcmp(safe_domain, "com.microsoft") == 0); + + if (is_onnx_domain && cuda_nhwc_onnx_ops.count(safe_op_type) > 0) { + *should_convert = 1; // Convert + return nullptr; + } + + if (is_ms_domain && std::strcmp(safe_op_type, "GridSample") == 0) { + *should_convert = 1; // Convert + return nullptr; + } + + *should_convert = 0; // Explicitly decline conversion for unsupported NHWC ops. + return nullptr; +#endif +} + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.h b/onnxruntime/core/providers/cuda/plugin/cuda_ep.h new file mode 100644 index 0000000000000..5f961fe3b0a8c --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_plugin_utils.h" +#include "ep/adapters.h" + +#include +#include + +namespace onnxruntime { +namespace cuda_plugin { + +class CudaEpFactory; + +/// CUDA execution provider implementation using public OrtEp interface. +class CudaEp : public onnxruntime::ep::adapter::Ep { + public: + /// Configuration parameters for the CUDA EP, parsed from session options. + struct Config { + bool prefer_nhwc = false; ///< Use NHWC data layout when available. + bool use_tf32 = true; ///< Enable TF32 math on Ampere+ GPUs. + bool enable_skip_layer_norm_strict_mode = false; ///< Strict mode for SkipLayerNorm kernel. + int device_id = 0; ///< CUDA device ordinal. + int cudnn_conv_algo = 0; ///< cuDNN convolution algorithm selection. + bool cudnn_conv_use_max_workspace = true; ///< Use maximum workspace for cuDNN conv algo search. + bool cudnn_conv1d_pad_to_nc1d = false; ///< Pad 1D convolutions to NC1D format. + bool fuse_conv_bias = false; ///< Enable cuDNN frontend conv+bias fusion. + int sdpa_kernel = 0; ///< Attention backend bitmask override. + }; + + CudaEp(CudaEpFactory& factory, const Config& config, const OrtLogger& logger); + ~CudaEp(); + + const char* GetEpName() const { return name_.c_str(); } + const Config& GetConfig() const { return config_; } + + private: + // OrtEp callback implementations + static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetCapabilityImpl( + OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept; + + static OrtStatus* ORT_API_CALL GetKernelRegistryImpl( + OrtEp* this_ptr, + const OrtKernelRegistry** kernel_registry) noexcept; + + static OrtStatus* ORT_API_CALL GetPreferredDataLayoutImpl( + OrtEp* this_ptr, OrtEpDataLayout* preferred_data_layout) noexcept; + + static OrtStatus* ORT_API_CALL ShouldConvertDataLayoutForOpImpl( + OrtEp* this_ptr, const char* domain, const char* op_type, + OrtEpDataLayout target_data_layout, int* should_convert) noexcept; + + CudaEpFactory& factory_; + std::string name_; + Config config_; + const OrtLogger& logger_; +}; + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc new file mode 100644 index 0000000000000..494deff257b7b --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -0,0 +1,552 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cuda_ep_factory.h" +#include "cuda_ep.h" +#include "cuda_plugin_kernels.h" +#include "core/common/string_utils.h" + +#include +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace cuda_plugin { + +CudaEpFactory::CudaEpFactory(const OrtApi& ort_api, const OrtEpApi& ep_api, + const OrtLogger& default_logger) + : OrtEpFactory{}, + ort_api_(ort_api), + ep_api_(ep_api), + default_logger_(default_logger) { + ort_version_supported = ORT_API_VERSION; + + if (!::onnxruntime::ep::adapter::LoggingManager::HasDefaultLogger()) { + ::onnxruntime::ep::adapter::LoggingManager::CreateDefaultLogger(&default_logger); + } + + // Assign callback function pointers + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; + GetVersion = GetVersionImpl; + GetSupportedDevices = GetSupportedDevicesImpl; + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + CreateDataTransfer = CreateDataTransferImpl; + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; +} + +CudaEpFactory::~CudaEpFactory() { + if (kernel_registry_ != nullptr) { + ep_api_.ReleaseKernelRegistry(kernel_registry_); + } +} + +OrtStatus* CudaEpFactory::GetKernelRegistryForEp(CudaEp& ep, + const OrtKernelRegistry** out_kernel_registry) { + *out_kernel_registry = nullptr; + + std::lock_guard lock(registry_mutex_); + + if (kernel_registry_ == nullptr) { + const char* ep_name = ep.GetEpName(); + // CreateCudaKernelRegistry dispatches between legacy/generated registrations + // and adapter-mode registration path based on build configuration. + RETURN_IF_ERROR(CreateCudaKernelRegistry(ep_api_, ep_name, nullptr, &kernel_registry_)); + } + + *out_kernel_registry = kernel_registry_; + return nullptr; +} + +// --------------------------------------------------------------------------- +// OrtEpFactory callback implementations +// --------------------------------------------------------------------------- + +/*static*/ +const char* ORT_API_CALL CudaEpFactory::GetNameImpl(const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->ep_name_.c_str(); +} + +/*static*/ +const char* ORT_API_CALL CudaEpFactory::GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->vendor_.c_str(); +} + +/*static*/ +uint32_t ORT_API_CALL CudaEpFactory::GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->vendor_id_; +} + +/*static*/ +const char* ORT_API_CALL CudaEpFactory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->ep_version_.c_str(); +} + +namespace { + +std::string ToUpper(std::string value) { + std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) { + return static_cast(std::toupper(c)); + }); + return value; +} + +std::string GetProviderOptionPrefix(std::string_view provider_name) { + return "ep." + onnxruntime::utils::GetLowercaseString(std::string{provider_name}) + "."; +} + +void LogWarning(const OrtApi& ort_api, const OrtLogger& logger, const char* file, int line, + const char* function, const char* msg) { + OrtStatus* st = ort_api.Logger_LogMessage(&logger, ORT_LOGGING_LEVEL_WARNING, msg, file, line, function); + if (st != nullptr) { + ort_api.ReleaseStatus(st); + } +} + +} // namespace + +CudaEpFactory::HardwareDeviceKey CudaEpFactory::MakeDeviceKey(const OrtApi& ort_api, + const OrtHardwareDevice& device) { + return { + ort_api.HardwareDevice_Type(&device), + ort_api.HardwareDevice_VendorId(&device), + ort_api.HardwareDevice_DeviceId(&device), + }; +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEpFactory::GetSupportedDevicesImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* hw_devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + auto* factory = static_cast(this_ptr); + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + auto release_ep_devices = [&](OrtStatus* status) -> OrtStatus* { + for (size_t j = 0; j < num_ep_devices; ++j) { + factory->ep_api_.ReleaseEpDevice(ep_devices[j]); + ep_devices[j] = nullptr; + } + num_ep_devices = 0; + return status; + }; + + int cuda_device_index = 0; + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *hw_devices[i]; + auto hw_type = factory->ort_api_.HardwareDevice_Type(&device); + + if (hw_type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // Filter by vendor ID to avoid claiming non-NVIDIA GPUs on mixed-vendor hosts. + // vendor_id == 0 means the hardware enumeration did not provide a vendor ID, + // in which case we fall through and let the CUDA runtime validate the device. + uint32_t hw_vendor_id = factory->ort_api_.HardwareDevice_VendorId(&device); + if (hw_vendor_id != 0 && hw_vendor_id != factory->vendor_id_) { + continue; // Skip non-NVIDIA GPUs + } + + // CUDA uses contiguous ordinals for CUDA-visible NVIDIA devices. Build that + // mapping from the filtered hardware-device list instead of relying on the + // ORT hardware device id, which is not guaranteed to be a CUDA ordinal. + int current_device_id = cuda_device_index++; + const auto device_key = CudaEpFactory::MakeDeviceKey(factory->ort_api_, device); + DeviceCacheEntry* cache_entry = nullptr; + { + std::lock_guard lock(factory->device_cache_mutex_); + auto [it, inserted] = factory->device_cache_.try_emplace(device_key); + if (inserted) { + it->second.cuda_device_id = current_device_id; + it->second.device_memory_info = Ort::MemoryInfo{"Cuda", + OrtMemoryInfoDeviceType_GPU, + factory->vendor_id_, + static_cast(current_device_id), + OrtDeviceMemoryType_DEFAULT, + /*alignment is default*/ 0, + OrtAllocatorType::OrtDeviceAllocator}; + it->second.pinned_memory_info = Ort::MemoryInfo{"CudaPinned", + OrtAllocatorType::OrtDeviceAllocator, + current_device_id, + OrtMemType::OrtMemTypeCPU}; + } + + cache_entry = &it->second; + current_device_id = cache_entry->cuda_device_id; + } + + OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; + factory->ort_api_.CreateKeyValuePairs(&ep_metadata); + factory->ort_api_.CreateKeyValuePairs(&ep_options); + factory->ort_api_.AddKeyValuePair(ep_metadata, "cuda_device_id", std::to_string(current_device_id).c_str()); + factory->ort_api_.AddKeyValuePair(ep_options, "device_id", std::to_string(current_device_id).c_str()); + + // Get CUDA device properties for metadata + int cuda_device_count = 0; + cudaError_t err = cudaGetDeviceCount(&cuda_device_count); + if (err == cudaSuccess && cuda_device_count > 0 && current_device_id < cuda_device_count) { + cudaDeviceProp prop; + if (cudaGetDeviceProperties(&prop, current_device_id) == cudaSuccess) { + factory->ort_api_.AddKeyValuePair(ep_metadata, "cuda_device_name", prop.name); + factory->ort_api_.AddKeyValuePair( + ep_metadata, "cuda_compute_capability", + (std::to_string(prop.major) + "." + std::to_string(prop.minor)).c_str()); + } + } + + OrtEpDevice* ep_device = nullptr; + auto* status = factory->ep_api_.CreateEpDevice(factory, &device, ep_metadata, ep_options, + &ep_device); + factory->ort_api_.ReleaseKeyValuePairs(ep_metadata); + factory->ort_api_.ReleaseKeyValuePairs(ep_options); + + if (status != nullptr) { + return release_ep_devices(status); + } + + auto release_current_ep_device = [factory](OrtEpDevice* device) { + factory->ep_api_.ReleaseEpDevice(device); + }; + // ep_device_guard owns the current device. On error, release_ep_devices cleans up + // previously committed devices [0, num_ep_devices), while the guard cleans up this one. + std::unique_ptr ep_device_guard(ep_device, release_current_ep_device); + + // Register allocator info for GPU device memory + status = factory->ep_api_.EpDevice_AddAllocatorInfo(ep_device, cache_entry->device_memory_info); + if (status != nullptr) { + return release_ep_devices(status); + } + + // Register allocator info for pinned host memory associated with the + // same CUDA ordinal as the device allocator above. + status = factory->ep_api_.EpDevice_AddAllocatorInfo(ep_device, cache_entry->pinned_memory_info); + if (status != nullptr) { + return release_ep_devices(status); + } + + ep_devices[num_ep_devices++] = ep_device_guard.release(); + } + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept { + EXCEPTION_TO_STATUS_BEGIN + + auto* factory = static_cast(this_ptr); + *ep = nullptr; + + if (num_devices != 1) { + return factory->ort_api_.CreateStatus( + ORT_INVALID_ARGUMENT, + "CUDA EP factory currently supports exactly one device per EP instance. " + "Pass a single OrtHardwareDevice when creating the CUDA plugin EP."); + } + if (devices == nullptr || devices[0] == nullptr) { + return factory->ort_api_.CreateStatus( + ORT_INVALID_ARGUMENT, + "CUDA EP factory requires a valid device."); + } + + // Parse configuration from session options. + // The read helpers intentionally swallow errors: if a config entry is + // absent or malformed the default value in Config is kept. + CudaEp::Config config{}; + + { + std::lock_guard lock(factory->device_cache_mutex_); + auto it = factory->device_cache_.find(CudaEpFactory::MakeDeviceKey(factory->ort_api_, *devices[0])); + if (it == factory->device_cache_.end()) { + return factory->ort_api_.CreateStatus( + ORT_INVALID_ARGUMENT, + "CUDA EP factory could not resolve the requested device. " + "Enumerate EP devices again and retry session creation."); + } + config.device_id = it->second.cuda_device_id; + } + + auto try_get_session_config = [&](std::string_view key) -> std::optional { + if (session_options == nullptr) { + return std::nullopt; + } + + size_t size = 0; + OrtStatus* status = factory->ort_api_.GetSessionConfigEntry(session_options, key.data(), nullptr, &size); + if (status != nullptr) { + Ort::Status s(status); + return std::nullopt; + } + if (size == 0) { + return std::nullopt; + } + std::vector buf(size); + status = factory->ort_api_.GetSessionConfigEntry(session_options, key.data(), buf.data(), &size); + if (status != nullptr) { + Ort::Status s(status); + return std::nullopt; + } + return std::string(buf.data()); + }; + + auto log_invalid_session_config = [&](std::string_view key, std::string_view expected) { + if (logger == nullptr) { + return; + } + + const std::string msg = std::string("Failed to parse session config for key '") + + std::string(key) + "'. Expected " + std::string(expected) + + ". Using default value."; + + OrtStatus* st = factory->ort_api_.Logger_LogMessage( + logger, ORT_LOGGING_LEVEL_WARNING, msg.c_str(), "cuda_ep_factory.cc", __LINE__, "CudaEpFactory"); + if (st != nullptr) { + factory->ort_api_.ReleaseStatus(st); + } + }; + + auto read_session_config_bool = [&](std::initializer_list keys, bool& value) { + for (const auto& key : keys) { + auto raw_value = try_get_session_config(key); + if (!raw_value.has_value()) { + continue; + } + + const auto normalized = ToUpper(*raw_value); + if (normalized == "1" || normalized == "TRUE") { + value = true; + return; + } + if (normalized == "0" || normalized == "FALSE") { + value = false; + return; + } + + log_invalid_session_config(key, "a boolean"); + return; + } + }; + + auto read_cudnn_conv_algo = [&](std::initializer_list keys, int& value) { + for (const auto& key : keys) { + auto raw_value = try_get_session_config(key); + if (!raw_value.has_value()) { + continue; + } + + try { + value = std::stoi(*raw_value); + return; + } catch (const std::exception&) { + } + + const auto normalized = ToUpper(*raw_value); + if (normalized == "EXHAUSTIVE") { + value = 0; + return; + } + if (normalized == "HEURISTIC") { + value = 1; + return; + } + if (normalized == "DEFAULT") { + value = 2; + return; + } + + log_invalid_session_config(key, "an integer or one of EXHAUSTIVE/HEURISTIC/DEFAULT"); + return; + } + }; + + auto read_session_config_non_negative_int = [&](std::initializer_list keys, int& value) { + for (const auto& key : keys) { + auto raw_value = try_get_session_config(key); + if (!raw_value.has_value()) { + continue; + } + + try { + int parsed = std::stoi(*raw_value); + if (parsed < 0) { + log_invalid_session_config(key, "a non-negative integer"); + return; + } + + value = parsed; + return; + } catch (const std::exception&) { + } + + log_invalid_session_config(key, "a non-negative integer"); + return; + } + }; + + const std::string ep_options_prefix = GetProviderOptionPrefix(factory->GetEpName()); + const std::string prefer_nhwc_key = ep_options_prefix + "prefer_nhwc"; + const std::string prefer_nhwc_layout_key = ep_options_prefix + "prefer_nhwc_layout"; + const std::string use_tf32_key = ep_options_prefix + "use_tf32"; + const std::string skip_layer_norm_key = ep_options_prefix + "enable_skip_layer_norm_strict_mode"; + const std::string cudnn_use_max_workspace_key = ep_options_prefix + "cudnn_conv_use_max_workspace"; + const std::string cudnn_conv1d_pad_key = ep_options_prefix + "cudnn_conv1d_pad_to_nc1d"; + const std::string cudnn_conv_algo_key = ep_options_prefix + "cudnn_conv_algo"; + const std::string cudnn_conv_algo_search_key = ep_options_prefix + "cudnn_conv_algo_search"; + const std::string fuse_conv_bias_key = ep_options_prefix + "fuse_conv_bias"; + const std::string sdpa_kernel_key = ep_options_prefix + "sdpa_kernel"; + + // Prefer plugin-provider-option keys, then fall back to the legacy ep.cuda.* + // aliases and finally to the historical flat session config names. + read_session_config_bool( + {prefer_nhwc_key, prefer_nhwc_layout_key, "ep.cuda.prefer_nhwc_layout", "prefer_nhwc", "prefer_nhwc_layout"}, + config.prefer_nhwc); + read_session_config_bool({use_tf32_key, "ep.cuda.use_tf32", "use_tf32"}, config.use_tf32); + read_session_config_bool( + {skip_layer_norm_key, "ep.cuda.enable_skip_layer_norm_strict_mode", "enable_skip_layer_norm_strict_mode"}, + config.enable_skip_layer_norm_strict_mode); + read_session_config_bool( + {cudnn_use_max_workspace_key, "ep.cuda.cudnn_conv_use_max_workspace", "cudnn_conv_use_max_workspace"}, + config.cudnn_conv_use_max_workspace); + read_session_config_bool( + {cudnn_conv1d_pad_key, "ep.cuda.cudnn_conv1d_pad_to_nc1d", "cudnn_conv1d_pad_to_nc1d"}, + config.cudnn_conv1d_pad_to_nc1d); + read_cudnn_conv_algo( + {cudnn_conv_algo_search_key, cudnn_conv_algo_key, "ep.cuda.cudnn_conv_algo_search", "ep.cuda.cudnn_conv_algo", + "cudnn_conv_algo_search", "cudnn_conv_algo"}, + config.cudnn_conv_algo); + read_session_config_bool( + {fuse_conv_bias_key, "ep.cuda.fuse_conv_bias", "fuse_conv_bias"}, + config.fuse_conv_bias); + read_session_config_non_negative_int( + {sdpa_kernel_key, "ep.cuda.sdpa_kernel", "sdpa_kernel"}, + config.sdpa_kernel); + + const OrtLogger& ep_logger = logger ? *logger : factory->default_logger_; + auto actual_ep = std::make_unique(*factory, config, ep_logger); + *ep = actual_ep.release(); + + return nullptr; + + EXCEPTION_TO_STATUS_END +} + +/*static*/ +void ORT_API_CALL CudaEpFactory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { + delete static_cast(ep); +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEpFactory::CreateAllocatorImpl( + OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + auto& factory = *static_cast(this_ptr); + *allocator = nullptr; + + const char* name = ""; + OrtStatus* status = factory.ort_api_.MemoryInfoGetName(memory_info, &name); + if (status != nullptr) { + return status; + } + int req_device_id = 0; + status = factory.ort_api_.MemoryInfoGetId(memory_info, &req_device_id); + if (status != nullptr) { + return status; + } + + if (name != nullptr && strcmp(name, "Cuda") == 0) { + auto cuda_allocator = std::make_unique(memory_info, req_device_id); + *allocator = cuda_allocator.release(); + return nullptr; + } + + if (name != nullptr && strcmp(name, "CudaPinned") == 0) { + auto pinned_allocator = std::make_unique(memory_info); + *allocator = pinned_allocator.release(); + return nullptr; + } + + return factory.ort_api_.CreateStatus( + ORT_INVALID_ARGUMENT, + "Unknown memory info provided to CUDA EP CreateAllocator."); +} + +/*static*/ +void ORT_API_CALL CudaEpFactory::ReleaseAllocatorImpl( + OrtEpFactory* this_ptr, OrtAllocator* allocator) noexcept { + if (!allocator) return; + auto* factory = static_cast(this_ptr); + auto* typed_allocator = static_cast(allocator); + switch (typed_allocator->GetKind()) { + case CudaAllocatorKind::kDevice: + delete static_cast(allocator); + return; + case CudaAllocatorKind::kPinned: + delete static_cast(allocator); + return; + default: + LogWarning(factory->ort_api_, factory->default_logger_, __FILE__, __LINE__, + "CudaEpFactory::ReleaseAllocatorImpl", + "ReleaseAllocatorImpl received an unknown CudaAllocatorKind. Leaking the allocator instance."); + assert(false && "Unknown CudaAllocatorKind"); + return; + } +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEpFactory::CreateDataTransferImpl( + OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept { + auto& factory = *static_cast(this_ptr); + auto data_transfer_impl = std::make_unique(factory.ort_api_, factory.ep_api_); + *data_transfer = data_transfer_impl.release(); + return nullptr; +} + +/*static*/ +bool ORT_API_CALL CudaEpFactory::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return true; // CUDA EP is stream-aware +} + +/*static*/ +OrtStatus* ORT_API_CALL CudaEpFactory::CreateSyncStreamForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* /*stream_options*/, + OrtSyncStreamImpl** stream) noexcept { + EXCEPTION_TO_STATUS_BEGIN + + auto* factory = static_cast(this_ptr); + int req_device_id = factory->ep_api_.MemoryDevice_GetDeviceId(memory_device); + auto cuda_stream = std::make_unique(*factory, req_device_id, nullptr); + + // Initialize CUDA handles (stream, cuBLAS, cuDNN) + RETURN_IF_ERROR(cuda_stream->InitHandles()); + + *stream = cuda_stream.release(); + return nullptr; + + EXCEPTION_TO_STATUS_END +} + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h new file mode 100644 index 0000000000000..ea4e2da19001d --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.h @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_plugin_utils.h" +#include "cuda_allocator_plugin.h" +#include "cuda_data_transfer_plugin.h" +#include "cuda_stream_plugin.h" + +#include +#include +#include +#include + +namespace onnxruntime { +namespace cuda_plugin { + +class CudaEp; + +/// CUDA EP factory implementing OrtEpFactory. +/// Manages device enumeration, allocator creation, data transfer, and stream creation. +class CudaEpFactory : public OrtEpFactory { + public: + CudaEpFactory(const OrtApi& ort_api, const OrtEpApi& ep_api, + const OrtLogger& default_logger); + ~CudaEpFactory(); + + const OrtApi& GetOrtApi() const { return ort_api_; } + const OrtEpApi& GetEpApi() const { return ep_api_; } + const std::string& GetEpName() const { return ep_name_; } + + /// Get or create the shared kernel registry for this factory. + /// Lazily created on first call; subsequent calls return the cached instance. + /// Thread-safe: protected by registry_mutex_. + OrtStatus* GetKernelRegistryForEp(CudaEp& ep, + const OrtKernelRegistry** out_kernel_registry); + + private: + // OrtEpFactory callback implementations + static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; + static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept; + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, size_t num_devices, + OrtEpDevice** ep_devices, size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept; + + static OrtStatus* ORT_API_CALL CreateEpImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept; + + static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* this_ptr, OrtEp* ep) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl( + OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept; + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* this_ptr, + OrtAllocator* allocator) noexcept; + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl( + OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept; + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept; + + const OrtApi& ort_api_; + const OrtEpApi& ep_api_; + const OrtLogger& default_logger_; + + const std::string ep_name_{"CudaPluginExecutionProvider"}; + const std::string vendor_{"NVIDIA"}; + const uint32_t vendor_id_ = 0x10DE; // NVIDIA PCI vendor ID + const std::string ep_version_{"1.0.0"}; + + struct DeviceCacheEntry { + int cuda_device_id{-1}; + Ort::MemoryInfo device_memory_info{nullptr}; + Ort::MemoryInfo pinned_memory_info{nullptr}; + }; + + struct HardwareDeviceKey { + OrtHardwareDeviceType type{OrtHardwareDeviceType::OrtHardwareDeviceType_CPU}; + uint32_t vendor_id{0}; + uint32_t device_id{0}; + + bool operator==(const HardwareDeviceKey&) const = default; + }; + + struct HardwareDeviceKeyHasher { + size_t operator()(const HardwareDeviceKey& key) const noexcept { + size_t hash = static_cast(key.type); + hash = (hash * 1315423911u) ^ static_cast(key.vendor_id); + hash = (hash * 1315423911u) ^ static_cast(key.device_id); + return hash; + } + }; + + static HardwareDeviceKey MakeDeviceKey(const OrtApi& ort_api, + const OrtHardwareDevice& device); + + // Stable per-device cache keyed by public hardware-device properties instead + // of the transient OrtHardwareDevice* pointer received during enumeration. + std::mutex device_cache_mutex_; + std::unordered_map device_cache_; + + // Kernel registry (cached, shared across EP instances) + OrtKernelRegistry* kernel_registry_ = nullptr; + std::mutex registry_mutex_; +}; + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h new file mode 100644 index 0000000000000..b72058dc90baa --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -0,0 +1,1146 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// cuda_kernel_adapter.h — Compatibility shim for migrating CUDA kernels to the +// plugin EP architecture. +// +// This header provides: +// - CudaKernel base class (scratch buffers, CUDA handles, etc.) +// - Error-return macros (CUDA_RETURN_IF_ERROR, etc.) +// - Type mapping helpers (ToCudaType) +// - Math/compute shims (HalfGemmOptions, CublasMathModeSetter) +// - Self-registering BuildKernelCreateInfo<> macros via PluginKernelCollector +// - CUDAExecutionProvider shim class +// - CPU provider shims for the plugin build + +#pragma once + +#include "core/common/status.h" +#include "core/common/narrow.h" +#include "core/common/float16.h" +#include "core/common/float8.h" +#include "core/framework/float4.h" +#include "core/framework/allocator.h" +#include "core/framework/tensor_shape.h" +#include "core/util/math.h" +#include + +#include +#include +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" + +#ifdef __CUDACC__ +#include +#include +#endif + +// =================================================================== +// Macros will be defined later to override core definitions. +// =================================================================== + +#include "core/providers/cuda/plugin/cuda_stream_plugin.h" +#include "core/providers/cuda/gpu_data_transfer.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" +#include "core/session/onnxruntime_cxx_api.h" + +// Forward-declare CudaStream so adapter overloads accepting CudaStream* compile +// without pulling in cuda_stream_handle.h (which depends on CUDAExecutionProvider). +namespace onnxruntime { +struct CudaStream; + +// Lightweight Stream shim for plugin build: wraps a raw cudaStream_t as a +// framework-compatible Stream* that can be passed to _impl.cu functions which +// call stream->GetHandle(). Stack-allocated; does NOT own the stream. +// Only available in .cc translation units (not .cu) since Stream is incomplete in NVCC context. +#ifndef __CUDACC__ +struct PluginStreamShim : public onnxruntime::Stream { + explicit PluginStreamShim(void* cuda_stream_handle) + : onnxruntime::Stream(cuda_stream_handle, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, + OrtDevice::VendorIds::NVIDIA, 0)) {} +}; + +class OrtStreamAdapter { + public: + explicit OrtStreamAdapter(void* cuda_stream_handle) + : plugin_stream_shim_(cuda_stream_handle), stream_(&plugin_stream_shim_) {} + + onnxruntime::Stream* get() const { return stream_; } + operator onnxruntime::Stream*() const { return stream_; } + + private: + PluginStreamShim plugin_stream_shim_; + onnxruntime::Stream* stream_; +}; +#else +class OrtStreamAdapter { + public: + explicit OrtStreamAdapter(void* cuda_stream_handle) + : stream_(static_cast(cuda_stream_handle)) {} + + onnxruntime::Stream* get() const { return stream_; } + operator onnxruntime::Stream*() const { return stream_; } + + private: + onnxruntime::Stream* stream_; +}; +#endif +} // namespace onnxruntime + +// =================================================================== +// Section 1: Include path selection +// =================================================================== + +#include "core/graph/constants.h" +#include "ep/adapters.h" +#include "core/framework/op_kernel.h" +#include "core/providers/common.h" + +namespace onnxruntime { +inline constexpr const char* kCudaPluginExecutionProvider = "CudaPluginExecutionProvider"; +} + +namespace onnxruntime { +namespace cuda { + +#ifndef CUDA_STREAM_FROM_CTX +// Helper for kernels that need a cudaStream_t from OpKernelContext in plugin build. +#define CUDA_STREAM_FROM_CTX(ctx) static_cast(GetComputeStream(ctx)) +#endif + +// Forward declare the template for kernel registration macros to specialize +// inside the onnxruntime::cuda namespace. +template +KernelCreateInfo BuildKernelCreateInfo(); + +// Tensor creation helper to replace deprecated Tensor::Create +inline std::unique_ptr<::onnxruntime::Tensor> TensorCreate(MLDataType type, const TensorShape& shape, AllocatorPtr allocator) { + return std::make_unique<::onnxruntime::Tensor>(type, shape, std::move(allocator)); +} + +using ::onnxruntime::HandleNegativeAxis; + +} // namespace cuda + +#ifndef DISABLE_CONTRIB_OPS +namespace contrib { +namespace cuda { + +// Forward declare the template for kernel registration macros to specialize +// inside the onnxruntime::contrib::cuda namespace. +template +KernelCreateInfo BuildKernelCreateInfo(); + +inline std::unique_ptr<::onnxruntime::Tensor> TensorCreate(MLDataType type, const TensorShape& shape, AllocatorPtr allocator) { + return std::make_unique<::onnxruntime::Tensor>(type, shape, std::move(allocator)); +} + +using ::onnxruntime::HandleNegativeAxis; + +} // namespace cuda +} // namespace contrib +#endif +} // namespace onnxruntime + +// =================================================================== +// Section 2: Error-return macros (redefined for all plugin paths) +// =================================================================== + +// Redefine error macros for ported code to use our adapter-specific Status translation +#undef CUDA_RETURN_IF_ERROR +#define CUDA_RETURN_IF_ERROR(expr) \ + { \ + cudaError_t _err = (expr); \ + if (_err != cudaSuccess) { \ + return onnxruntime::common::Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, std::string("CUDA error: ") + cudaGetErrorString(_err)); \ + } \ + } + +#undef CUBLAS_RETURN_IF_ERROR +#define CUBLAS_RETURN_IF_ERROR(expr) \ + { \ + cublasStatus_t _status = (expr); \ + if (_status != CUBLAS_STATUS_SUCCESS) { \ + return onnxruntime::common::Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, std::string("cuBLAS error: ") + std::to_string(static_cast(_status))); \ + } \ + } + +#undef CUDNN_RETURN_IF_ERROR +#define CUDNN_RETURN_IF_ERROR(expr) \ + { \ + cudnnStatus_t _status = (expr); \ + if (_status != CUDNN_STATUS_SUCCESS) { \ + return onnxruntime::common::Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, std::string("cuDNN error: ") + cudnnGetErrorString(_status)); \ + } \ + } + +#undef CUFFT_RETURN_IF_ERROR +#define CUFFT_RETURN_IF_ERROR(expr) \ + { \ + cufftResult _status = (expr); \ + if (_status != CUFFT_SUCCESS) { \ + return onnxruntime::common::Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, std::string("cuFFT error: ") + std::to_string((int)_status)); \ + } \ + } + +// =================================================================== +// Section 3: Self-registering kernel collector using BuildKernelCreateInfo<> +// +// Each ONNX_OPERATOR_*_KERNEL_EX macro expansion both: +// 1. Creates a BuildKernelCreateInfo() template specialization +// (identical to the framework's macro output, but using adapter types) +// 2. Auto-registers the BuildKernelCreateInfoFn pointer into a global +// PluginKernelCollector singleton +// +// At registration time, the factory iterates the collector and calls +// adapter::KernelRegistry::Register(build_fn()) for each compiled kernel. +// Only ops whose .cc files are actually compiled get registered. +// =================================================================== + +#include + +namespace onnxruntime { +namespace cuda { + +/// Singleton collector for BuildKernelCreateInfoFn pointers. +/// Each compiled kernel .cc file's macro expansion auto-registers here. +/// +/// Thread-safety: Instance() uses a function-local static (C++11 §6.7/4: +/// constructed exactly once, even under concurrent first-access). Add() +/// is guarded by a mutex for formal correctness across translation units, +/// though in practice all calls occur during static initialization. +class PluginKernelCollector { + public: + static PluginKernelCollector& Instance() { + static PluginKernelCollector instance; + return instance; + } + + void Add(BuildKernelCreateInfoFn fn) { + std::lock_guard lock(mutex_); + entries_.push_back(fn); + } + std::vector Entries() const { + std::lock_guard lock(mutex_); + return entries_; + } + + private: + std::vector entries_; + mutable std::mutex mutex_; +}; + +} // namespace cuda +} // namespace onnxruntime + +// --- Macro overrides: produce BuildKernelCreateInfo<> AND auto-register --- +// +// These macros mirror the framework's ONNX_OPERATOR_*_KERNEL_EX definitions +// from core/framework/op_kernel.h, but additionally register each +// BuildKernelCreateInfoFn into PluginKernelCollector at static init time. +#define ORT_ADAPTER_CONCAT_IMPL(x, y) x##y +#define ORT_ADAPTER_CONCAT(x, y) ORT_ADAPTER_CONCAT_IMPL(x, y) + +// The provider parameter are not used in below macros since we are hardcoding the provider to cuda plugin. +#define CUDA_PLUGIN_EP ::onnxruntime::kCudaPluginExecutionProvider + +#undef ONNX_OPERATOR_KERNEL_EX +#define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \ + class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(ver).Provider(CUDA_PLUGIN_EP).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ + true); + +#undef ONNX_OPERATOR_VERSIONED_KERNEL_EX +#define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, provider, builder, ...) \ + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(startver, endver).Provider(CUDA_PLUGIN_EP).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ + true); + +#undef ONNX_OPERATOR_TYPED_KERNEL_EX +#define ONNX_OPERATOR_TYPED_KERNEL_EX(name, domain, ver, type, provider, builder, ...) \ + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(ver).Provider(CUDA_PLUGIN_EP).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_##type##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ + true); + +#undef ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX +#define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, domain, startver, endver, type, provider, builder, ...) \ + class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(startver, endver).Provider(CUDA_PLUGIN_EP).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_##type##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ + true); + +#undef ONNX_OPERATOR_TWO_TYPED_KERNEL_EX +#define ONNX_OPERATOR_TWO_TYPED_KERNEL_EX(name, domain, ver, type1, type2, provider, builder, ...) \ + class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(ver).Provider(CUDA_PLUGIN_EP).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_##type1##_##type2##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ + true); + +#undef ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX +#define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX(name, domain, startver, endver, type1, type2, \ + provider, builder, ...) \ + class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(startver, endver).Provider(CUDA_PLUGIN_EP).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_##type1##_##type2##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ + true); + +#undef ONNX_OPERATOR_THREE_TYPED_KERNEL_EX +#define ONNX_OPERATOR_THREE_TYPED_KERNEL_EX(name, domain, ver, type1, type2, type3, provider, builder, ...) \ + class ONNX_OPERATOR_THREE_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, type3, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name).SetDomain(domain).SinceVersion(ver).Provider(CUDA_PLUGIN_EP).Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + })); \ + } \ + static const bool ORT_ADAPTER_CONCAT(ORT_ADAPTER_AUTOREG_##name##_##type1##_##type2##_##type3##_, __COUNTER__) = \ + (::onnxruntime::cuda::PluginKernelCollector::Instance().Add( \ + &BuildKernelCreateInfo), \ + true); + +// =================================================================== +// Section 4: Logging shim (adapter path only) +// LOGS_DEFAULT is re-routed through ep::adapter::LoggingManager, which +// holds the ORT default logger set up in CudaEpFactory::CudaEpFactory. +// All severity levels (including ERROR/WARNING) are forwarded to the +// ORT logger; no log output is suppressed. +// =================================================================== + +// Explicit function instantiation — called once per unique class in each .cc file +#define ONNX_OPERATOR_TYPED_KERNEL_COMPUTE_INSTANTIATION(cls) template Status cls::ComputeInternal(OpKernelContext* context) const; + +// The plugin utilizes ep::adapter::LoggingManager for LOGS_DEFAULT, +// which is initialized in CudaEpFactory::CudaEpFactory. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace cuda { + +// =================================================================== +// Section 5: Runtime configuration for migrated kernels +// Fields are written once during CudaEp construction and owned by the +// shim CUDAExecutionProvider. Plugin kernels cache a shared_ptr to this +// config object during construction so Compute() does not need to rely on +// raw provider-pointer casts. +// =================================================================== + +namespace detail { +struct CudaKernelAdapterRuntimeConfig { + bool use_tf32 = true; + bool skip_layer_norm_strict_mode = false; + int cudnn_conv_algo = 0; + bool cudnn_conv_use_max_workspace = true; + bool cudnn_conv1d_pad_to_nc1d = false; + bool fuse_conv_bias = false; + int sdpa_kernel = 0; + int device_id = 0; + cudaDeviceProp device_prop{}; + onnxruntime::AttentionKernelOptions attention_kernel_options; +}; +template +struct SizeOf { + static constexpr size_t value = sizeof(T); +}; +template <> +struct SizeOf { + static constexpr size_t value = 0; +}; + +[[nodiscard]] inline bool TryBytesForCount(size_t count_or_bytes, size_t element_size, size_t& bytes) { + if (element_size == 0) { + // `element_size == 0` is the sentinel for the `T = void` path. + // In that mode callers already pass a raw byte count to helpers like + // GetScratchBuffer(workspace_bytes, ...), so no multiplication is needed. + bytes = count_or_bytes; + return true; + } + + if (count_or_bytes > (std::numeric_limits::max() / element_size)) { + return false; + } + + bytes = count_or_bytes * element_size; + return true; +} + +template +IConstantBuffer* GetConstOnesBufferForDevice(int device_id) { + static std::mutex mutex; + static std::unordered_map>> buffers; + std::lock_guard lock(mutex); + auto& buffer = buffers[device_id]; + if (!buffer) { + buffer = CreateConstantOnes(); + } + return buffer.get(); +} + +struct DefaultCudaHandles { + cublasHandle_t cublas = nullptr; + cudnnHandle_t cudnn = nullptr; + + ~DefaultCudaHandles() { + if (cublas != nullptr) { + cublasDestroy(cublas); + } + if (cudnn != nullptr) { + cudnnDestroy(cudnn); + } + } +}; + +inline DefaultCudaHandles& GetDefaultCudaHandlesForDevice(int device_id) { + // Fallback handles are only used for code paths that need cuBLAS/cuDNN + // without an active CudaSyncStream. Keep them thread-local so they are not + // shared across callers that may use the libraries concurrently. + thread_local std::unordered_map handles_by_device; + auto [it, inserted] = handles_by_device.try_emplace(device_id); + if (inserted) { + int prev_device = -1; + const cudaError_t get_device_result = cudaGetDevice(&prev_device); + PL_CUDA_CALL_THROW(cudaSetDevice(device_id)); + if (cublasCreate(&it->second.cublas) != CUBLAS_STATUS_SUCCESS) { + if (get_device_result == cudaSuccess) { + cudaSetDevice(prev_device); + } + handles_by_device.erase(it); + ORT_THROW("Failed to create default cuBLAS handle for CUDA plugin device ", device_id); + } + if (cudnnCreate(&it->second.cudnn) != CUDNN_STATUS_SUCCESS) { + cublasDestroy(it->second.cublas); + it->second.cublas = nullptr; + if (get_device_result == cudaSuccess) { + cudaSetDevice(prev_device); + } + handles_by_device.erase(it); + ORT_THROW("Failed to create default cuDNN handle for CUDA plugin device ", device_id); + } + if (get_device_result == cudaSuccess) { + PL_CUDA_CALL_THROW(cudaSetDevice(prev_device)); + } + } + + return it->second; +} + +inline const cudaDeviceProp& GetDevicePropForDevice(int device_id) { + static std::mutex mutex; + static std::unordered_map> props; + std::lock_guard lock(mutex); + auto it = props.find(device_id); + if (it == props.end()) { + auto prop = std::make_unique(); + const cudaError_t result = cudaGetDeviceProperties(prop.get(), device_id); + if (result != cudaSuccess) { + ORT_THROW("Failed to query CUDA device properties for device ", device_id, ": ", cudaGetErrorString(result)); + } + it = props.emplace(device_id, std::move(prop)).first; + } + return *it->second; +} +} // namespace detail +} // namespace cuda + +// =================================================================== +// Section 6: CUDAExecutionProvider shim +// Provides the minimal API surface that migrated kernels expect +// (GetCudnnConvAlgo, UseTF32, GetDeviceProp, etc.) without the full +// CUDAExecutionProvider class from onnxruntime/core/providers/cuda/. +// +// In the plugin build this shim is wrapped by adapter::Ep. Plugin kernels +// should prefer the CudaKernel base-class accessors for runtime settings +// instead of re-casting info.GetExecutionProvider() inside Compute(). +// =================================================================== + +// Shim for CUDAExecutionProvider required by conv.cc, einsum, and others +class CUDAExecutionProvider : public onnxruntime::IExecutionProvider { + public: + explicit CUDAExecutionProvider(const std::string& name, const OrtEp* ort_ep = nullptr) + : onnxruntime::IExecutionProvider{name}, ort_ep_{ort_ep} {} + + std::unique_ptr GetDataTransfer() const override { + return std::make_unique(); + } + + const OrtEp* GetOrtEp() const override { + return ort_ep_; + } + + std::shared_ptr GetRuntimeConfig() const { + return config_; + } + + int GetCudnnConvAlgo() const { + return config_->cudnn_conv_algo; + } + bool GetCudnnConvUseMaxWorkspace() const { + return config_->cudnn_conv_use_max_workspace; + } + bool GetCudnnConv1dPadToNc1d() const { + return config_->cudnn_conv1d_pad_to_nc1d; + } + bool UseTF32() const { + return config_->use_tf32; + } + bool IsFuseConvBias() const { + return config_->fuse_conv_bias; + } + const onnxruntime::AttentionKernelOptions* GetAttentionKernelOptions() const { + config_->attention_kernel_options.InitializeOnce(config_->sdpa_kernel, true, true); + return &config_->attention_kernel_options; + } + const cudaDeviceProp& GetDeviceProp() const { + return config_->device_prop; + } + + private: + const OrtEp* ort_ep_ = nullptr; + std::shared_ptr config_ = + std::make_shared(); +}; + +namespace cuda { +namespace detail { + +inline std::shared_ptr GetCudaKernelAdapterRuntimeConfigForProvider(const void* provider) { + return static_cast(provider)->GetRuntimeConfig(); +} + +} // namespace detail + +// Populate the per-provider adapter config from a pre-filled initializer struct. +// Callers (e.g. CudaEp constructor) construct a detail::CudaKernelAdapterRuntimeConfig, +// fill every field they care about, then call this function. Adding a new config +// field only requires updating the struct and the call site — no signature change. +inline void SetCudaKernelAdapterRuntimeConfigForProvider( + const void* provider, const detail::CudaKernelAdapterRuntimeConfig& init_config) { + auto config = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); + // AttentionKernelOptions contains std::once_flag (not copyable), so assign + // the plain-data fields individually rather than relying on operator=. + config->use_tf32 = init_config.use_tf32; + config->skip_layer_norm_strict_mode = init_config.skip_layer_norm_strict_mode; + config->cudnn_conv_algo = init_config.cudnn_conv_algo; + config->cudnn_conv_use_max_workspace = init_config.cudnn_conv_use_max_workspace; + config->cudnn_conv1d_pad_to_nc1d = init_config.cudnn_conv1d_pad_to_nc1d; + config->fuse_conv_bias = init_config.fuse_conv_bias; + config->sdpa_kernel = init_config.sdpa_kernel; + config->device_id = init_config.device_id; + PL_CUDA_CALL_THROW(cudaGetDeviceProperties(&config->device_prop, config->device_id)); +} + +inline bool GetCudaKernelAdapterSkipLayerNormStrictMode(const void* provider) { + const auto config = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); + return config->skip_layer_norm_strict_mode; +} + +// Global aliases and shims +using Status = onnxruntime::common::Status; +using MLFloat16 = onnxruntime::MLFloat16; +using BFloat16 = onnxruntime::BFloat16; +using Float8E4M3FN = onnxruntime::Float8E4M3FN; +using Float8E4M3FNUZ = onnxruntime::Float8E4M3FNUZ; +using Float8E5M2 = onnxruntime::Float8E5M2; +using Float8E5M2FNUZ = onnxruntime::Float8E5M2FNUZ; + +// Type mapping for CUDA +template +struct ToCudaType { + typedef T MappedType; + static MappedType FromFloat(float f) { return static_cast(f); } +}; + +template <> +struct ToCudaType { + typedef half MappedType; + static MappedType FromFloat(float f) { + uint16_t h = onnxruntime::math::floatToHalf(f); + return *reinterpret_cast(&h); + } +}; + +#ifdef __CUDACC__ +template <> +struct ToCudaType { + typedef nv_bfloat16 MappedType; + static MappedType FromFloat(float f) { + return nv_bfloat16(f); + } +}; + +// Forward declare templates from common.cuh to allow specialization +// Match signatures from common.cuh exactly (no default parameters) +template +struct _IsInf; +template +struct _IsNan; + +namespace bf16_isinf_nan { +template +struct IsInfTyped; +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(nv_bfloat16 a) { + uint16_t val = *reinterpret_cast(&a); + return (val & 0x7F80) == 0x7F80 && (val & 0x007F) == 0x0000; + } + static __device__ __inline__ bool IsInfPos(nv_bfloat16 a) { + return *reinterpret_cast(&a) == 0x7F80; + } + static __device__ __inline__ bool IsInfNeg(nv_bfloat16 a) { + return *reinterpret_cast(&a) == 0xFF80; + } +}; +} // namespace bf16_isinf_nan + +// Specialize for nv_bfloat16 to avoid ambiguity with isnan/isinf overloads +template <> +struct _IsNan { + __device__ __inline__ bool operator()(nv_bfloat16 a) const { + uint16_t val = *reinterpret_cast(&a); + return (val & 0x7F80) == 0x7F80 && (val & 0x007F) != 0x0000; + } +}; + +template +struct _IsInf { + __device__ __inline__ bool operator()(nv_bfloat16 a) const { + if constexpr (detect_positive && detect_negative) { + return bf16_isinf_nan::IsInfTyped::IsInf(a); + } else if constexpr (detect_positive) { + return bf16_isinf_nan::IsInfTyped::IsInfPos(a); + } else if constexpr (detect_negative) { + return bf16_isinf_nan::IsInfTyped::IsInfNeg(a); + } else { + return false; + } + } +}; +#endif + +// =================================================================== +// Section 6b: CPU provider shims for the plugin build +// Inline implementations of CPU utility functions that CUDA kernels +// reference (e.g., OneHot validation, GatherElements shape checks). +// These are normally provided by onnxruntime_providers but the plugin +// does not link against it. +// +// We temporarily close namespace cuda so these shims live directly in +// namespace onnxruntime, where unqualified lookup from onnxruntime::cuda +// will find them, and where onnxruntime::GatherElements resolves correctly. +// =================================================================== + +} // namespace cuda + +// Shim for ValidateInputs from core/providers/cpu/tensor/onehot.h +inline Status ValidateInputs(const Tensor* depth, const Tensor* values) { + if (!depth->Shape().IsScalar()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Invalid argument for depth; it's not a scalar."); + } + if (!(values->Shape().NumDimensions() == 1 && values->Shape().Size() == 2)) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Invalid argument for values; either it's rank is more than 1" + " or it has more than 2 elements"); + } + return Status::OK(); +} + +// Shim for PrepareOutputShape from core/providers/cpu/tensor/onehot.h +inline Status PrepareOutputShape(const Tensor* indices, const int64_t depth_val, const int64_t axis, + int64_t& prefix_dim_size, int64_t& suffix_dim_size, + TensorShapeVector& output_shape) { + const auto& indices_shape = indices->Shape(); + const auto indices_dims = indices_shape.GetDims(); + const auto indices_num_dims = indices_shape.NumDimensions(); + output_shape = indices_shape.AsShapeVector(); + const auto output_rank = static_cast(indices_num_dims) + 1; + auto true_axis = HandleNegativeAxis(axis, output_rank); + output_shape.insert(output_shape.begin() + true_axis, depth_val); + prefix_dim_size = 1; + for (int64_t i = 0; i < true_axis; ++i) { + prefix_dim_size *= indices_dims[narrow(i)]; + } + suffix_dim_size = indices_shape.Size() / prefix_dim_size; + return Status::OK(); +} + +// Shim for GatherElements::ValidateInputShapes from +// core/providers/cpu/tensor/gather_elements.h +class GatherElements { + public: + static Status ValidateInputShapes(const TensorShape& input_data_shape, + const TensorShape& indices_shape, + int64_t axis) { + int64_t input_data_rank = static_cast(input_data_shape.NumDimensions()); + int64_t indices_rank = static_cast(indices_shape.NumDimensions()); + if (input_data_rank < 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "GatherElements op: Cannot operate on scalar input"); + if (input_data_rank != indices_rank) + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "GatherElements op: Rank of input 'data' needs to be equal to rank of input 'indices'"); + for (int64_t i = 0; i < indices_rank; ++i) { + if (i != axis) { + if (indices_shape[narrow(i)] < 0 || + indices_shape[narrow(i)] > input_data_shape[narrow(i)]) + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "GatherElements op: 'indices' shape should have values within bounds of 'data' shape. " + "Invalid value in indices shape is: ", + indices_shape[narrow(i)]); + } + } + return Status::OK(); + } +}; + +namespace cuda { // re-open onnxruntime::cuda + +// =================================================================== +// Section 7: CudaKernel base class +// Base class for all migrated CUDA kernels. Provides scratch-buffer +// management, CUDA handle access (cuBLAS, cuDNN), device property +// queries, and the CudaAsyncBuffer helper for host→device transfers. +// =================================================================== + +// Additional adapter logic for CudaKernel + +class CudaKernel : public OpKernel { + public: + explicit CudaKernel(const OpKernelInfo& info) : OpKernel(info) { + const auto* provider = info.GetExecutionProvider(); + runtime_config_ = detail::GetCudaKernelAdapterRuntimeConfigForProvider(provider); + use_tf32_ = runtime_config_->use_tf32; + device_id_ = runtime_config_->device_id; + device_prop_ = runtime_config_->device_prop; + } + virtual ~CudaKernel() = default; + Status Compute(OpKernelContext* ctx) const { + Status s = ComputeInternal(ctx); + if (s.IsOK()) { + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "CUDA error: " + std::string(cudaGetErrorString(err))); + } + return s; + } + virtual Status ComputeInternal(OpKernelContext* ctx) const = 0; + + inline cudaStream_t DefaultCudaStream() const { return Stream(static_cast(nullptr)); } + inline cublasHandle_t DefaultCublasHandle() const { return detail::GetDefaultCudaHandlesForDevice(device_id_).cublas; } + inline cudnnHandle_t DefaultCudnnHandle() const { return detail::GetDefaultCudaHandlesForDevice(device_id_).cudnn; } + + inline Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst, onnxruntime::Stream& stream) const { + if (src.Shape().Size() == 0) return Status::OK(); + if (cudaMemcpyAsync(dst.MutableDataRaw(), src.DataRaw(), src.SizeInBytes(), cudaMemcpyDeviceToDevice, (cudaStream_t)stream.GetHandle()) != cudaSuccess) { + return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Memcpy fail"); + } + return Status::OK(); + } + + cudaStream_t Stream(OpKernelContext* ctx) const { + if (!ctx) return nullptr; + return static_cast(ctx->GetGPUComputeStream()); + } + + // Returns an opaque stream pointer for passing to GetScratchBuffer/AddDeferredReleaseCPUPtr/CopyToGpu. + // Returns void* for dual-build compatibility: framework wraps Stream*, plugin wraps cudaStream_t. + inline void* GetComputeStream(OpKernelContext* ctx) const { + return ctx->GetGPUComputeStream(); + } + + inline onnxruntime::OrtStreamAdapter GetOrtStream(OpKernelContext* ctx) const { + return onnxruntime::OrtStreamAdapter(GetComputeStream(ctx)); + } + + static cudnnHandle_t GetCudnnHandle(cudaStream_t s) { + auto* sync = cuda_plugin::CudaSyncStream::FromCudaStream(s); + return sync ? sync->GetCudnnHandle() : nullptr; + } + static inline cudnnHandle_t GetCudnnHandle(onnxruntime::CudaStream* stream) { + return stream ? GetCudnnHandle(static_cast(reinterpret_cast(stream)->GetHandle())) : nullptr; + } + static inline cudnnHandle_t GetCudnnHandle(onnxruntime::Stream* stream) { + return stream ? GetCudnnHandle(static_cast(stream->GetHandle())) : nullptr; + } + cudnnHandle_t GetCudnnHandle(OpKernelContext* ctx) const { + auto stream = Stream(ctx); + auto handle = GetCudnnHandle(stream); + if (handle != nullptr) { + return handle; + } + + handle = DefaultCudnnHandle(); + if (stream != nullptr) { + CUDNN_CALL_THROW(cudnnSetStream(handle, stream)); + } + return handle; + } + + static cublasHandle_t GetCublasHandle(cudaStream_t s) { + auto* sync = cuda_plugin::CudaSyncStream::FromCudaStream(s); + return sync ? sync->GetCublasHandle() : nullptr; + } + static inline cublasHandle_t GetCublasHandle(onnxruntime::CudaStream* stream) { + return stream ? GetCublasHandle(static_cast(reinterpret_cast(stream)->GetHandle())) : nullptr; + } + static inline cublasHandle_t GetCublasHandle(onnxruntime::Stream* stream) { + return stream ? GetCublasHandle(static_cast(stream->GetHandle())) : nullptr; + } + cublasHandle_t GetCublasHandle(OpKernelContext* ctx) const { + auto stream = Stream(ctx); + auto handle = GetCublasHandle(stream); + if (handle != nullptr) { + return handle; + } + + handle = DefaultCublasHandle(); + if (stream != nullptr) { + CUBLAS_CALL_THROW(cublasSetStream(handle, stream)); + } + return handle; + } + + static cublasLtHandle_t GetCublasLtHandle(cudaStream_t s) { + auto* sync = cuda_plugin::CudaSyncStream::FromCudaStream(s); + return sync ? sync->GetCublasLtHandle() : nullptr; + } + static inline cublasLtHandle_t GetCublasLtHandle(onnxruntime::CudaStream* stream) { + return stream ? GetCublasLtHandle(static_cast(reinterpret_cast(stream)->GetHandle())) : nullptr; + } + static inline cublasLtHandle_t GetCublasLtHandle(onnxruntime::Stream* stream) { + return stream ? GetCublasLtHandle(static_cast(stream->GetHandle())) : nullptr; + } + cublasLtHandle_t GetCublasLtHandle(OpKernelContext* ctx) const { return GetCublasLtHandle(Stream(ctx)); } + + const cudaDeviceProp& GetDeviceProp() const { + // Some migrated kernels size their launches from device properties. If the + // per-provider cache was not populated for this kernel instance, fall back + // to a direct lookup instead of returning an all-zero struct. + if (device_prop_.maxThreadsPerMultiProcessor == 0 || device_prop_.multiProcessorCount == 0) { + return detail::GetDevicePropForDevice(device_id_); + } + + return device_prop_; + } + int GetCudnnConvAlgo() const { return runtime_config_->cudnn_conv_algo; } + bool GetCudnnConvUseMaxWorkspace() const { return runtime_config_->cudnn_conv_use_max_workspace; } + bool GetCudnnConv1dPadToNc1d() const { return runtime_config_->cudnn_conv1d_pad_to_nc1d; } + bool UseTF32() const { return use_tf32_; } + bool IsFuseConvBias() const { return runtime_config_->fuse_conv_bias; } + bool IsArchAvailable(int arch) const { return GetDeviceProp().major >= arch; } + // Delegate to the base OpKernel::Info() which holds a safe copy of OpKernelInfo. + // Do NOT store a reference to the constructor parameter — it becomes dangling. + const OpKernelInfo& Info() const { return OpKernel::Info(); } + const onnxruntime::AttentionKernelOptions* GetAttentionKernelOptions() const { + runtime_config_->attention_kernel_options.InitializeOnce(runtime_config_->sdpa_kernel, true, true); + return &runtime_config_->attention_kernel_options; + } + + // Stub for GetTuningContext — tunable ops are not supported in the plugin. + struct PluginTuningContextStub { + bool IsTunableOpEnabled() const { return false; } + }; + PluginTuningContextStub* GetTuningContext() const { + static PluginTuningContextStub stub; + return &stub; + } + + // GetConstOnes: returns a device buffer of constant ones. + // Delegates to IConstantBuffer from cuda_utils.h (compiled in cuda_utils.cu). + template + const T* GetConstOnes(size_t count, cudaStream_t stream) const { + auto* buf = detail::GetConstOnesBufferForDevice(device_id_); + return buf->GetBuffer(stream, count); + } + + template + using IAllocatorUniquePtr = std::unique_ptr>; + template + inline IAllocatorUniquePtr GetScratchBuffer(size_t cnt, void* s) const { + if (cnt == 0) return IAllocatorUniquePtr(nullptr, [](T*) {}); + size_t sz = 0; + if (!detail::TryBytesForCount(cnt, detail::SizeOf::value, sz)) { + ORT_THROW("CUDA scratch buffer allocation size overflow for ", cnt, " elements"); + } + + void* p = nullptr; + cudaError_t alloc_result = cudaSuccess; + bool used_async_alloc = false; + if (s) { + alloc_result = cudaMallocAsync(&p, sz, static_cast(s)); + used_async_alloc = (alloc_result == cudaSuccess); + if (!used_async_alloc && (alloc_result == cudaErrorNotSupported || alloc_result == cudaErrorInvalidValue)) { + alloc_result = cudaMalloc(&p, sz); + } + } else { + alloc_result = cudaMalloc(&p, sz); + } + + if (alloc_result != cudaSuccess) { + ORT_THROW("CUDA scratch buffer allocation failed for ", sz, " bytes: ", cudaGetErrorString(alloc_result)); + } + + return IAllocatorUniquePtr(static_cast(p), [s, used_async_alloc](T* ptr) { + if (ptr) { + // Guard: only attempt async free if the stream is still registered. + // CudaSyncStream::~CudaSyncStream guarantees UnregisterStream() is + // called before cudaStreamDestroy(), so a non-null lookup here means + // the raw cudaStream_t handle is still valid. + if (used_async_alloc && s && + cuda_plugin::CudaSyncStream::FromCudaStream(static_cast(s)) != nullptr) { + cudaError_t free_result = cudaFreeAsync(ptr, static_cast(s)); + if (free_result == cudaSuccess) { + return; + } + } + + // Fall back to synchronous free if async free is unsupported or if the + // stream is no longer registered. cudaFree is valid for allocations + // returned by cudaMallocAsync and avoids using a stale stream handle. + cudaFree(ptr); + } + }); + } + template + inline IAllocatorUniquePtr GetTransientScratchBuffer(size_t cnt) const { + return GetScratchBuffer(cnt, nullptr); + } + inline void AddDeferredReleaseCPUPtr(void* p, void* s) const { + if (!p) return; + auto* sync = cuda_plugin::CudaSyncStream::FromCudaStream(static_cast(s)); + if (sync) { + sync->EnqueueDeferredCPUBuffer(p); + return; + } + + if (s != nullptr) { + cudaError_t sync_result = cudaStreamSynchronize(static_cast(s)); + if (sync_result != cudaSuccess) { + // If the raw stream handle is already invalid during teardown, prefer a + // bounded leak over freeing pinned memory that could still be in use by + // an in-flight async copy. + LOGS_DEFAULT(WARNING) << "AddDeferredReleaseCPUPtr: cudaStreamSynchronize failed (" + << cudaGetErrorString(sync_result) + << "); leaking pinned buffer to avoid use-after-free"; + return; + } + } + + cudaFreeHost(p); + } + template + inline IAllocatorUniquePtr AllocateBufferOnCPUPinned(size_t cnt) const { + if (cnt == 0) return IAllocatorUniquePtr(nullptr, [](T*) {}); + size_t sz = 0; + if (!detail::TryBytesForCount(cnt, detail::SizeOf::value, sz)) { + ORT_THROW("CUDA pinned CPU buffer allocation size overflow for ", cnt, " elements"); + } + void* p = nullptr; + if (cudaHostAlloc(&p, sz, cudaHostAllocDefault) != cudaSuccess) return IAllocatorUniquePtr(nullptr, [](T*) {}); + return IAllocatorUniquePtr(static_cast(p), [](T* ptr) { if (ptr) cudaFreeHost(ptr); }); + } + + template + class CudaAsyncBuffer { + public: + CudaAsyncBuffer(const CudaKernel* ok) : gpu_(nullptr, [](T*) {}), count_(0), op_kernel_(ok) {} + CudaAsyncBuffer(const CudaKernel* ok, size_t n) : CudaAsyncBuffer(ok) { AllocCpuPtr(n); } + CudaAsyncBuffer(const CudaKernel* ok, const T& v, size_t n) : CudaAsyncBuffer(ok, n) { + T* p = CpuPtr(); + for (size_t i = 0; i != n; ++i) *p++ = v; + } + CudaAsyncBuffer(const CudaKernel* ok, gsl::span vec) : CudaAsyncBuffer(ok, vec.size()) { + size_t bytes = 0; + if (!detail::TryBytesForCount(vec.size(), sizeof(T), bytes)) { + ORT_THROW("CUDA async buffer host copy size overflow for ", vec.size(), " elements"); + } + memcpy(CpuPtr(), vec.data(), bytes); + } + void AllocCpuPtr(size_t n) { + cpu_ = op_kernel_->AllocateBufferOnCPUPinned(n); + if (!cpu_) throw std::runtime_error("alloc fail"); + count_ = n; + } + Status CopyToGpu(void* s) { + if (cpu_) { + gpu_ = op_kernel_->GetScratchBuffer(count_, s); + size_t bytes = 0; + if (!detail::TryBytesForCount(count_, sizeof(T), bytes)) { + ORT_THROW("CUDA async buffer copy size overflow for ", count_, " elements"); + } + if (cudaMemcpyAsync(gpu_.get(), cpu_.get(), bytes, cudaMemcpyHostToDevice, static_cast(s)) != cudaSuccess) return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Memcpy fail"); + op_kernel_->AddDeferredReleaseCPUPtr(cpu_.release(), s); + } + return Status::OK(); + } + T* CpuPtr() const { return cpu_.get(); } + gsl::span CpuSpan() const { return gsl::span(CpuPtr(), count_); } + T* GpuPtr() const { return gpu_.get(); } + size_t count() const { return count_; } + + protected: + IAllocatorUniquePtr gpu_; + std::unique_ptr> cpu_{nullptr, [](T*) {}}; + size_t count_; + const CudaKernel* op_kernel_; + }; + + private: + std::shared_ptr runtime_config_; + cudaDeviceProp device_prop_{}; + bool use_tf32_ = true; + int device_id_ = 0; +}; + +// =================================================================== +// Section 8: Compute helper shims (HalfGemmOptions, CublasMathModeSetter) +// =================================================================== + +// Shims for HalfGemmOptions and CublasMathModeSetter required by fpgeneric.h +class HalfGemmOptions { + public: + static const HalfGemmOptions* GetInstance() { + static HalfGemmOptions instance; + return &instance; + } + cublasMath_t GetMathMode() const { return CUBLAS_DEFAULT_MATH; } + bool IsCompute16F() const { return false; } +#if defined(CUBLAS_COMPUTE_32F) + cublasComputeType_t GetComputeType() const { return CUBLAS_COMPUTE_32F; } +#else + cudaDataType_t GetComputeType() const { return CUDA_R_32F; } +#endif +}; + +class CublasMathModeSetter { + public: + CublasMathModeSetter(const cudaDeviceProp& prop, cublasHandle_t handle, cublasMath_t mode) : handle_(handle) { + enable_ = (mode == CUBLAS_TF32_TENSOR_OP_MATH ? prop.major >= 8 : true); + if (enable_) { + cublasGetMathMode(handle, &mode_); + enable_ = (mode_ != mode); + if (enable_) { + cublasSetMathMode(handle, mode); + } + } + } + + ~CublasMathModeSetter() { + if (enable_) { + cublasSetMathMode(handle_, mode_); + } + } + + private: + cublasHandle_t handle_; + cublasMath_t mode_ = CUBLAS_DEFAULT_MATH; + bool enable_; +}; + +} // namespace cuda + +// Global aliases for convenience +using MLFloat16 = onnxruntime::MLFloat16; +using BFloat16 = onnxruntime::BFloat16; +using Float8E4M3FN = onnxruntime::Float8E4M3FN; +using Float8E4M3FNUZ = onnxruntime::Float8E4M3FNUZ; +using Float8E5M2 = onnxruntime::Float8E5M2; +using Float8E5M2FNUZ = onnxruntime::Float8E5M2FNUZ; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_memcpy_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_memcpy_plugin.cc new file mode 100644 index 0000000000000..f8b4d27e5cade --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_memcpy_plugin.cc @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Plugin-side MemcpyFromHost / MemcpyToHost kernels. +// These handle the common Tensor case using cudaMemcpyAsync directly. + +#include "core/providers/cuda/plugin/cuda_kernel_adapter.h" + +#include + +namespace onnxruntime { +namespace cuda { + +class PluginMemcpy final : public CudaKernel { + public: + PluginMemcpy(const OpKernelInfo& info) : CudaKernel(info) {} + + Status ComputeInternal(OpKernelContext* ctx) const override { + const Tensor* X = ctx->Input(0); + ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); + Tensor* Y = ctx->Output(0, X->Shape()); + ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); + + if (X->SizeInBytes() == 0) { + return Status::OK(); + } + + const void* src = X->DataRaw(); + void* dst = Y->MutableDataRaw(); + if (src == dst) { + return Status::OK(); + } + + // Determine copy direction from device placement. + const auto& src_loc = X->Location(); + const auto& dst_loc = Y->Location(); + cudaMemcpyKind kind; + if (src_loc.device.Type() == OrtDevice::CPU && dst_loc.device.Type() == OrtDevice::GPU) { + kind = cudaMemcpyHostToDevice; + } else if (src_loc.device.Type() == OrtDevice::GPU && dst_loc.device.Type() == OrtDevice::CPU) { + kind = cudaMemcpyDeviceToHost; + } else if (src_loc.device.Type() == OrtDevice::GPU && dst_loc.device.Type() == OrtDevice::GPU) { + kind = cudaMemcpyDeviceToDevice; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "PluginMemcpy: unsupported copy direction"); + } + + cudaStream_t stream = Stream(ctx); + if (stream != nullptr) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst, src, X->SizeInBytes(), kind, stream)); + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst, src, X->SizeInBytes(), kind)); + } + return Status::OK(); + } +}; + +ONNX_OPERATOR_KERNEL_EX( + MemcpyFromHost, + kOnnxDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + PluginMemcpy); + +ONNX_OPERATOR_KERNEL_EX( + MemcpyToHost, + kOnnxDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPUOutput, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypesIRv9()), + PluginMemcpy); + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_ep.cc new file mode 100644 index 0000000000000..8ad5da2fa6ab7 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_ep.cc @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// DLL entry points for the CUDA Plugin Execution Provider. +// Exports CreateEpFactories() and ReleaseEpFactory() as the +// public interface for ORT to load and use the CUDA EP as a plugin. + +#include "onnxruntime_cxx_api.h" + +#include "cuda_ep_factory.h" + +#ifndef _WIN32 +#define EXPORT_SYMBOL __attribute__((visibility("default"))) +#else +#define EXPORT_SYMBOL +#endif + +extern "C" { + +/// Create the CUDA EP factory instances. +/// Called by ORT when loading the CUDA plugin EP DLL. +EXPORT_SYMBOL OrtStatus* CreateEpFactories( + const char* registration_name, + const OrtApiBase* ort_api_base, + const OrtLogger* default_logger, + OrtEpFactory** factories, + size_t max_factories, + size_t* num_factories) { + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + const OrtEpApi* ep_api = ort_api->GetEpApi(); + + // Initialize the C++ API FIRST before any C++ wrapper usage + Ort::InitApi(ort_api); + + if (default_logger == nullptr) { + return ort_api->CreateStatus( + ORT_INVALID_ARGUMENT, + "CUDA Plugin EP: default_logger must not be null."); + } + + // Log initialization (default_logger is guaranteed non-null after the check above). + { + std::string msg = "CreateEpFactories: Initializing CUDA Plugin EP with registration name: "; + msg += (registration_name ? registration_name : "NULL"); + auto* status = ort_api->Logger_LogMessage(default_logger, ORT_LOGGING_LEVEL_INFO, + msg.c_str(), + ORT_FILE, __LINE__, __FUNCTION__); + if (status) ort_api->ReleaseStatus(status); + } + + if (max_factories < 1) { + auto* log_status = ort_api->Logger_LogMessage(default_logger, ORT_LOGGING_LEVEL_ERROR, + "CreateEpFactories: max_factories < 1", + ORT_FILE, __LINE__, __FUNCTION__); + if (log_status) ort_api->ReleaseStatus(log_status); + return ort_api->CreateStatus( + ORT_INVALID_ARGUMENT, + "CUDA Plugin EP: Not enough space to return EP factory. Need at least one."); + } + + try { + auto factory = std::make_unique( + *ort_api, *ep_api, *default_logger); + + factories[0] = factory.release(); + *num_factories = 1; + + auto* log_status = ort_api->Logger_LogMessage(default_logger, ORT_LOGGING_LEVEL_INFO, + "CreateEpFactories: Successfully created CUDA EP factory", + ORT_FILE, __LINE__, __FUNCTION__); + if (log_status) ort_api->ReleaseStatus(log_status); + } catch (const std::exception& ex) { + auto* log_status = ort_api->Logger_LogMessage(default_logger, ORT_LOGGING_LEVEL_ERROR, + ex.what(), ORT_FILE, __LINE__, __FUNCTION__); + if (log_status) ort_api->ReleaseStatus(log_status); + return ort_api->CreateStatus(ORT_EP_FAIL, ex.what()); + } + + return nullptr; +} + +/// Release a CUDA EP factory instance. +/// Called by ORT when unloading the CUDA plugin EP DLL. +EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete static_cast(factory); + return nullptr; +} + +} // extern "C" diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_ep_symbols.def b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_ep_symbols.def new file mode 100644 index 0000000000000..6d0efad28c669 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_ep_symbols.def @@ -0,0 +1,4 @@ +; Symbol export definition file for CUDA Plugin EP (Windows) +EXPORTS + CreateEpFactories + ReleaseEpFactory diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu new file mode 100644 index 0000000000000..b5b3f19d8a7c9 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.cu @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file provides the CreateCudaKernelRegistry entrypoint for the CUDA plugin EP. +// +// Kernel registration is now fully automatic: each compiled kernel .cc file's +// ONNX_OPERATOR_*_KERNEL_EX macro expansion creates a BuildKernelCreateInfo<>() +// template specialization and auto-registers it in PluginKernelCollector via +// the macro overrides in cuda_kernel_adapter.h. +// +// CreateCudaKernelRegistry iterates the collector and registers each entry +// into an adapter::KernelRegistry which is then returned to the EP factory. + +#include "cuda_plugin_kernels.h" +#include "cuda_stream_plugin.h" +#include "cuda_kernel_adapter.h" + +// Define the BuildKernelCreateInfo() sentinel in onnxruntime::cuda. +// This is normally defined in cuda_execution_provider.cc (excluded from plugin). +// The NHWC registration tables reference it as a placeholder to prevent empty arrays. +namespace onnxruntime::cuda { +template <> +KernelCreateInfo BuildKernelCreateInfo() { + KernelCreateInfo info; + return info; +} +} // namespace onnxruntime::cuda + +namespace onnxruntime { +namespace cuda_plugin { + +OrtStatus* CreateCudaKernelRegistry(const OrtEpApi& /*ep_api*/, + const char* /*ep_name*/, + void* /*create_kernel_state*/, + OrtKernelRegistry** out_registry) { + *out_registry = nullptr; + + EXCEPTION_TO_STATUS_BEGIN + + // adapter::KernelRegistry wraps OrtKernelRegistry via the Ort C++ API. + ::onnxruntime::ep::adapter::KernelRegistry registry; + + // Iterate all self-registered BuildKernelCreateInfoFn pointers. + auto entries = ::onnxruntime::cuda::PluginKernelCollector::Instance().Entries(); + for (auto build_fn : entries) { + ::onnxruntime::ep::adapter::KernelCreateInfo info = build_fn(); + if (info.kernel_def != nullptr) { // filter the BuildKernelCreateInfo sentinel + ORT_THROW_IF_ERROR(registry.Register(std::move(info))); + } + } + + *out_registry = registry.release(); + return nullptr; + + EXCEPTION_TO_STATUS_END +} + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.h b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.h new file mode 100644 index 0000000000000..1ff979de22743 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_kernels.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_plugin_utils.h" + +namespace onnxruntime { +namespace cuda_plugin { + +/// Create the CUDA kernel registry using self-registered BuildKernelCreateInfo<> +/// entries from PluginKernelCollector. All compiled kernel .cc files automatically +/// contribute their registrations via the macro overrides in cuda_kernel_adapter.h. +OrtStatus* CreateCudaKernelRegistry(const OrtEpApi& ep_api, + const char* ep_name, + void* create_kernel_state, + OrtKernelRegistry** out_registry); + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h new file mode 100644 index 0000000000000..0e4808d07046d --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_plugin_utils.h @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Common utilities, error-handling macros, and type definitions shared by +// all source files in the CUDA plugin EP implementation. + +#pragma once + +#include "onnxruntime_c_api.h" +#include "onnxruntime_cxx_api.h" + +#include +#include +#include +#include + +// Error handling macros + +#ifndef PL_CUDA_RETURN_IF_ERROR +#define PL_CUDA_RETURN_IF_ERROR(cuda_call_expr) \ + do { \ + cudaError_t _cuda_err = (cuda_call_expr); \ + if (_cuda_err != cudaSuccess) { \ + return Ort::GetApi().CreateStatus( \ + ORT_EP_FAIL, \ + (std::string("CUDA error: ") + cudaGetErrorName(_cuda_err) + ": " + \ + cudaGetErrorString(_cuda_err)) \ + .c_str()); \ + } \ + } while (0) +#endif + +// Throwing variant for use in constructors and non-OrtStatus contexts. +// Analogous to CUDA_CALL_THROW in the non-plugin build. +#ifndef PL_CUDA_CALL_THROW +#define PL_CUDA_CALL_THROW(cuda_call_expr) \ + do { \ + cudaError_t _cuda_err = (cuda_call_expr); \ + if (_cuda_err != cudaSuccess) { \ + throw std::runtime_error( \ + std::string("CUDA error: ") + cudaGetErrorName(_cuda_err) + ": " + \ + cudaGetErrorString(_cuda_err)); \ + } \ + } while (0) +#endif + +#ifndef PL_CUBLAS_RETURN_IF_ERROR +#define PL_CUBLAS_RETURN_IF_ERROR(cublas_call_expr) \ + do { \ + cublasStatus_t _cublas_err = (cublas_call_expr); \ + if (_cublas_err != CUBLAS_STATUS_SUCCESS) { \ + return Ort::GetApi().CreateStatus( \ + ORT_EP_FAIL, \ + (std::string("cuBLAS error: ") + \ + std::to_string(static_cast(_cublas_err))) \ + .c_str()); \ + } \ + } while (0) +#endif + +#ifndef PL_CUDNN_RETURN_IF_ERROR +#define PL_CUDNN_RETURN_IF_ERROR(cudnn_call_expr) \ + do { \ + cudnnStatus_t _cudnn_err = (cudnn_call_expr); \ + if (_cudnn_err != CUDNN_STATUS_SUCCESS) { \ + return Ort::GetApi().CreateStatus( \ + ORT_EP_FAIL, \ + (std::string("cuDNN error: ") + \ + cudnnGetErrorString(_cudnn_err)) \ + .c_str()); \ + } \ + } while (0) +#endif + +#define EXCEPTION_TO_STATUS_BEGIN try { +#define EXCEPTION_TO_STATUS_END \ + } \ + catch (const Ort::Exception& ex) { \ + Ort::Status status(ex); \ + return status.release(); \ + } \ + catch (const std::exception& ex) { \ + Ort::Status status(ex.what(), ORT_EP_FAIL); \ + return status.release(); \ + } + +/// Stored API pointers accessible to all plugin components. +struct CudaPluginApis { + const OrtApi& ort_api; + const OrtEpApi& ep_api; +}; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc new file mode 100644 index 0000000000000..521c6bb15c13f --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc @@ -0,0 +1,279 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cuda_stream_plugin.h" +#include "cuda_ep_factory.h" +#include +#include +#include + +namespace onnxruntime { +namespace cuda_plugin { + +namespace { + +// Global stream-to-CudaSyncStream mapping. +// Required because migrated CUDA kernels receive only a raw cudaStream_t +// but need access to associated cuBLAS/cuDNN handles. +using StreamMap = std::unordered_map; + +StreamMap& GetStreamMap() { + static StreamMap stream_map; + return stream_map; +} + +std::shared_mutex& GetStreamMapMutex() { + static std::shared_mutex stream_map_mutex; + return stream_map_mutex; +} + +// Monotonically increasing generation counter, bumped on every UnregisterStream +// so that TLS caches can detect stale entries without acquiring a lock. +std::atomic& GetStreamMapGeneration() { + static std::atomic generation{0}; + return generation; +} +} // namespace + +// --------------------------------------------------------------------------- +// CudaSyncStream +// --------------------------------------------------------------------------- + +CudaSyncStream::CudaSyncStream(CudaEpFactory& factory, int device_id, + const OrtEp* /*ep*/) + : OrtSyncStreamImpl{}, + factory_(factory), + device_id_(device_id) { + ort_version_supported = ORT_API_VERSION; + GetHandle = GetHandleImpl; + CreateNotification = CreateNotificationImpl; + Flush = FlushImpl; + OnSessionRunEnd = OnSessionRunEndImpl; + Release = ReleaseImpl; +} + +CudaSyncStream::~CudaSyncStream() { + bool has_deferred_cpu_buffers = false; + { + std::lock_guard lock(deferred_cpu_buffers_mutex_); + has_deferred_cpu_buffers = !deferred_cpu_buffers_.empty(); + } + + if (has_deferred_cpu_buffers) { + if (cuda_stream_ != nullptr) { + OrtStatus* status = OnSessionRunEndImpl(this); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + } + } else { + OrtStatus* status = CleanupDeferredCPUBuffers(); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + } + } + } + + if (cublas_handle_) cublasDestroy(cublas_handle_); + if (cudnn_handle_) cudnnDestroy(cudnn_handle_); + if (cublas_lt_handle_) cublasLtDestroy(cublas_lt_handle_); + if (cuda_stream_) { + // Unregister the stream from the global map *after* destroying handles but + // *before* destroying the stream itself. This ordering ensures: + // 1. No concurrent kernel can obtain cuBLAS/cuDNN handles from a destroyed + // CudaSyncStream during the brief window before unregistration. + // 2. UnregisterStream bumps the TLS generation counter, invalidating cached + // lookups in other threads. + // 3. The stream is destroyed only after it is no longer discoverable. + UnregisterStream(cuda_stream_); + + auto destroy_result = cudaStreamDestroy(cuda_stream_); + if (destroy_result == cudaSuccess && !deferred_cpu_buffers_.empty()) { + // Fallback: we only reach here when the earlier cudaStreamSynchronize in + // OnSessionRunEndImpl failed, leaving some buffers un-freed. + // cudaStreamDestroy on a non-blocking stream returns immediately (async + // cleanup), so in-flight ops may still reference these buffers. However, + // a prior sync failure indicates a serious CUDA error, so best-effort + // cleanup is the most we can do here. + OrtStatus* status = CleanupDeferredCPUBuffers(); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + } + } + } +} + +OrtStatus* CudaSyncStream::InitHandles() { + PL_CUDA_RETURN_IF_ERROR(cudaSetDevice(device_id_)); + + PL_CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&cuda_stream_, cudaStreamNonBlocking)); + + PL_CUBLAS_RETURN_IF_ERROR(cublasCreate(&cublas_handle_)); + PL_CUBLAS_RETURN_IF_ERROR(cublasSetStream(cublas_handle_, cuda_stream_)); + + PL_CUDNN_RETURN_IF_ERROR(cudnnCreate(&cudnn_handle_)); + PL_CUDNN_RETURN_IF_ERROR(cudnnSetStream(cudnn_handle_, cuda_stream_)); + + PL_CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&cublas_lt_handle_)); + RegisterStream(cuda_stream_, this); + + return nullptr; +} + +void CudaSyncStream::EnqueueDeferredCPUBuffer(void* cpu_buffer) { + std::lock_guard lock(deferred_cpu_buffers_mutex_); + deferred_cpu_buffers_.push_back(cpu_buffer); +} + +OrtStatus* CudaSyncStream::CleanupDeferredCPUBuffers() noexcept { + std::vector buffers_to_free; + { + std::lock_guard lock(deferred_cpu_buffers_mutex_); + buffers_to_free.swap(deferred_cpu_buffers_); + } + + OrtStatus* first_error = nullptr; + for (void* buf : buffers_to_free) { + cudaError_t err = cudaFreeHost(buf); + if (err != cudaSuccess && first_error == nullptr) { + first_error = Ort::GetApi().CreateStatus( + ORT_EP_FAIL, + (std::string("CUDA error: ") + cudaGetErrorName(err) + ": " + cudaGetErrorString(err)).c_str()); + } + } + return first_error; +} + +/*static*/ void* ORT_API_CALL CudaSyncStream::GetHandleImpl(OrtSyncStreamImpl* this_ptr) noexcept { + auto* stream = static_cast(this_ptr); + return stream->cuda_stream_; +} + +/*static*/ OrtStatus* ORT_API_CALL CudaSyncStream::CreateNotificationImpl( + OrtSyncStreamImpl* this_ptr, OrtSyncNotificationImpl** notification) noexcept { + EXCEPTION_TO_STATUS_BEGIN + auto* stream = static_cast(this_ptr); + auto notif = std::make_unique(*stream); + *notification = notif.release(); + return nullptr; + EXCEPTION_TO_STATUS_END +} + +/*static*/ OrtStatus* ORT_API_CALL CudaSyncStream::FlushImpl(OrtSyncStreamImpl* this_ptr) noexcept { + auto* stream = static_cast(this_ptr); + PL_CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream->cuda_stream_)); + return nullptr; +} + +/*static*/ OrtStatus* ORT_API_CALL CudaSyncStream::OnSessionRunEndImpl(OrtSyncStreamImpl* this_ptr) noexcept { + auto* stream = static_cast(this_ptr); + if (stream->cuda_stream_ == nullptr) { + return stream->CleanupDeferredCPUBuffers(); + } + // Synchronize before releasing deferred CPU buffers to ensure + // all async copies using those buffers have completed. + PL_CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream->cuda_stream_)); + return stream->CleanupDeferredCPUBuffers(); +} + +/*static*/ void ORT_API_CALL CudaSyncStream::ReleaseImpl(OrtSyncStreamImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +/*static*/ CudaSyncStream* CudaSyncStream::FromCudaStream(cudaStream_t stream) { + if (stream == nullptr) { + return nullptr; + } + + // Thread-local TLS cache to mitigate lock contention on the hot path. + // The generation counter is bumped on every UnregisterStream() so that + // stale TLS entries (pointing to destroyed CudaSyncStream objects) are + // automatically invalidated without requiring per-thread notification. + thread_local cudaStream_t tls_last_stream = nullptr; + thread_local CudaSyncStream* tls_last_sync_stream = nullptr; + thread_local uint64_t tls_generation = 0; + + uint64_t current_gen = GetStreamMapGeneration().load(std::memory_order_acquire); + if (stream == tls_last_stream && tls_generation == current_gen) { + return tls_last_sync_stream; + } + + auto& stream_map = GetStreamMap(); + std::shared_lock lock(GetStreamMapMutex()); + auto it = stream_map.find(stream); + if (it != stream_map.end()) { + tls_last_stream = stream; + tls_last_sync_stream = it->second; + tls_generation = current_gen; + return it->second; + } + return nullptr; +} + +/*static*/ void CudaSyncStream::RegisterStream(cudaStream_t stream, CudaSyncStream* sync_stream) { + auto& stream_map = GetStreamMap(); + std::unique_lock lock(GetStreamMapMutex()); + stream_map[stream] = sync_stream; +} + +/*static*/ void CudaSyncStream::UnregisterStream(cudaStream_t stream) { + auto& stream_map = GetStreamMap(); + std::unique_lock lock(GetStreamMapMutex()); + stream_map.erase(stream); + // Bump generation so TLS caches in other threads are invalidated. + GetStreamMapGeneration().fetch_add(1, std::memory_order_release); +} + +// --------------------------------------------------------------------------- +// CudaSyncNotification +// --------------------------------------------------------------------------- + +CudaSyncNotification::CudaSyncNotification(CudaSyncStream& stream) + : OrtSyncNotificationImpl{}, + stream_(stream) { + ort_version_supported = ORT_API_VERSION; + Activate = ActivateImpl; + WaitOnDevice = WaitOnDeviceImpl; + WaitOnHost = WaitOnHostImpl; + Release = ReleaseImpl; + + // Create a CUDA event for synchronization (disable timing for performance) + PL_CUDA_CALL_THROW(cudaSetDevice(stream_.GetDeviceId())); + PL_CUDA_CALL_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); +} + +CudaSyncNotification::~CudaSyncNotification() { + if (event_) { + cudaEventDestroy(event_); + } +} + +/*static*/ OrtStatus* ORT_API_CALL CudaSyncNotification::ActivateImpl( + OrtSyncNotificationImpl* this_ptr) noexcept { + auto* notif = static_cast(this_ptr); + PL_CUDA_RETURN_IF_ERROR(cudaEventRecord(notif->event_, notif->stream_.GetCudaStream())); + return nullptr; +} + +/*static*/ OrtStatus* ORT_API_CALL CudaSyncNotification::WaitOnDeviceImpl( + OrtSyncNotificationImpl* this_ptr, OrtSyncStream* stream) noexcept { + auto* notif = static_cast(this_ptr); + // SyncStream_GetHandle is in the main ORT API + cudaStream_t wait_stream = static_cast(Ort::GetApi().SyncStream_GetHandle(stream)); + PL_CUDA_RETURN_IF_ERROR(cudaStreamWaitEvent(wait_stream, notif->event_, 0)); + return nullptr; +} + +/*static*/ OrtStatus* ORT_API_CALL CudaSyncNotification::WaitOnHostImpl( + OrtSyncNotificationImpl* this_ptr) noexcept { + auto* notif = static_cast(this_ptr); + PL_CUDA_RETURN_IF_ERROR(cudaEventSynchronize(notif->event_)); + return nullptr; +} + +/*static*/ void ORT_API_CALL CudaSyncNotification::ReleaseImpl( + OrtSyncNotificationImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h new file mode 100644 index 0000000000000..4b72dee82ca38 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// CUDA stream and event-based synchronization primitives for the plugin EP. +// CudaSyncStream wraps a cudaStream_t plus cuBLAS/cuDNN/cuBLASLt handles. +// CudaSyncNotification wraps a cudaEvent_t for cross-stream synchronization. +// A global stream registry (with TLS-cached lookups) allows migrated kernels +// to obtain their compute handles from a raw cudaStream_t. + +#pragma once + +#include "cuda_plugin_utils.h" + +#include +#include +#include + +namespace onnxruntime { +namespace cuda_plugin { + +class CudaSyncNotification; +class CudaEpFactory; + +/// CUDA stream implementation for the plugin EP. +/// Owns a cudaStream_t and associated cuBLAS/cuDNN handles. +class CudaSyncStream : public OrtSyncStreamImpl { + public: + CudaSyncStream(CudaEpFactory& factory, int device_id, + const OrtEp* ep); + ~CudaSyncStream(); + + int GetDeviceId() const { return device_id_; } + cudaStream_t GetCudaStream() const { return cuda_stream_; } + cublasHandle_t GetCublasHandle() const { return cublas_handle_; } + cudnnHandle_t GetCudnnHandle() const { return cudnn_handle_; } + cublasLtHandle_t GetCublasLtHandle() const { return cublas_lt_handle_; } + + void EnqueueDeferredCPUBuffer(void* cpu_buffer); + OrtStatus* InitHandles(); + + /// Look up the CudaSyncStream wrapper from a raw cudaStream_t handle. + /// Uses a thread-local TLS cache with a generation counter to avoid lock + /// contention on this hot path (called on every kernel launch). + static CudaSyncStream* FromCudaStream(cudaStream_t stream); + + private: + static void RegisterStream(cudaStream_t stream, CudaSyncStream* sync_stream); + static void UnregisterStream(cudaStream_t stream); + static void* ORT_API_CALL GetHandleImpl(OrtSyncStreamImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL CreateNotificationImpl( + OrtSyncStreamImpl* this_ptr, OrtSyncNotificationImpl** notification) noexcept; + static OrtStatus* ORT_API_CALL FlushImpl(OrtSyncStreamImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL OnSessionRunEndImpl(OrtSyncStreamImpl* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtSyncStreamImpl* this_ptr) noexcept; + + OrtStatus* CleanupDeferredCPUBuffers() noexcept; + + CudaEpFactory& factory_; + int device_id_; + cudaStream_t cuda_stream_ = nullptr; + cublasHandle_t cublas_handle_ = nullptr; + cudnnHandle_t cudnn_handle_ = nullptr; + cublasLtHandle_t cublas_lt_handle_ = nullptr; + + // CPU buffers whose deallocation is deferred to OnSessionRunEnd. + // Pinned memory must remain valid until all async device operations that + // reference it have completed, so we synchronize the stream first. + mutable std::mutex deferred_cpu_buffers_mutex_; + std::vector deferred_cpu_buffers_; +}; + +/// CUDA event-based notification for stream synchronization. +class CudaSyncNotification : public OrtSyncNotificationImpl { + public: + explicit CudaSyncNotification(CudaSyncStream& stream); + ~CudaSyncNotification(); + + private: + static OrtStatus* ORT_API_CALL ActivateImpl(OrtSyncNotificationImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL WaitOnDeviceImpl( + OrtSyncNotificationImpl* this_ptr, OrtSyncStream* stream) noexcept; + static OrtStatus* ORT_API_CALL WaitOnHostImpl(OrtSyncNotificationImpl* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtSyncNotificationImpl* this_ptr) noexcept; + + CudaSyncStream& stream_; + cudaEvent_t event_ = nullptr; +}; + +} // namespace cuda_plugin +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/provider_api_shims.cc b/onnxruntime/core/providers/cuda/plugin/provider_api_shims.cc new file mode 100644 index 0000000000000..2d6851aae07d2 --- /dev/null +++ b/onnxruntime/core/providers/cuda/plugin/provider_api_shims.cc @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Provider API shims used by migrated CUDA kernels. +// Provides direct implementations of utility functions that in-tree kernels +// obtain via the SHARED_PROVIDER bridge (GetEnvironmentVar, floatToHalf, +// halfToFloat). Plugin builds skip SHARED_PROVIDER entirely, so these thin +// wrappers ensure the migrated kernel code compiles and links. + +#include +#include +#include "core/common/float16.h" + +namespace onnxruntime { + +std::string GetEnvironmentVar(const std::string& var_name) { + const char* val = std::getenv(var_name.c_str()); + return val ? std::string(val) : std::string(); +} + +namespace math { + +uint16_t floatToHalf(float f) { + return MLFloat16(f).val; +} + +float halfToFloat(uint16_t h) { + return MLFloat16::FromBits(h).ToFloat(); +} + +} // namespace math + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 006b2366af0a5..127cfcc557fd5 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -335,17 +335,30 @@ Status PrepareForReduce(const Tensor* X, return Status::OK(); } +// Unified scratch buffer allocation: when a CudaKernel pointer is available +// (plugin build path), use GetScratchBuffer which routes through the adapter +// allocator. Otherwise fall back to IAllocator::MakeUniquePtr (in-tree path). +template +inline IAllocatorUniquePtr AllocateScratchBuffer( + const AllocatorPtr& gpu_allocator, const CudaKernel* kernel, size_t count, void* compute_stream) { + if (count == 0) return nullptr; + if (kernel) { + return kernel->GetScratchBuffer(count, compute_stream); + } + return IAllocator::MakeUniquePtr(gpu_allocator, count, false, static_cast(compute_stream)); +} + // `input_shape_override` is the input shape for compute purposes (if provided) template -Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, +Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const CudaKernel* kernel, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, gsl::span axes, bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, - Stream* ort_stream, cudnnHandle_t cudnn_handle, + cudaStream_t cuda_stream, void* compute_stream, cudnnHandle_t cudnn_handle, const TensorShape* input_shape_override) { typedef typename ToCudaType::MappedType CudaT; const TensorShape& input_shape = input_shape_override ? *input_shape_override : input.Shape(); - cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; + cudaStream_t stream = cuda_stream; int64_t input_count = prepare_reduce_metadata.input_count; int64_t output_count = prepare_reduce_metadata.output_count; @@ -367,7 +380,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, IAllocatorUniquePtr input_data_buffer(nullptr, [](T*) {}); const CudaT* input_data = reinterpret_cast(input.Data()); if (calculate_sqt) { - input_data_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); + input_data_buffer = AllocateScratchBuffer(gpu_allocator, kernel, input_count, compute_stream); input_data = reinterpret_cast(input_data_buffer.get()); fast_divmod tmp_div; Impl_Mul(stream, static_cast(SimpleBroadcast::NoBroadcast), nullptr, @@ -386,7 +399,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size(m, n); auto buffer = buffer_size_bytes == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, buffer_size_bytes, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, buffer_size_bytes, compute_stream); ORT_RETURN_IF_ERROR(reduce_matrix_columns(stream, input_data, reinterpret_cast(output.MutableData()), m, n, buffer.get(), buffer_size_bytes)); @@ -423,7 +436,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, if ((ReduceTensorIndices == CUDNN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same::value) || (ReduceTensorIndices == CUDNN_REDUCE_TENSOR_NO_INDICES && std::is_same::value)) { // ArgMax/ArgMin with FP16 are not supported by cudnn, so convert input to fp32 then call cudnn - temp_X = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); + temp_X = AllocateScratchBuffer(gpu_allocator, kernel, input_count, compute_stream); Impl_Cast(stream, reinterpret_cast(input.Data()), temp_X.get(), input_shape.Size()); } else { cudnn_type_X = CudnnTensor::GetDataType(); @@ -448,20 +461,20 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, input_tensor, output_tensor, &workspace_bytes)); auto workspace_cuda = workspace_bytes == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, workspace_bytes, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, workspace_bytes, compute_stream); size_t indices_bytes = 0; CUDNN_RETURN_IF_ERROR(cudnnGetReductionIndicesSize(cudnn_handle, reduce_desc, input_tensor, output_tensor, &indices_bytes)); auto indices_cuda = indices_bytes == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, indices_bytes, compute_stream); if (ReduceTensorIndices == CUDNN_REDUCE_TENSOR_NO_INDICES) { IAllocatorUniquePtr input_data_buffer(nullptr, [](T*) {}); CudaT* input_data = nullptr; if (calculate_sqt) { - input_data_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); + input_data_buffer = AllocateScratchBuffer(gpu_allocator, kernel, input_count, compute_stream); input_data = reinterpret_cast(input_data_buffer.get()); fast_divmod tmp_div; Impl_Mul(stream, @@ -490,7 +503,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, input_tensor, output_tensor, &indices_bytes_max)); auto indices_cuda_max = indices_bytes_max == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes_max, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, indices_bytes_max, compute_stream); auto* p_output = reinterpret_cast(output.template MutableData()); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( cudnn_handle, reduce_max_desc, indices_cuda_max.get(), indices_bytes_max, @@ -501,11 +514,11 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, // Exp(X-ReduceMax) const TensorShape output_shape(output_dims); - auto exp_result_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); + auto exp_result_buffer = AllocateScratchBuffer(gpu_allocator, kernel, input_count, compute_stream); auto exp_result = exp_result_buffer.get(); auto log_sum_result_buffer = output_count == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, output_count, compute_stream); auto log_sum_result = log_sum_result_buffer.get(); BinaryElementwisePreparation prepare; ORT_RETURN_IF_ERROR(prepare.BinaryElementwiseBroadcastPrepareHelper(input_shape, output_shape, input_shape)); @@ -575,7 +588,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, if (temp_X) { auto temp_output = output_count == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, output_count, compute_stream); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( cudnn_handle, reduce_desc, indices_cuda.get(), indices_bytes, workspace_cuda.get(), workspace_bytes, @@ -603,7 +616,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, if (temp_X) { auto temp_output = output_count == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, output_count, compute_stream); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( cudnn_handle, reduce_desc, indices_cuda.get(), indices_bytes, workspace_cuda.get(), workspace_bytes, @@ -612,7 +625,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, } else { auto temp_output = output_count == 0 ? nullptr - : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); + : AllocateScratchBuffer(gpu_allocator, kernel, output_count, compute_stream); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( cudnn_handle, reduce_desc, indices_cuda.get(), indices_bytes, workspace_cuda.get(), workspace_bytes, @@ -637,27 +650,27 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, } template Status ReduceComputeCore( - const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + const AllocatorPtr& gpu_allocator, const CudaKernel* kernel, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, gsl::span axes, bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, - Stream* ort_stream, cudnnHandle_t cudnn_handle, + cudaStream_t cuda_stream, void* compute_stream, cudnnHandle_t cudnn_handle, const TensorShape* input_shape_override); template Status ReduceComputeCore( - const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + const AllocatorPtr& gpu_allocator, const CudaKernel* kernel, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, gsl::span axes, bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, - Stream* ort_stream, cudnnHandle_t cudnn_handle, + cudaStream_t cuda_stream, void* compute_stream, cudnnHandle_t cudnn_handle, const TensorShape* input_shape_override); template Status ReduceComputeCore( - const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + const AllocatorPtr& gpu_allocator, const CudaKernel* kernel, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, gsl::span axes, bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, - Stream* ort_stream, cudnnHandle_t cudnn_handle, + cudaStream_t cuda_stream, void* compute_stream, cudnnHandle_t cudnn_handle, const TensorShape* input_shape_override); template @@ -704,9 +717,9 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnRe ORT_RETURN_IF_ERROR(PrepareForReduce(X, keepdims_, axes, prepare_reduce_metadata)); Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); const bool fast_reduction = fast_reduction_ && !ctx->GetUseDeterministicCompute(); - return ReduceComputeCore(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), *X, prepare_reduce_metadata, *Y, cudnn_reduce_op, axes, - calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction, ctx->GetComputeStream(), - GetCudnnHandleOrDefault(ctx->GetComputeStream())); + return ReduceComputeCore(AllocatorPtr{}, this, *X, prepare_reduce_metadata, *Y, cudnn_reduce_op, axes, + calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction, + Stream(ctx), GetComputeStream(ctx), GetCudnnHandle(ctx)); } #define SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(T) \ @@ -768,7 +781,7 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnRe CudnnReduceDescriptor reduce_desc; \ \ cudnnDataType_t cudnn_type_X = CUDNN_DATA_FLOAT; \ - IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count, ctx->GetComputeStream()); \ + IAllocatorUniquePtr temp_X = GetScratchBuffer(input_count, GetComputeStream(ctx)); \ Impl_Cast(Stream(ctx), reinterpret_cast(X->Data()), temp_X.get(), X->Shape().Size()); \ \ ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, cudnn_type_X, CUDNN_REDUCE_TENSOR_NO_INDICES)); \ @@ -778,12 +791,12 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnRe cudnnGetReductionIndicesSize(GetCudnnHandle(ctx), reduce_desc, input_tensor, output_tensor, &indices_bytes)); \ CUDNN_RETURN_IF_ERROR( \ cudnnGetReductionWorkspaceSize(GetCudnnHandle(ctx), reduce_desc, input_tensor, output_tensor, &workspace_bytes)); \ - IAllocatorUniquePtr indices_cuda = GetScratchBuffer(indices_bytes, ctx->GetComputeStream()); \ - IAllocatorUniquePtr workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); \ + IAllocatorUniquePtr indices_cuda = GetScratchBuffer(indices_bytes, GetComputeStream(ctx)); \ + IAllocatorUniquePtr workspace_cuda = GetScratchBuffer(workspace_bytes, GetComputeStream(ctx)); \ \ const auto one = Consts::One; \ const auto zero = Consts::Zero; \ - auto temp_Y = GetScratchBuffer(output_count, ctx->GetComputeStream()); \ + auto temp_Y = GetScratchBuffer(output_count, GetComputeStream(ctx)); \ CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(GetCudnnHandle(ctx), reduce_desc, indices_cuda.get(), indices_bytes, \ workspace_cuda.get(), workspace_bytes, &one, input_tensor, temp_X.get(), \ &zero, output_tensor, temp_Y.get())); \ @@ -796,7 +809,6 @@ SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int32_t) SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int64_t) SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(int8_t) SPECIALIZED_REDUCEKERNEL_COMPUTEIMPL(uint8_t) - namespace ReductionOps { template @@ -818,8 +830,10 @@ std::unique_ptr ReduceCompute(const AllocatorPtr& gpu_allocator, cudnnRe auto output = Tensor::Create(input.DataType(), prepare_reduce_metadata.squeezed_output_dims, allocator); - status = ReduceComputeCore(gpu_allocator, input, prepare_reduce_metadata, *output, cudnn_reduce_op, axes, - calculate_log, calculate_sqt, log_sum_exp, fast_reduction, stream, cudnn_handle, + status = ReduceComputeCore(gpu_allocator, nullptr, input, prepare_reduce_metadata, *output, cudnn_reduce_op, axes, + calculate_log, calculate_sqt, log_sum_exp, fast_reduction, + stream ? static_cast(stream->GetHandle()) : nullptr, + static_cast(stream), cudnn_handle, input_shape_override); if (!status.IsOK()) { @@ -923,6 +937,5 @@ REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, float, 17, 18) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, double, 17, 18) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, BFloat16, 17, 18) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, int32_t, 17, 18) - } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.h b/onnxruntime/core/providers/cuda/reduction/reduction_ops.h index ca1f89b7333f3..5c8fd1c2711cc 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.h +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.h @@ -236,11 +236,11 @@ Status PrepareForReduce(const Tensor* X, const TensorShape* input_shape_override = nullptr); template -Status ReduceComputeCore(const AllocatorPtr& allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, +Status ReduceComputeCore(const AllocatorPtr& allocator, const CudaKernel* kernel, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, gsl::span axes, bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, - Stream* ort_stream, cudnnHandle_t cudnn_handle, + cudaStream_t cuda_stream, void* compute_stream, cudnnHandle_t cudnn_handle, const TensorShape* input_shape_override = nullptr); // CUDA's reduction descriptor cudnnReduceTensorDescriptor_t is a pointer so diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index 4718e59e5a042..e7c8f52950141 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -88,7 +88,9 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons size_t& reorganized_w_data_size_in_bytes, IAllocatorUniquePtr& reorganized_w_data, CudnnFilterDescriptor& target_w_desc, - CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const { + CudnnRNN& rnn_desc, + void* alloc_stream, cudaStream_t cuda_stream, + cudnnHandle_t cudnn_handle) const { typedef typename ToCudaType::MappedType CudaT; int64_t input_size = W->Shape()[2]; // RNN W[num_directions_, hidden_size_, input_size] @@ -107,20 +109,18 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons // Prepare the weight data reorganized_w_data_size_in_bytes = w_size * sizeof(T); - reorganized_w_data = GetScratchBuffer(reorganized_w_data_size_in_bytes, ort_stream); + reorganized_w_data = GetScratchBuffer(reorganized_w_data_size_in_bytes, alloc_stream); // In many cases, this allocation is bigger than needed, leaving part of // the buffer uninitialized. non-zero garbage data leads to wrong result // in call to cudnnRNNForwardInference() // TODO! refine allocation size for each case. - cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; CUDA_RETURN_IF_ERROR(cudaMemsetAsync(reorganized_w_data.get(), 0, reorganized_w_data_size_in_bytes, cuda_stream)); const T* W_data = W->Data(); const T* R_data = R->Data(); const T* B_data = B == nullptr ? nullptr : B->Data(); - cudnnHandle_t cudnn_handle = GetCudnnHandleOrDefault(ort_stream); ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, reorganized_w_data_size_in_bytes, reorganized_w_data.get(), W_data, R_data, B_data, cuda_stream)); @@ -157,11 +157,11 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { if (get_B) { ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, - tmp_rnn_desc, nullptr)); + tmp_rnn_desc, nullptr, nullptr, DefaultCudnnHandle())); } else { ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, - tmp_rnn_desc, nullptr)); + tmp_rnn_desc, nullptr, nullptr, DefaultCudnnHandle())); } cudaStreamSynchronize(nullptr); @@ -238,11 +238,11 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // Prior to cuDNN 8.9.1 the sequence lens buffer must be passed to cudnnRNNForward and thus is must // be copied to the GPU always. - ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(GetComputeStream(ctx))); // Starting with cuDNN 8.9.1 the sequence lens buffer is ignored by cudnnRNNForward and thus it must // be copied to the GPU only for the ReverseBySequence kernels. // if (reverse_) { - // ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); + // ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(GetComputeStream(ctx))); // } // optional outputs @@ -257,7 +257,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { const T* x_data = X->Data(); if (reverse_) { // reverse input data - x_reversed_data = GetScratchBuffer(seq_length * batch_size * input_size, ctx->GetComputeStream()); + x_reversed_data = GetScratchBuffer(seq_length * batch_size * input_size, GetComputeStream(ctx)); ReverseBySequence(Stream(ctx), gsl::narrow_cast(seq_length), sequence_lens_buffer.GpuPtr(), @@ -280,7 +280,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { if (Y != nullptr) { y_data = Y->MutableData(); } else { - y_alloc_data = GetScratchBuffer(output_size, ctx->GetComputeStream()); + y_alloc_data = GetScratchBuffer(output_size, GetComputeStream(ctx)); y_data = y_alloc_data.get(); } @@ -307,7 +307,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { const Tensor& W = *ctx->Input(RNN_Input_Index::W); const Tensor& R = *ctx->Input(RNN_Input_Index::R); ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data_size_in_bytes, w_data, w_desc, - rnn_desc, ctx->GetComputeStream())); + rnn_desc, GetComputeStream(ctx), Stream(ctx), GetCudnnHandle(ctx))); } CudnnDataTensor x_desc1; @@ -329,8 +329,8 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { CUDNN_RETURN_IF_ERROR(cudnnGetRNNTempSpaceSizes(GetCudnnHandle(ctx), rnn_desc, CUDNN_FWD_MODE_INFERENCE, x_desc1, &workspace_bytes, &reservespace_bytes)); - auto workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); - auto reservespace_cuda = GetScratchBuffer(reservespace_bytes, ctx->GetComputeStream()); + auto workspace_cuda = GetScratchBuffer(workspace_bytes, GetComputeStream(ctx)); + auto reservespace_cuda = GetScratchBuffer(reservespace_bytes, GetComputeStream(ctx)); CUDNN_RETURN_IF_ERROR(cudnnRNNForward(GetCudnnHandle(ctx), rnn_desc, @@ -357,7 +357,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // Mask on output for 0 sequence batches if (zero_seq_count > 0) { // Mask on output for 0 sequence batches - SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); + SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, GetComputeStream(ctx), Stream(ctx)); } return Status::OK(); } @@ -365,7 +365,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { IAllocatorUniquePtr y_reorganized_data; if (reverse_ || num_directions_ == 2) { // reverse output - y_reorganized_data = GetScratchBuffer(output_size, ctx->GetComputeStream()); + y_reorganized_data = GetScratchBuffer(output_size, GetComputeStream(ctx)); if (reverse_) { // reverse output data ReverseBySequence(Stream(ctx), @@ -397,7 +397,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // Mask on output for 0 sequence batches if (zero_seq_count > 0) { - SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); + SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, GetComputeStream(ctx), Stream(ctx)); } return Status::OK(); @@ -409,13 +409,12 @@ void CudnnRnnBase::SetZeroSequences(const int64_t zero_seq_index_cache_size, T* y_data, T* y_h_data, T* y_c_data, - onnxruntime::Stream* ort_stream) const { + void* alloc_stream, cudaStream_t cuda_stream) const { typedef typename ToCudaType::MappedType CudaT; CudaAsyncBuffer zero_seq_index_cache_async_buffer(this, zero_seq_index_cache_size); memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), zero_seq_index_cache_size * sizeof(int32_t)); - ORT_THROW_IF_ERROR(zero_seq_index_cache_async_buffer.CopyToGpu(ort_stream)); - cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; + ORT_THROW_IF_ERROR(zero_seq_index_cache_async_buffer.CopyToGpu(alloc_stream)); MaskZeroSequences(cuda_stream, gsl::narrow_cast(hidden_size_), reinterpret_cast(y_data), diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 7b827f8d04593..b7a3d67b45e93 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -138,7 +138,8 @@ class CudnnRnnBase : public CudaKernel { IAllocatorUniquePtr& target_w_data, CudnnFilterDescriptor& target_w_desc, CudnnRNN& rnn_desc, - onnxruntime::Stream* ort_stream) const; + void* alloc_stream, cudaStream_t cuda_stream, + cudnnHandle_t cudnn_handle) const; Status SetWeightBias(const cudnnHandle_t handle, const cudnnRNNDescriptor_t rnn_desc, @@ -156,7 +157,7 @@ class CudnnRnnBase : public CudaKernel { T* y_data, T* y_h_data, T* y_c_data, - onnxruntime::Stream* cuda_stream) const; + void* alloc_stream, cudaStream_t cuda_stream) const; protected: // W_lin_layer_id_ & R_lin_layer_id_ are set in Constructor diff --git a/onnxruntime/core/providers/cuda/shared_inc/integer_gemm.h b/onnxruntime/core/providers/cuda/shared_inc/integer_gemm.h index ea1f28f734c25..7ef4b7218611b 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/integer_gemm.h +++ b/onnxruntime/core/providers/cuda/shared_inc/integer_gemm.h @@ -19,6 +19,7 @@ Status GemmInt8(int m, int32_t* c, int ldc, const CudaKernel* cuda_kernel, - onnxruntime::Stream* stream); + void* alloc_stream, cudaStream_t cuda_stream, + cublasHandle_t cublas_handle); } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/compress.cc b/onnxruntime/core/providers/cuda/tensor/compress.cc index a75f24341fc47..87c6f068dacd0 100644 --- a/onnxruntime/core/providers/cuda/tensor/compress.cc +++ b/onnxruntime/core/providers/cuda/tensor/compress.cc @@ -48,7 +48,7 @@ Status Compress::ComputeInternal(OpKernelContext* ctx) const { int64_t compress_input_length = has_axis_ ? input_dimensions[axis] : input_size; int64_t valid_condition_length = compress_input_length < condition_length ? compress_input_length : condition_length; - auto condition_cumulative_sum_buffer = GetScratchBuffer(gsl::narrow(valid_condition_length), ctx->GetComputeStream()); + auto condition_cumulative_sum_buffer = GetScratchBuffer(gsl::narrow(valid_condition_length), GetComputeStream(ctx)); auto condition_cumulative_sum = condition_cumulative_sum_buffer.get(); size_t temp_storage_bytes = 0; @@ -58,7 +58,7 @@ Status Compress::ComputeInternal(OpKernelContext* ctx) const { gsl::narrow(valid_condition_length), temp_storage_bytes)); - auto temp_buffer = GetScratchBuffer(temp_storage_bytes, ctx->GetComputeStream()); + auto temp_buffer = GetScratchBuffer(temp_storage_bytes, GetComputeStream(ctx)); auto d_temp_storage = temp_buffer.get(); CUDA_RETURN_IF_ERROR(CompressInclusivePrefixSum(Stream(ctx), d_temp_storage, diff --git a/onnxruntime/core/providers/cuda/tensor/concat.cc b/onnxruntime/core/providers/cuda/tensor/concat.cc index 4a390541009d9..711f30e7a57ca 100644 --- a/onnxruntime/core/providers/cuda/tensor/concat.cc +++ b/onnxruntime/core/providers/cuda/tensor/concat.cc @@ -33,7 +33,7 @@ ONNX_OPERATOR_KERNEL_EX(Concat, Concat); Status Concat::ComputeInternal(OpKernelContext* ctx) const { - auto input_count = Node().InputArgCount().front(); + auto input_count = ctx->InputCount(); // Hold pointers to the input tensors to be used in the PrepareForCompute() step InlinedTensorsVector input_tensors; @@ -43,7 +43,7 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const { } Prepare p; - ORT_RETURN_IF_ERROR(PrepareForCompute(ctx, input_tensors, p)); + ORT_RETURN_IF_ERROR(PrepareForComputeImpl(ctx, input_tensors, p)); // Return at this point if output tensor is going to be empty if (p.output_num_elements == 0) @@ -76,7 +76,7 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const { Stream(ctx), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim, concat_sizes[0], p.output_tensor->MutableDataRaw(), input_ptr_array, static_cast(p.output_num_elements))); } else { - ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu(GetComputeStream(ctx))); ORT_RETURN_IF_ERROR(ConcatSameConcatDimImpl( Stream(ctx), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim, concat_sizes[0], p.output_tensor->MutableDataRaw(), input_ptr.GpuPtr(), static_cast(p.output_num_elements))); @@ -89,10 +89,10 @@ Status Concat::ComputeInternal(OpKernelContext* ctx) const { concat_sizes_range[i] += concat_sizes_range[i - 1]; } CudaAsyncBuffer concat_sizes_range_gpu(this, concat_sizes_range); - ORT_RETURN_IF_ERROR(concat_sizes_gpu.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(concat_sizes_range_gpu.CopyToGpu(ctx->GetComputeStream())); - ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(concat_sizes_gpu.CopyToGpu(GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu(GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(concat_sizes_range_gpu.CopyToGpu(GetComputeStream(ctx))); + ORT_RETURN_IF_ERROR(input_ptr.CopyToGpu(GetComputeStream(ctx))); ORT_RETURN_IF_ERROR(ConcatImpl(Stream(ctx), element_bytes, block_size_including_axis_dim, block_size_inside_axis_dim, concat_sizes_gpu.GpuPtr(), concat_sizes_range_gpu.GpuPtr(), axis_dimension_input_output_mapping_gpu.GpuPtr(), p.output_tensor->MutableDataRaw(), diff --git a/onnxruntime/core/providers/cuda/tensor/gather.cc b/onnxruntime/core/providers/cuda/tensor/gather.cc index b7170872b0c0f..0f1ce81099b56 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather.cc +++ b/onnxruntime/core/providers/cuda/tensor/gather.cc @@ -46,7 +46,7 @@ ONNX_OPERATOR_KERNEL_EX( Status Gather::ComputeInternal(OpKernelContext* context) const { Prepare p; - ORT_RETURN_IF_ERROR(PrepareForCompute(context, p)); + ORT_RETURN_IF_ERROR(PrepareForComputeImpl(context, p)); const TensorShape& input_shape = p.input_tensor->Shape(); diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd.cc b/onnxruntime/core/providers/cuda/tensor/gather_nd.cc index 0be8bab97c005..7933cbc53148c 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd.cc +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd.cc @@ -40,7 +40,8 @@ Status CheckBatchDimensionsMatch( template Status GatherNDBase::PrepareCompute( - onnxruntime::Stream* stream, + void* alloc_stream, + cudaStream_t cuda_stream, const int64_t batch_dims, const TensorShape& input_shape, const TensorShape& indices_shape, @@ -54,7 +55,6 @@ Status GatherNDBase::PrepareCompute( const auto num_batches = input_shape.SizeToDimension(batch_dims); const auto input_batch_stride = input_shape.SizeFromDimension(batch_dims); const auto num_slices_per_batch = num_slices / num_batches; - cudaStream_t cuda_stream = stream ? static_cast(stream->GetHandle()) : nullptr; const TIndex* const indices_data = indices_tensor->Data(); @@ -67,14 +67,14 @@ Status GatherNDBase::PrepareCompute( } } - auto sizes_from_slice_dims_buffer = GetScratchBuffer(sizes_from_slice_dims.size(), stream); + auto sizes_from_slice_dims_buffer = GetScratchBuffer(sizes_from_slice_dims.size(), alloc_stream); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( sizes_from_slice_dims_buffer.get(), sizes_from_slice_dims.data(), sizes_from_slice_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice, cuda_stream)); - input_slice_offsets_buffer = GetScratchBuffer(num_slices, stream); + input_slice_offsets_buffer = GetScratchBuffer(num_slices, alloc_stream); TArray input_dims(input_shape.GetDims()); @@ -180,7 +180,7 @@ Status GatherND::ComputeInternal(OpKernelContext* context) const { int64_t num_slices; int64_t slice_size; IAllocatorUniquePtr input_slice_offsets_buffer; - ORT_RETURN_IF_ERROR(PrepareCompute(context->GetComputeStream(), + ORT_RETURN_IF_ERROR(PrepareCompute(GetComputeStream(context), Stream(context), batch_dims_, input_shape, indices_shape, indices_tensor, num_slices, slice_size, input_slice_offsets_buffer)); diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd.h b/onnxruntime/core/providers/cuda/tensor/gather_nd.h index 0d63c8a159783..6973aadbe73ae 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd.h +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd.h @@ -23,7 +23,8 @@ class GatherNDBase : public CudaKernel { protected: template Status PrepareCompute( - onnxruntime::Stream* stream, + void* alloc_stream, + cudaStream_t cuda_stream, const int64_t batch_dims, const TensorShape& input_shape, const TensorShape& indices_shape, diff --git a/onnxruntime/core/providers/cuda/tensor/identity_op.cc b/onnxruntime/core/providers/cuda/tensor/identity_op.cc index 5120a661ef971..a92dfdb29a0aa 100644 --- a/onnxruntime/core/providers/cuda/tensor/identity_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/identity_op.cc @@ -5,32 +5,6 @@ namespace onnxruntime { namespace cuda { -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Dropout, - kOnnxDomain, - 7, 9, - kCudaExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) - .Alias(0, 0), - IdentityOp); - -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Dropout, - kOnnxDomain, - 10, - 11, - kCudaExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}) - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .Alias(0, 0), - IdentityOp); - ONNX_OPERATOR_VERSIONED_KERNEL_EX( Identity, kOnnxDomain, @@ -51,13 +25,24 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .Alias(0, 0), IdentityOp); +// From opset 14 onward the ONNX spec's type constraint is "V" which includes +// both Tensor and TensorSequence types. In the plugin EP build TensorSeq is +// an incomplete type, so we register only the Tensor subset. +#ifdef BUILD_CUDA_EP_AS_PLUGIN +#define IDENTITY_V_TYPES DataTypeImpl::AllFixedSizeTensorTypes() +#define IDENTITY_V_TYPES_IRv9 DataTypeImpl::AllFixedSizeTensorTypes() +#else +#define IDENTITY_V_TYPES DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes() +#define IDENTITY_V_TYPES_IRv9 DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9() +#endif + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Identity, kOnnxDomain, 14, 18, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypes()) + .TypeConstraint("V", IDENTITY_V_TYPES) .Alias(0, 0), IdentityOp); @@ -67,7 +52,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 19, 20, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9()) + .TypeConstraint("V", IDENTITY_V_TYPES_IRv9) .Alias(0, 0), IdentityOp); @@ -77,7 +62,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 21, 22, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9()) + .TypeConstraint("V", IDENTITY_V_TYPES_IRv9) .Alias(0, 0), IdentityOp); @@ -87,7 +72,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 23, 24, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9()) + .TypeConstraint("V", IDENTITY_V_TYPES_IRv9) .Alias(0, 0), IdentityOp); @@ -97,8 +82,11 @@ ONNX_OPERATOR_KERNEL_EX( 25, kCudaExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorAndSequenceTensorTypesIRv9()) + .TypeConstraint("V", IDENTITY_V_TYPES_IRv9) .Alias(0, 0), IdentityOp); + +#undef IDENTITY_V_TYPES +#undef IDENTITY_V_TYPES_IRv9 } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/identity_op.h b/onnxruntime/core/providers/cuda/tensor/identity_op.h index 775bc9d1ec924..de338f61126aa 100644 --- a/onnxruntime/core/providers/cuda/tensor/identity_op.h +++ b/onnxruntime/core/providers/cuda/tensor/identity_op.h @@ -15,8 +15,10 @@ class IdentityOp final : public CudaKernel { } Status ComputeInternal(OpKernelContext* context) const override { +#ifndef BUILD_CUDA_EP_AS_PLUGIN auto X_ml_type = context->InputType(0); if (X_ml_type->IsTensorType()) { +#endif const Tensor* X = context->Input(0); if (nullptr == X) { return Status(common::ONNXRUNTIME, common::FAIL, @@ -51,6 +53,7 @@ class IdentityOp final : public CudaKernel { CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mask_data, 0, mask->SizeInBytes(), Stream(context))); } } +#ifndef BUILD_CUDA_EP_AS_PLUGIN } else if (X_ml_type->IsTensorSequenceType()) { const TensorSeq* X = context->Input(0); ORT_ENFORCE(X != nullptr, "IdentityOp cuda: input tensor is missing."); @@ -83,6 +86,7 @@ class IdentityOp final : public CudaKernel { return Status(common::ONNXRUNTIME, common::FAIL, "IdentityOp cuda: unsupported input type."); } +#endif return Status::OK(); } }; diff --git a/onnxruntime/core/providers/cuda/tensor/nonzero_op.cc b/onnxruntime/core/providers/cuda/tensor/nonzero_op.cc index 142bcb7eeb512..98bbff9855284 100644 --- a/onnxruntime/core/providers/cuda/tensor/nonzero_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/nonzero_op.cc @@ -63,13 +63,13 @@ Status NonZero::ComputeInternal(OpKernelContext* context) const { auto x_data = reinterpret_cast::MappedType*>(x->Data()); const int number_of_blocks = NonZeroCalcBlockCount(x_size); - auto prefix_buffer = GetScratchBuffer(number_of_blocks, context->GetComputeStream()); + auto prefix_buffer = GetScratchBuffer(number_of_blocks, GetComputeStream(context)); int* prefix_counts = prefix_buffer.get(); CUDA_RETURN_IF_ERROR(NonZeroCountEachBlock(Stream(context), x_data, x_size, prefix_counts)); size_t temp_storage_bytes = 0; CUDA_RETURN_IF_ERROR(NonZeroCalcPrefixSumTempStorageBytes(Stream(context), prefix_counts, number_of_blocks, temp_storage_bytes)); - auto temp_buffer = GetScratchBuffer(temp_storage_bytes, context->GetComputeStream()); + auto temp_buffer = GetScratchBuffer(temp_storage_bytes, GetComputeStream(context)); auto d_temp_storage = temp_buffer.get(); CUDA_RETURN_IF_ERROR(NonZeroInclusivePrefixSum(Stream(context), d_temp_storage, temp_storage_bytes, prefix_counts, number_of_blocks)); diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index 3dd50c1c03cbf..73c8433bbefc8 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -115,6 +115,53 @@ namespace cuda { using PadsVector = PadBase::PadsVector; +// In the plugin build, PadBase::ComputePads is not accessible because it +// depends on CPU provider internals. ComputePadsImpl is a minimal inline +// equivalent. Keep in sync with PadBase::ComputePads in pad.h. +template +static void ComputePadsLocal(KernelContextType& ctx, + size_t data_rank, + gsl::span pads_data, + PadsVector& pads) { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + PadBase::ComputePadsImpl(ctx, data_rank, pads_data, pads); +#else + PadBase::ComputePads(ctx, data_rank, pads_data, pads); +#endif +} + +// In the plugin build, PadBase::HandleDimValueZero lives in CPU provider code +// that cannot be linked into the plugin. Inline the same validation here. +// Keep in sync with PadBase::HandleDimValueZero in pad.h. +static Status HandleDimValueZeroLocal(const Mode& mode, + const TensorShape& input_shape, + const TensorShape& output_shape) { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + switch (mode) { + case Mode::Constant: + break; + case Mode::Edge: + case Mode::Reflect: { + for (size_t i = 0, end = input_shape.NumDimensions(); i < end; ++i) { + if (input_shape[i] == 0 && output_shape[i] > 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Cannot use '", mode == Mode::Edge ? "edge" : "reflect", + "' mode to pad dimension with a value of 0. Input shape:", + input_shape); + } + } + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected mode of ", static_cast(mode)); + } + + return Status::OK(); +#else + return PadBase::HandleDimValueZero(mode, input_shape, output_shape); +#endif +} + static bool IsNCHWInputWithPaddingAlongHAndW(size_t input_rank, const TArray& lower_pads, const TArray& upper_pads) { @@ -226,7 +273,7 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { // this is expected for constant mode only, otherwise the output is empty // no error if (input_shape.Size() == 0) { - ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode_, input_shape, output_shape)); + ORT_RETURN_IF_ERROR(HandleDimValueZeroLocal(mode_, input_shape, output_shape)); if (mode_ == Mode::Constant) { const int64_t output_size = output_shape.Size(); if (output_size > 0) { diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc b/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc index 60f1a82605d26..01c7229783b33 100644 --- a/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc @@ -7,6 +7,50 @@ #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/providers/cpu/tensor/utils.h" +#ifdef BUILD_CUDA_EP_AS_PLUGIN +// PLUGIN BUILD ADAPTATION: SCATTER_ND_VALIDATE_SHAPES is defined in the CPU +// provider (scatter_nd.h) which cannot be linked into the plugin. Inline the +// same validation logic here. Keep in sync with ScatterND::ValidateShapes. +namespace onnxruntime { +namespace scatter_nd_plugin { +inline Status ValidateShapes(const TensorShape& input_shape, + const TensorShape& indice_shape, + const TensorShape& update_shape) { + auto input_rank = input_shape.NumDimensions(); + auto indice_rank = indice_shape.NumDimensions(); + auto update_rank = update_shape.NumDimensions(); + if (input_rank == 0 || indice_rank == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input tensor and indices tensor must have rank larger than 0. ", + "input shape: ", input_shape, ", indices shape: ", indice_shape); + } + auto last_indice_dimension = indice_shape[indice_rank - 1]; + if (last_indice_dimension > static_cast(input_rank)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "last dimension of indices must not be larger than rank of input tensor"); + } + auto expected_update_rank = input_rank + indice_rank - 1 - static_cast(last_indice_dimension); + if (update_rank != expected_update_rank) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "update tensor shape does not match expected shape"); + } + if (indice_shape.Slice(0, indice_rank - 1) != update_shape.Slice(0, indice_rank - 1)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "update tensor shape mismatch with indices tensor shape"); + } + if (input_shape.Slice(onnxruntime::narrow(last_indice_dimension)) != update_shape.Slice(indice_rank - 1)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "update tensor shape mismatch with input tensor shape"); + } + return Status::OK(); +} +} // namespace scatter_nd_plugin +} // namespace onnxruntime +#define SCATTER_ND_VALIDATE_SHAPES onnxruntime::scatter_nd_plugin::ValidateShapes +#else +#define SCATTER_ND_VALIDATE_SHAPES onnxruntime::ScatterND::ValidateShapes +#endif + namespace onnxruntime { namespace cuda { @@ -50,7 +94,7 @@ template static Status InitializeElementCountsAndInputDimsSpanOrGpu(int64_t last_index_dimension, const TensorShape& input_shape, ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims, CudaKernel::CudaAsyncBuffer& element_counts_and_input_dims_gpu, - KernelContextType* context) { + KernelContextType* stream) { TensorPitches input_strides(input_shape); if (last_index_dimension < 6) { @@ -66,7 +110,7 @@ static Status InitializeElementCountsAndInputDimsSpanOrGpu(int64_t last_index_di element_counts_and_input_dims_gpu.CpuPtr()[i] = input_strides[i]; element_counts_and_input_dims_gpu.CpuPtr()[i + last_index_dimension] = input_shape[i]; } - ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(context->GetComputeStream())); + ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(stream)); element_counts_and_input_dims.gpu_ptr = element_counts_and_input_dims_gpu.GpuPtr(); } return Status::OK(); @@ -82,7 +126,7 @@ Status ScatterNDDisjointAndNoReduction::ComputeInternal(OpKernelContext* context const auto& updates_shape = updates_tensor->Shape(); // Validate input shapes - ORT_RETURN_IF_ERROR(onnxruntime::ScatterND::ValidateShapes(input_shape, indices_shape, updates_shape)); + ORT_RETURN_IF_ERROR(SCATTER_ND_VALIDATE_SHAPES(input_shape, indices_shape, updates_shape)); auto* output_tensor = context->Output(0, input_shape); @@ -111,7 +155,7 @@ Status ScatterNDDisjointAndNoReduction::ComputeInternal(OpKernelContext* context ORT_RETURN_IF_ERROR(InitializeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape, element_counts_and_input_dims, element_counts_and_input_dims_gpu, - context)); + GetComputeStream(context))); ORT_RETURN_IF_ERROR(ScatterNDImpl( Stream(context), @@ -137,7 +181,7 @@ Status ScatterNDWithAtomicReduction::ComputeInternal(OpKernelContext* context) c const auto& updates_shape = updates_tensor->Shape(); // Validate input shapes - ORT_RETURN_IF_ERROR(onnxruntime::ScatterND::ValidateShapes(input_shape, indices_shape, updates_shape)); + ORT_RETURN_IF_ERROR(SCATTER_ND_VALIDATE_SHAPES(input_shape, indices_shape, updates_shape)); auto* output_tensor = context->Output(0, input_shape); @@ -163,7 +207,7 @@ Status ScatterNDWithAtomicReduction::ComputeInternal(OpKernelContext* context) c ORT_RETURN_IF_ERROR(InitializeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape, element_counts_and_input_dims, element_counts_and_input_dims_gpu, - context)); + GetComputeStream(context))); switch (reduction_) { case ScatterNDReduction::None: { diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd.h b/onnxruntime/core/providers/cuda/tensor/scatter_nd.h index 6d8bbe6f463fd..b46d896331c1d 100644 --- a/onnxruntime/core/providers/cuda/tensor/scatter_nd.h +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd.h @@ -7,7 +7,9 @@ #include "core/providers/shared_library/provider_api.h" #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/tensor/scatter_nd_kind.h" +#ifndef BUILD_CUDA_EP_AS_PLUGIN #include "core/providers/cpu/tensor/scatter_nd.h" +#endif namespace onnxruntime { namespace cuda { diff --git a/onnxruntime/core/providers/cuda/tensor/slice.cc b/onnxruntime/core/providers/cuda/tensor/slice.cc index db285ba547b6a..34de6eeac3ea2 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.cc +++ b/onnxruntime/core/providers/cuda/tensor/slice.cc @@ -177,10 +177,10 @@ Status Slice::ComputeInternal(OpKernelContext* ctx) const { if (dynamic) { TensorShapeVector input_starts, input_ends, input_axes, input_steps; ORT_RETURN_IF_ERROR(FillInputVectors(ctx, input_starts, input_ends, input_axes, input_steps)); - ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata)); + ORT_RETURN_IF_ERROR(SliceBase::PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata)); } else { - ORT_RETURN_IF_ERROR(PrepareForCompute(StartsAttribute(), EndsAttribute(), AxesAttribute(), compute_metadata)); + ORT_RETURN_IF_ERROR(SliceBase::PrepareForCompute(StartsAttribute(), EndsAttribute(), AxesAttribute(), compute_metadata)); } TensorShape output_shape(compute_metadata.output_dims_); @@ -212,8 +212,8 @@ template Status Slice::FillInputVectors(OpKernelContext* ctx, TensorShapeVector& input_starts, TensorShapeVector& input_ends, TensorShapeVector& input_axes, TensorShapeVector& input_steps) const { - return FillVectorsFromInput(*ctx->Input(1), *ctx->Input(2), ctx->Input(3), - ctx->Input(4), input_starts, input_ends, input_axes, input_steps); + return SliceBase::FillVectorsFromInput(*ctx->Input(1), *ctx->Input(2), ctx->Input(3), + ctx->Input(4), input_starts, input_ends, input_axes, input_steps); } template @@ -258,12 +258,7 @@ Status FuncSlice( SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions); - ORT_RETURN_IF_ERROR( - SliceOp::PrepareForComputeHelper(starts_span, ends_span, axes_span, steps_span, compute_metadata)); - - ORT_RETURN_IF_ERROR(SliceBase::FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_, - compute_metadata.ends_, compute_metadata.steps_, compute_metadata.p_flattened_input_dims_, - compute_metadata.p_flattened_output_dims_)); + ORT_RETURN_IF_ERROR(SliceBase::PrepareForCompute(starts_span, ends_span, axes_span, steps_span, compute_metadata)); TensorShape output_shape(compute_metadata.output_dims_); diff --git a/onnxruntime/core/providers/cuda/tensor/slice.h b/onnxruntime/core/providers/cuda/tensor/slice.h index 1a3ccd11cb1b9..050206f47217b 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.h +++ b/onnxruntime/core/providers/cuda/tensor/slice.h @@ -23,9 +23,11 @@ Status Impl(cudaStream_t stream, template class Slice : public CudaKernel, public SliceBase { public: - explicit Slice(const OpKernelInfo& info) : CudaKernel(info), SliceBase(info, dynamic, CudaProviderTag{}) {} + explicit Slice(const OpKernelInfo& info) : CudaKernel(info), + SliceBase(info, dynamic, CudaProviderTag{}) {} - Status ComputeInternal(OpKernelContext* ctx) const override; + Status + ComputeInternal(OpKernelContext* ctx) const override; private: virtual const Tensor* GetSlicedOrUnslicedTensor(OpKernelContext* ctx) const; diff --git a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h index 8780d9b365005..3a054175db9da 100644 --- a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h +++ b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h @@ -4,24 +4,134 @@ #pragma once #include "core/providers/cuda/cuda_kernel.h" +#ifndef BUILD_CUDA_EP_AS_PLUGIN #include "core/providers/cpu/tensor/space_depth_ops.h" +#endif namespace onnxruntime { namespace cuda { +#ifdef BUILD_CUDA_EP_AS_PLUGIN +// PLUGIN BUILD ADAPTATION: SpaceDepthBase (in cpu/tensor/space_depth_ops.h) +// cannot be included because it pulls in core/framework/op_kernel.h which +// conflicts with the adapter types. This inline namespace reimplements the +// validation and dimension-calculation logic. Keep in sync with SpaceDepthBase. +namespace detail { + +template +Status InputValidationsAndOutputDimsCalc(int64_t blocksize, + const Tensor& input, + int64_t& batch, + int64_t& input_depth, int64_t& input_height, int64_t& input_width, + int64_t& output_depth, int64_t& output_height, int64_t& output_width, + bool is_space_to_depth) { + const TensorShape& input_shape = input.Shape(); + + if (input_shape.NumDimensions() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "SpaceDepth ops require a 4-D input. Provided rank: ", + input_shape.NumDimensions()); + } + + batch = input_shape[0]; + if constexpr (IsNHWC) { + input_depth = input_shape[3]; + input_height = input_shape[1]; + input_width = input_shape[2]; + } else { + input_depth = input_shape[1]; + input_height = input_shape[2]; + input_width = input_shape[3]; + } + + if (is_space_to_depth) { + if ((input_height % blocksize) != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "SpaceToDepth requires input height to be a multiple of block_size"); + } + if ((input_width % blocksize) != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "SpaceToDepth requires input width to be a multiple of block_size"); + } + output_depth = input_depth * blocksize * blocksize; + output_height = input_height / blocksize; + output_width = input_width / blocksize; + } else { + if ((input_depth % (blocksize * blocksize) != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DepthToSpace requires input depth to be a multiple of (block_size * block_size)"); + } + output_depth = input_depth / blocksize / blocksize; + output_height = input_height * blocksize; + output_width = input_width * blocksize; + } + + return Status::OK(); +} + +} // namespace detail +#endif // BUILD_CUDA_EP_AS_PLUGIN + template -class SpaceToDepth final : public CudaKernel, SpaceDepthBase { +class SpaceToDepth final : public CudaKernel +#ifndef BUILD_CUDA_EP_AS_PLUGIN + , + SpaceDepthBase +#endif +{ public: - explicit SpaceToDepth(const OpKernelInfo& info) : CudaKernel(info), SpaceDepthBase(info) { + explicit SpaceToDepth(const OpKernelInfo& info) + : CudaKernel(info) +#ifndef BUILD_CUDA_EP_AS_PLUGIN + , + SpaceDepthBase(info) +#endif + { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + // Plugin builds cannot inherit from SpaceDepthBase, so extract the + // blocksize attribute directly from OpKernelInfo. + ORT_ENFORCE(info.GetAttr("blocksize", &blocksize_).IsOK(), + "Attribute blocksize is not set."); +#endif } Status ComputeInternal(OpKernelContext* context) const override; + +#ifdef BUILD_CUDA_EP_AS_PLUGIN + protected: + template + Status InputValidationsAndOutputDimsCalc(const Tensor& input, + int64_t& batch, + int64_t& input_depth, int64_t& input_height, int64_t& input_width, + int64_t& output_depth, int64_t& output_height, int64_t& output_width, + bool is_space_to_depth) const { + return detail::InputValidationsAndOutputDimsCalc( + blocksize_, input, batch, input_depth, input_height, input_width, + output_depth, output_height, output_width, is_space_to_depth); + } + + int64_t blocksize_; +#endif }; template -class DepthToSpace final : public CudaKernel, SpaceDepthBase { +class DepthToSpace final : public CudaKernel +#ifndef BUILD_CUDA_EP_AS_PLUGIN + , + SpaceDepthBase +#endif +{ public: - explicit DepthToSpace(const OpKernelInfo& info) : CudaKernel(info), SpaceDepthBase(info) { + explicit DepthToSpace(const OpKernelInfo& info) + : CudaKernel(info) +#ifndef BUILD_CUDA_EP_AS_PLUGIN + , + SpaceDepthBase(info) +#endif + { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + // Plugin builds cannot inherit from SpaceDepthBase, so extract the + // blocksize attribute directly from OpKernelInfo. + ORT_ENFORCE(info.GetAttr("blocksize", &blocksize_).IsOK(), + "Attribute blocksize is not set."); +#endif std::string mode; // if mode doesn't exist, then it is the default "DCR" mode // (or) it is an opset < 11 model for which the only mode is "DCR" mode @@ -38,6 +148,22 @@ class DepthToSpace final : public CudaKernel, SpaceDepthBase { private: bool is_dcr_ = true; + +#ifdef BUILD_CUDA_EP_AS_PLUGIN + protected: + template + Status InputValidationsAndOutputDimsCalc(const Tensor& input, + int64_t& batch, + int64_t& input_depth, int64_t& input_height, int64_t& input_width, + int64_t& output_depth, int64_t& output_height, int64_t& output_width, + bool is_space_to_depth) const { + return detail::InputValidationsAndOutputDimsCalc( + blocksize_, input, batch, input_depth, input_height, input_width, + output_depth, output_height, output_width, is_space_to_depth); + } + + int64_t blocksize_; +#endif }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/split.cc b/onnxruntime/core/providers/cuda/tensor/split.cc index ca82387600085..06b0c7e50f919 100644 --- a/onnxruntime/core/providers/cuda/tensor/split.cc +++ b/onnxruntime/core/providers/cuda/tensor/split.cc @@ -42,6 +42,66 @@ ONNX_OPERATOR_KERNEL_EX(Split, .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), Split_18); +// Self-contained split preparation that replaces SplitBase::PrepareForCompute. +// The base class method is not available in plugin builds because SplitBase +// depends on CPU provider internals. Keep logic in sync with SplitBase. +Status SplitKernel::PrepareForComputeLocal(const TensorShape& input_shape, + int num_outputs, + int64_t& axis, + int& before_dims, + int& after_dims_including_split_axis, + int& after_dims_excluding_split, + std::vector& split_sizes) const { + auto input_dims = input_shape.GetDims(); + const auto num_dimensions = gsl::narrow_cast(input_shape.NumDimensions()); + axis = HandleNegativeAxis(axis_, num_dimensions); + const int64_t split_dim_size = input_dims[onnxruntime::narrow(axis)]; + + before_dims = gsl::narrow_cast(input_shape.SizeToDimension(onnxruntime::narrow(axis))); + after_dims_including_split_axis = gsl::narrow_cast(input_shape.SizeFromDimension(onnxruntime::narrow(axis))); + after_dims_excluding_split = (axis + 1 == num_dimensions) + ? 1 + : gsl::narrow_cast(input_shape.SizeFromDimension(onnxruntime::narrow(axis + 1))); + + if (num_outputs_ != -1) { + if (num_outputs_ > split_dim_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid num_outputs value of ", num_outputs_, + ". Size of dimension being split is ", split_dim_size); + } + + int64_t size = (split_dim_size + num_outputs_ - 1) / num_outputs_; + int64_t remainder = split_dim_size % size; + + split_sizes = std::vector(num_outputs, size); + if (remainder) { + split_sizes.back() = remainder; + } + } + + if (split_sizes.empty()) { + if (split_dim_size % static_cast(num_outputs) != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input cannot be split evenly on selected axis. Input shape=", input_shape, + " Axis=", axis_, " NumOutputs=", num_outputs); + } + split_sizes = std::vector(static_cast(num_outputs), split_dim_size / num_outputs); + } else { + int64_t split_size_sum = split_size_sum_; + if (split_size_sum == -1) { + split_size_sum = std::accumulate(split_sizes.cbegin(), split_sizes.cend(), 0LL); + } + if (split_sizes.size() != static_cast(num_outputs) || split_size_sum != split_dim_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Cannot split using values in 'split' attribute. Axis=", axis_, + " Input shape=", input_shape, + " NumOutputs=", num_outputs, + " Num entries in 'split' (must equal number of outputs) was ", split_sizes.size(), + " Sum of sizes in 'split' (must equal size of selected axis) was ", split_size_sum); + } + } + + return Status::OK(); +} + Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const { const Tensor* input_tensor = ctx->Input(0); ORT_ENFORCE(input_tensor); @@ -63,13 +123,13 @@ Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const { split_sizes.assign(split_sizes_.begin(), split_sizes_.end()); } - ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape, - num_outputs, - axis, - before_dims, - block_size_including_axis_dim, - block_size_inside_axis_dim, - split_sizes)); + ORT_RETURN_IF_ERROR(PrepareForComputeLocal(input_shape, + num_outputs, + axis, + before_dims, + block_size_including_axis_dim, + block_size_inside_axis_dim, + split_sizes)); auto input_data = input_tensor->DataRaw(); @@ -129,23 +189,23 @@ Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const { block_size_inside_axis_dim, split_sizes[0], num_outputs, input_data, output_ptr_array, static_cast(input_shape.Size()))); } else { - ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu(GetComputeStream(ctx))); ORT_RETURN_IF_ERROR(SplitSameSplitDimImpl(Stream(ctx), element_size, block_size_including_axis_dim, block_size_inside_axis_dim, split_sizes[0], num_outputs, input_data, output_ptr.GpuPtr(), static_cast(input_shape.Size()))); } } else { - ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(output_ptr.CopyToGpu(GetComputeStream(ctx))); CudaAsyncBuffer split_sizes_gpu(this, split_sizes); - ORT_RETURN_IF_ERROR(split_sizes_gpu.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(split_sizes_gpu.CopyToGpu(GetComputeStream(ctx))); std::vector split_sizes_range(split_sizes); for (size_t i = 1; i < split_sizes_range.size(); ++i) { split_sizes_range[i] += split_sizes_range[i - 1]; } CudaAsyncBuffer split_sizes_range_gpu(this, split_sizes_range); - ORT_RETURN_IF_ERROR(split_sizes_range_gpu.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(split_sizes_range_gpu.CopyToGpu(GetComputeStream(ctx))); CudaAsyncBuffer axis_dimension_input_output_mapping_gpu(this, axis_dimension_input_output_mapping); - ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(axis_dimension_input_output_mapping_gpu.CopyToGpu(GetComputeStream(ctx))); ORT_RETURN_IF_ERROR(SplitImpl(Stream(ctx), element_size, block_size_including_axis_dim, block_size_inside_axis_dim, split_sizes_gpu.GpuPtr(), split_sizes_range_gpu.GpuPtr(), axis_dimension_input_output_mapping_gpu.GpuPtr(), num_outputs, input_data, diff --git a/onnxruntime/core/providers/cuda/tensor/split.h b/onnxruntime/core/providers/cuda/tensor/split.h index 00d80f65b79c0..6820d58d9c953 100644 --- a/onnxruntime/core/providers/cuda/tensor/split.h +++ b/onnxruntime/core/providers/cuda/tensor/split.h @@ -13,6 +13,15 @@ class SplitKernel : public CudaKernel, public SplitBase { SplitKernel(const OpKernelInfo& info, uint32_t opset) : CudaKernel(info), SplitBase(info, opset) {} Status ComputeInternal(OpKernelContext* context) const override; + + private: + Status PrepareForComputeLocal(const TensorShape& input_shape, + int num_outputs, + int64_t& axis, + int& before_dims, + int& after_dims_including_split_axis, + int& after_dims_excluding_split, + std::vector& split_sizes) const; }; // versions 2, 11 and 13 diff --git a/onnxruntime/core/providers/cuda/tensor/tile.cc b/onnxruntime/core/providers/cuda/tensor/tile.cc index a0dd78b5e60e1..b07ca4f61cca0 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile.cc +++ b/onnxruntime/core/providers/cuda/tensor/tile.cc @@ -9,6 +9,45 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace cuda { +namespace { + +#ifdef BUILD_CUDA_EP_AS_PLUGIN +// PLUGIN BUILD ADAPTATION: TileOp::IsTileMemcpy (CPU provider) cannot be +// linked into the plugin. Reimplement the memcpy fast-path check here. +// Keep in sync with TileOp::IsTileMemcpy in cpu/tensor/tile.cc. +bool IsTileMemcpyForPlugin(const TensorShape& input_shape, + const int64_t* repeats, + size_t rank, + /*out*/ bool& is_batched_memcpy, + /*out*/ size_t& num_of_elements_per_batch, + /*out*/ size_t& num_of_copies_per_batch, + /*out*/ size_t& num_of_batch_copies) { + for (int64_t i = static_cast(rank) - 1; i >= 0; --i) { + if (repeats[i] != 1) { + if (input_shape.SizeToDimension(onnxruntime::narrow(i)) == 1) { + num_of_copies_per_batch = 1; + for (int64_t j = 0; j <= i; ++j) { + num_of_copies_per_batch *= onnxruntime::narrow(repeats[onnxruntime::narrow(j)]); + } + is_batched_memcpy = false; + return true; + } else if (i == 1) { + num_of_elements_per_batch = static_cast(input_shape.SizeFromDimension(1)); + num_of_copies_per_batch = onnxruntime::narrow(repeats[onnxruntime::narrow(i)]); + num_of_batch_copies = onnxruntime::narrow(repeats[0]); + is_batched_memcpy = true; + return true; + } else { + break; + } + } + } + return false; +} +#endif + +} // namespace + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Tile, kOnnxDomain, @@ -109,6 +148,15 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const { size_t num_of_elements_per_batch = 1; size_t num_of_copies_per_batch = 1; size_t num_of_batch_copies = 1; +#ifdef BUILD_CUDA_EP_AS_PLUGIN + if (IsTileMemcpyForPlugin(input_shape, + repeats, + input_rank, + is_batched_memcpy, + num_of_elements_per_batch, + num_of_copies_per_batch, + num_of_batch_copies)) { +#else if (TileOp::IsTileMemcpy(input_shape, repeats, input_rank, @@ -116,6 +164,7 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const { num_of_elements_per_batch, num_of_copies_per_batch, num_of_batch_copies)) { +#endif if (!is_batched_memcpy) { switch (element_size) { CASE_TILE_MEMCPY(float); diff --git a/onnxruntime/core/providers/cuda/tensor/transpose.cc b/onnxruntime/core/providers/cuda/tensor/transpose.cc index 51aa46df18bc8..7df9238b3acb1 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose.cc +++ b/onnxruntime/core/providers/cuda/tensor/transpose.cc @@ -99,7 +99,8 @@ Status Transpose::DoTranspose(const Transpose& transpose_kernel, onnxruntime::Stream* ort_stream, const gsl::span& permutations, const Tensor& input, Tensor& output) { cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - const cublasHandle_t cublas_handle = transpose_kernel.GetCublasHandleOrDefault(ort_stream); + const cublasHandle_t cublas_handle = + ort_stream ? transpose_kernel.GetCublasHandle(ort_stream) : transpose_kernel.DefaultCublasHandle(); return Transpose::DoTranspose(transpose_kernel.GetDeviceProp(), cuda_stream, cublas_handle, diff --git a/onnxruntime/core/providers/cuda/tensor/unsqueeze.cc b/onnxruntime/core/providers/cuda/tensor/unsqueeze.cc index 411f94f31c7ba..64f951c70fe15 100644 --- a/onnxruntime/core/providers/cuda/tensor/unsqueeze.cc +++ b/onnxruntime/core/providers/cuda/tensor/unsqueeze.cc @@ -6,6 +6,61 @@ namespace onnxruntime { namespace cuda { +namespace { + +// PLUGIN BUILD ADAPTATION: PrepareCompute() is inherited from UnsqueezeBase +// in the non-plugin build, but the base class cannot be used in plugin builds +// because it depends on core/framework/op_kernel.h internals. This standalone +// function reimplements the same axes-parsing and output-shape computation. +Status PrepareComputeForPlugin(OpKernelContext* ctx, UnsqueezeBase::Prepare& p, const TensorShapeVector& axes_attr) { + const auto* input = ctx->Input(0); + ORT_ENFORCE(input != nullptr); + auto& input_tensor = *input; + + TensorShapeVector axes; + size_t num_inputs = static_cast(ctx->InputCount()); + if (num_inputs == 2) { + const Tensor* axes_tensor = ctx->Input(1); + ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 || + axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a scalar or a 1-D tensor."); + auto data_span = axes_tensor->DataAsSpan(); + axes.assign(data_span.begin(), data_span.end()); + } else { + axes.assign(axes_attr.begin(), axes_attr.end()); + } + + TensorShapeVector output_dims(axes.size() + input_tensor.Shape().NumDimensions(), 0); + for (int64_t axis : axes) { + axis = HandleNegativeAxis(axis, onnxruntime::narrow(output_dims.size())); + if (axis < 0 || axis >= static_cast(output_dims.size())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has an out of range axis"); + } + + auto axis_index = onnxruntime::narrow(axis); + if (output_dims[axis_index] != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has a duplicate axis"); + } + output_dims[axis_index] = 1; + } + + auto begin = input_tensor.Shape().GetDims().begin(); + for (auto& axis_size : output_dims) { + if (axis_size == 0) { + axis_size = *begin++; + } + } + + TensorShape output_shape(output_dims); + p.output_tensor = ctx->Output(0, output_shape); + ORT_ENFORCE(p.output_tensor != nullptr); + p.input_tensor = &input_tensor; + return Status::OK(); +} + +} // namespace + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Unsqueeze, kOnnxDomain, @@ -85,7 +140,11 @@ ONNX_OPERATOR_KERNEL_EX( Status Unsqueeze::ComputeInternal(OpKernelContext* ctx) const { Prepare p; +#ifdef BUILD_CUDA_EP_AS_PLUGIN + ORT_RETURN_IF_ERROR(PrepareComputeForPlugin(ctx, p, axes_)); +#else ORT_RETURN_IF_ERROR(PrepareCompute(ctx, p)); +#endif const void* input = p.input_tensor->DataRaw(); void* output = p.output_tensor->MutableDataRaw(); diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index e2c08618264dd..36e89cd38e72b 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -42,14 +42,15 @@ REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9); template Upsample::Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { - if (UpsampleBase::antialias_) { - // Copy the table on DEVICE +#ifndef BUILD_CUDA_EP_AS_PLUGIN + if (antialias_) { const uint8_t* lookup_table = GetLookupTableShared(); auto alloc = info.GetAllocator(OrtMemTypeDefault); shared_lookup_table_ondevice_ = IAllocator::MakeUniquePtr(std::move(alloc), kLookupTableSize); CUDA_CALL_THROW(cudaMemcpyAsync(shared_lookup_table_ondevice_.get(), lookup_table, kLookupTableSize, cudaMemcpyHostToDevice, nullptr)); } +#endif } template @@ -104,8 +105,21 @@ Status Upsample::BaseCompute(OpKernelContext* context, } if (antialias_) { +#ifdef BUILD_CUDA_EP_AS_PLUGIN + // Plugin builds copy the lookup table to a per-call scratch buffer because + // plugin kernels cannot safely hold persistent device pointers across sessions. + // Non-plugin builds cache the table in a member IAllocator::UniquePtr. + const uint8_t* lookup_table = GetLookupTableShared(); + auto shared_lookup_table_ondevice_buffer = GetScratchBuffer(kLookupTableSize, GetComputeStream(context)); + CUDA_CALL_THROW(cudaMemcpyAsync(shared_lookup_table_ondevice_buffer.get(), lookup_table, kLookupTableSize, + cudaMemcpyHostToDevice, Stream(context))); + const auto* shared_lookup_table_ondevice = shared_lookup_table_ondevice_buffer.get(); +#else + const auto* shared_lookup_table_ondevice = shared_lookup_table_ondevice_.get(); +#endif + TempSpaceAllocateFunc allocate_temp_space = [&](size_t bytes_size) { - return GetScratchBuffer(bytes_size, context->GetComputeStream()); + return GetScratchBuffer(bytes_size, GetComputeStream(context)); }; std::optional extrapolation_value; @@ -170,7 +184,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, extrapolation_value, exclude_outside_, allocate_temp_space, - shared_lookup_table_ondevice_.get(), + shared_lookup_table_ondevice, reinterpret_cast(X->Data()), reinterpret_cast(Y->MutableData()), output_count); @@ -213,7 +227,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, extrapolation_value, exclude_outside_, allocate_temp_space, - shared_lookup_table_ondevice_.get(), + shared_lookup_table_ondevice, reinterpret_cast(X->Data()), reinterpret_cast(Y->MutableData()), output_count); @@ -259,7 +273,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, extrapolation_value, exclude_outside_, allocate_temp_space, - shared_lookup_table_ondevice_.get(), + shared_lookup_table_ondevice, reinterpret_cast(X->Data()), reinterpret_cast(Y->MutableData()), output_count); @@ -274,7 +288,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, TArray scales_vals(scales); size_t temp_buffer_size = CalcResizeBufferSize(mode_, output_dims); - auto dims_mapping_buffer = GetScratchBuffer(temp_buffer_size, context->GetComputeStream()); + auto dims_mapping_buffer = GetScratchBuffer(temp_buffer_size, GetComputeStream(context)); void* dims_mapping = reinterpret_cast(dims_mapping_buffer.get()); ResizeImpl(Stream(context), mode_, rank, input_shape, output_shape, input_strides, output_div_pitches, scales_vals, roi_vals, @@ -341,7 +355,7 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { InlinedVector scales_array(input_dims.size()); // opset < 10 - if (OpKernel::Node().InputDefs().size() == 1) { + if (context->InputCount() == 1) { // Compute output shape from scales attributes and input dims scales_array = scales_; diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.h b/onnxruntime/core/providers/cuda/tensor/upsample.h index 50597e0fba1b9..152862da0fdbd 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.h +++ b/onnxruntime/core/providers/cuda/tensor/upsample.h @@ -19,8 +19,10 @@ class Upsample : public UpsampleBase, public CudaKernel { Status BaseCompute(OpKernelContext* context, gsl::span roi, gsl::span scales, gsl::span output_dims) const; +#ifndef BUILD_CUDA_EP_AS_PLUGIN private: IAllocatorUniquePtr shared_lookup_table_ondevice_; +#endif }; } // namespace cuda diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 0e3ecdf2385c7..2ba0d24c55fef 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -7,6 +7,18 @@ // switching providers to be runnable as shared libraries. The interfaces will become more tightly integrated into the core code. #pragma once + +// When building the CUDA EP as a plugin (BUILD_CUDA_EP_AS_PLUGIN), +// skip all SHARED_PROVIDER type redefinitions. The adapter header (ep/adapters.h) +// provides its own facade types, and the SHARED_PROVIDER bridge would conflict. +#ifdef BUILD_CUDA_EP_AS_PLUGIN + +// Plugin build: provider_api.h is a complete no-op. We do NOT define +// SHARED_PROVIDER so that #ifndef SHARED_PROVIDER guards in framework +// headers (op_kernel.h, etc.) remain active. + +#else // !BUILD_CUDA_EP_AS_PLUGIN — normal SHARED_PROVIDER path + #define SHARED_PROVIDER 1 #ifdef _WIN32 @@ -307,7 +319,7 @@ constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; template -using IAllocatorUniquePtr = std::unique_ptr>; +using IAllocatorUniquePtr = std::unique_ptr >; inline OrtStatus* CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept { return g_host->CreateStatus(code, msg); } @@ -418,7 +430,7 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2; } -inline std::vector> +inline std::vector > CreateSupportedPartitions(const GraphViewer& graph_viewer, const std::unordered_set& supported_nodes, const std::unordered_set& stop_ops, @@ -466,7 +478,7 @@ inline Status ConvertInMemoryDataToInline(Graph& graph, const std::string& name) } // namespace graph_utils namespace QDQ { -inline std::pair>, std::unordered_map> +inline std::pair >, std::unordered_map > GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) { return g_host->QDQ__GetAllNodeUnits(graph_viewer, logger); } @@ -511,3 +523,5 @@ inline T* Initializer::data() { #define LOGS_DEFAULT(severity) \ LOGS_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime) + +#endif // !BUILD_CUDA_EP_AS_PLUGIN diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc index e1c883f960dde..ebe027c9efa95 100644 --- a/onnxruntime/python/onnxruntime_pybind_module.cc +++ b/onnxruntime/python/onnxruntime_pybind_module.cc @@ -13,6 +13,7 @@ #include "core/session/provider_bridge_ort.h" #include "core/framework/provider_options.h" #include "core/platform/env.h" +#include "core/common/inlined_containers.h" namespace onnxruntime { namespace python { @@ -113,7 +114,30 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { throw pybind11::import_error(st.ErrorMessage()); // move it out of shared method since training build has a little different behavior. m.def( - "get_available_providers", []() -> const std::vector& { return GetAvailableExecutionProviderNames(); }, + "get_available_providers", []() -> std::vector { + auto available = GetAvailableExecutionProviderNames(); +#if !defined(ORT_MINIMAL_BUILD) + const auto& ep_devices = GetEnv().GetOrtEpDevices(); + available.reserve(available.size() + ep_devices.size()); + + InlinedHashSet existing; + existing.reserve(available.size() + ep_devices.size()); + for (const auto& ep_name : available) { + existing.insert(ep_name); + } + + for (const OrtEpDevice* ep_device : ep_devices) { + if (!ep_device) { + continue; + } + + if (existing.insert(ep_device->ep_name).second) { + available.push_back(ep_device->ep_name); + } + } +#endif + return available; + }, "Return list of available Execution Providers in this installed version of Onnxruntime. " "The order of elements represents the default priority order of Execution Providers " "from highest to lowest."); diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index 0f51794f12e48..1b5b888110e27 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -4,6 +4,10 @@ #include "python/onnxruntime_pybind_state_common.h" #include "core/framework/kernel_registry.h" +#if !defined(ORT_MINIMAL_BUILD) +#include "core/session/utils.h" +#include "core/session/abi_devices.h" +#endif #include namespace py = pybind11; @@ -93,6 +97,58 @@ void addGlobalSchemaFunctions(pybind11::module& m) { return result; }, "Return a vector of KernelDef for all registered OpKernels"); + +#if !defined(ORT_MINIMAL_BUILD) + m.def( + "get_registered_ep_kernel_defs", + [](const std::string& ep_name) -> std::vector { + std::vector result; + auto& env = GetEnv(); + + // Collect all OrtEpDevice pointers matching the requested EP name. + std::vector selected_devices; + for (const OrtEpDevice* device : env.GetOrtEpDevices()) { + if (device && device->ep_name == ep_name) { + selected_devices.push_back(device); + break; // one device is sufficient to create the factory and query kernels + } + } + + if (selected_devices.empty()) { + throw std::runtime_error( + "No devices found for EP '" + ep_name + + "'. Ensure the plugin EP library is registered via register_execution_provider_library()."); + } + + // Create a factory for the plugin EP. + std::unique_ptr factory; + auto status = CreateIExecutionProviderFactoryForEpDevices(env, selected_devices, factory); + if (!status.IsOK()) { + throw std::runtime_error("Failed to create EP factory for '" + ep_name + "': " + status.ToString()); + } + + // Create an EP instance with default session options. + OrtSessionOptions ort_session_options{}; + const auto& logger = *env.GetLoggingManager()->DefaultLogger().ToExternal(); + auto provider = factory->CreateProvider(ort_session_options, logger); + if (!provider) { + throw std::runtime_error("Failed to create EP instance for '" + ep_name + "'."); + } + + // Extract kernel defs from the EP's kernel registry. + auto kernel_registry = provider->GetKernelRegistry(); + if (kernel_registry) { + for (const auto& entry : kernel_registry->GetKernelCreateMap()) { + result.emplace_back(*(entry.second.kernel_def)); + } + } + + return result; + }, + py::arg("ep_name"), + "Return a vector of KernelDef for a dynamically registered plugin EP.\n" + "The EP must be loaded first via register_execution_provider_library()."); +#endif } void addOpKernelSubmodule(py::module& m) { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 937a96a619822..212272647f8e3 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -573,6 +573,71 @@ void RegisterNvTensorRTRtxPluginsAsCustomOps(PySessionOptions& so, const Provide } #endif +#if !defined(ORT_MINIMAL_BUILD) +// Find a registered plugin EP device matching the given EP name and optional device_id from provider options. +// Returns nullptr if no matching device is found. +static const OrtEpDevice* FindRegisteredPluginEpDevice( + const std::string& ep_name, + const ProviderOptions* provider_options) { + const auto& ep_devices = GetEnv().GetOrtEpDevices(); + if (ep_devices.empty()) { + return nullptr; + } + + bool has_requested_device_id = false; + int requested_device_id = 0; + if (provider_options != nullptr) { + if (const auto device_id_it = provider_options->find("device_id"); device_id_it != provider_options->end()) { + try { + requested_device_id = std::stoi(device_id_it->second); + has_requested_device_id = requested_device_id >= 0; + } catch (const std::exception&) { + LOGS_DEFAULT(WARNING) << "Invalid device_id value '" << device_id_it->second + << "' in provider options for EP '" << ep_name << "'; ignoring."; + } + } + } + + for (const OrtEpDevice* ep_device : ep_devices) { + if (!ep_device || ep_device->ep_name != ep_name) { + continue; + } + + if (has_requested_device_id) { + Ort::ConstEpDevice current_device(ep_device); + std::optional current_device_id{}; + if (const char* device_id = current_device.EpOptions().GetValue("device_id"); device_id != nullptr) { + try { + current_device_id = std::stoi(device_id); + } catch (const std::exception&) { + } + } + + if (!current_device_id.has_value()) { + if (const char* device_id = current_device.EpMetadata().GetValue("cuda_device_id"); device_id != nullptr) { + try { + current_device_id = std::stoi(device_id); + } catch (const std::exception&) { + } + } + } + + if (!current_device_id.has_value()) { + current_device_id = static_cast(current_device.Device().DeviceId()); + } + + if (*current_device_id != requested_device_id) { + continue; + } + } + + return ep_device; + } + + return nullptr; +} +#endif + /** * Creates an IExecutionProviderFactory instance of the specified type. * @param session_options The session options. @@ -585,6 +650,48 @@ static std::shared_ptr CreateExecutionProviderFactory const SessionOptions& session_options, const std::string& type, const ProviderOptionsMap& provider_options_map) { +#if !defined(ORT_MINIMAL_BUILD) + auto get_registered_plugin_ep_devices = [&]() -> InlinedVector { + InlinedVector selected_devices; + + const ProviderOptions* provider_options = nullptr; + if (const auto provider_it = provider_options_map.find(type); provider_it != provider_options_map.end()) { + provider_options = &provider_it->second; + } + + const OrtEpDevice* selected_device = FindRegisteredPluginEpDevice(type, provider_options); + if (selected_device == nullptr) { + if (provider_options != nullptr) { + if (const auto device_id_it = provider_options->find("device_id"); device_id_it != provider_options->end()) { + LOGS_DEFAULT(WARNING) << "No registered plugin EP device found for '" << type + << "' with device_id=" << device_id_it->second; + } + } + return selected_devices; + } + + selected_devices.push_back(selected_device); + return selected_devices; + }; + + auto try_create_registered_plugin_factory = [&]() -> std::shared_ptr { + auto selected_devices = get_registered_plugin_ep_devices(); + if (selected_devices.empty()) { + return nullptr; + } + + std::unique_ptr ep_factory; + const auto status = onnxruntime::CreateIExecutionProviderFactoryForEpDevices(GetEnv(), selected_devices, ep_factory); + if (!status.IsOK()) { + LOGS_DEFAULT(WARNING) << "Failed to create dynamic EP factory for '" << type + << "' from registered EP devices: " << status; + return nullptr; + } + + return std::shared_ptr(std::move(ep_factory)); + }; +#endif + if (type == kCpuExecutionProvider) { return onnxruntime::CPUProviderFactoryCreator::Create( session_options.enable_cpu_mem_arena); @@ -1203,6 +1310,13 @@ static std::shared_ptr CreateExecutionProviderFactory << " to ensure all dependencies are met."; #endif } else { +#if !defined(ORT_MINIMAL_BUILD) + // Try EPs dynamically registered via register_execution_provider_library(). + if (auto ep_factory = try_create_registered_plugin_factory(); ep_factory) { + return ep_factory; + } +#endif + // check whether it is a dynamic load EP: const auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { @@ -1241,7 +1355,48 @@ std::unique_ptr CreateExecutionProviderInstance(const Sessio const ProviderOptionsMap& provider_options_map) { auto ep_factory = CreateExecutionProviderFactoryInstance(session_options, type, provider_options_map); if (ep_factory) { - return ep_factory->CreateProvider(); + const auto& default_logger = GetEnv().GetLoggingManager()->DefaultLogger(); + OrtSessionOptions ort_session_options; + ort_session_options.value = session_options; + +#if !defined(ORT_MINIMAL_BUILD) + auto add_registered_plugin_ep_options_to_session = [&]() -> Status { + const ProviderOptions* provider_options = nullptr; + if (const auto provider_it = provider_options_map.find(type); provider_it != provider_options_map.end()) { + provider_options = &provider_it->second; + } + + if (provider_options == nullptr || provider_options->empty()) { + return Status::OK(); + } + + const OrtEpDevice* selected_device = FindRegisteredPluginEpDevice(type, provider_options); + if (selected_device == nullptr) { + return Status::OK(); + } + + InlinedVector selected_devices; + selected_devices.push_back(selected_device); + + std::vector ep_option_keys; + std::vector ep_option_vals; + ep_option_keys.reserve(provider_options->size()); + ep_option_vals.reserve(provider_options->size()); + for (const auto& [key, val] : *provider_options) { + ep_option_keys.push_back(key.c_str()); + ep_option_vals.push_back(val.c_str()); + } + + return AddEpOptionsToSessionOptions(selected_devices, ep_option_keys, ep_option_vals, ort_session_options.value); + }; + + auto status = add_registered_plugin_ep_options_to_session(); + if (!status.IsOK()) { + ORT_THROW("Error applying registered plugin EP options: ", status); + } +#endif + + return ep_factory->CreateProvider(ort_session_options, *default_logger.ToExternal()); } return nullptr; } @@ -1263,6 +1418,31 @@ static Status AddExplicitEpFactory(PySessionOptions& py_sess_options, const std: return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to add provider of type '", provider_type, "' to SessionOptions. Provider configuration is not supported."); } + +#if !defined(ORT_MINIMAL_BUILD) + if (!provider_options.empty()) { + const OrtEpDevice* selected_device = FindRegisteredPluginEpDevice(provider_type, &provider_options); + if (selected_device != nullptr) { + InlinedVector selected_devices; + selected_devices.push_back(selected_device); + + std::vector ep_option_keys; + std::vector ep_option_vals; + ep_option_keys.reserve(provider_options.size()); + ep_option_vals.reserve(provider_options.size()); + for (const auto& [key, val] : provider_options) { + ep_option_keys.push_back(key.c_str()); + ep_option_vals.push_back(val.c_str()); + } + + ORT_RETURN_IF_ERROR(AddEpOptionsToSessionOptions(selected_devices, + ep_option_keys, + ep_option_vals, + py_sess_options.value)); + } + } +#endif + py_sess_options.provider_factories.push_back(std::move(ep_factory)); return Status::OK(); } diff --git a/onnxruntime/test/framework/dynamic_plugin_ep_test.cc b/onnxruntime/test/framework/dynamic_plugin_ep_test.cc new file mode 100644 index 0000000000000..be2225ee66b80 --- /dev/null +++ b/onnxruntime/test/framework/dynamic_plugin_ep_test.cc @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP) + +#include "core/framework/execution_provider.h" +#include "test/unittest_util/test_dynamic_plugin_ep.h" + +#include +#include + +#include "test/util/include/asserts.h" + +#if defined(USE_CUDA) && defined(ORT_USE_EP_API_ADAPTERS) +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/providers/cuda/plugin/cuda_kernel_adapter.h" +#endif + +namespace onnxruntime::test { + +namespace dynamic_plugin_ep_test_infra = onnxruntime::test::dynamic_plugin_ep_infra; + +TEST(DynamicPluginEpInfraTest, ParseInitializationConfigParsesOptionalFields) { + constexpr std::string_view kConfigJson = R"json( +{ + "ep_library_registration_name": "CudaPluginExecutionProvider", + "ep_library_path": "/tmp/libonnxruntime_providers_cuda_plugin.so", + "selected_ep_device_indices": [0, 2], + "default_ep_options": { + "ep.cuda.use_tf32": "1", + "ep.cuda.prefer_nhwc_layout": "1" + }, + "tests_to_skip": [ + "CudaTests.SkipMe", + "GraphTests.SkipMeToo" + ] +} +)json"; + + dynamic_plugin_ep_test_infra::InitializationConfig config{}; + ASSERT_STATUS_OK(dynamic_plugin_ep_test_infra::ParseInitializationConfig(kConfigJson, config)); + + EXPECT_EQ(config.ep_library_registration_name, "CudaPluginExecutionProvider"); + EXPECT_EQ(config.ep_library_path, "/tmp/libonnxruntime_providers_cuda_plugin.so"); + EXPECT_TRUE(config.selected_ep_name.empty()); + EXPECT_THAT(config.selected_ep_device_indices, ::testing::ElementsAre(0u, 2u)); + EXPECT_THAT(config.default_ep_options, + ::testing::UnorderedElementsAre( + ::testing::Pair("ep.cuda.prefer_nhwc_layout", "1"), + ::testing::Pair("ep.cuda.use_tf32", "1"))); + EXPECT_THAT(config.tests_to_skip, + ::testing::ElementsAre("CudaTests.SkipMe", "GraphTests.SkipMeToo")); +} + +TEST(DynamicPluginEpInfraTest, ParseInitializationConfigDefaultsUnsetOptionalFields) { + constexpr std::string_view kConfigJson = R"json( +{ + "ep_library_registration_name": "ExamplePluginEP", + "ep_library_path": "/tmp/libexample_plugin_ep.so", + "selected_ep_name": "ExampleExecutionProvider" +} +)json"; + + dynamic_plugin_ep_test_infra::InitializationConfig config{}; + ASSERT_STATUS_OK(dynamic_plugin_ep_test_infra::ParseInitializationConfig(kConfigJson, config)); + + EXPECT_EQ(config.ep_library_registration_name, "ExamplePluginEP"); + EXPECT_EQ(config.ep_library_path, "/tmp/libexample_plugin_ep.so"); + EXPECT_EQ(config.selected_ep_name, "ExampleExecutionProvider"); + EXPECT_TRUE(config.selected_ep_device_indices.empty()); + EXPECT_TRUE(config.default_ep_options.empty()); + EXPECT_TRUE(config.tests_to_skip.empty()); +} + +TEST(DynamicPluginEpInfraTest, ParseInitializationConfigRejectsMissingRequiredFields) { + constexpr std::string_view kConfigJson = R"json( +{ + "ep_library_registration_name": "CudaPluginExecutionProvider" +} +)json"; + + dynamic_plugin_ep_test_infra::InitializationConfig config{}; + ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(dynamic_plugin_ep_test_infra::ParseInitializationConfig(kConfigJson, config), + "JSON parse error"); +} + +TEST(DynamicPluginEpInfraTest, UninitializedStateReturnsSafeDefaults) { + dynamic_plugin_ep_test_infra::Shutdown(); + + EXPECT_FALSE(dynamic_plugin_ep_test_infra::IsInitialized()); + EXPECT_EQ(dynamic_plugin_ep_test_infra::MakeEp(), nullptr); + EXPECT_FALSE(dynamic_plugin_ep_test_infra::GetEpName().has_value()); + EXPECT_TRUE(dynamic_plugin_ep_test_infra::GetTestsToSkip().empty()); + + dynamic_plugin_ep_test_infra::Shutdown(); + + EXPECT_FALSE(dynamic_plugin_ep_test_infra::IsInitialized()); + EXPECT_FALSE(dynamic_plugin_ep_test_infra::GetEpName().has_value()); + EXPECT_TRUE(dynamic_plugin_ep_test_infra::GetTestsToSkip().empty()); +} + +#if defined(USE_CUDA) && defined(ORT_USE_EP_API_ADAPTERS) +TEST(DynamicPluginEpInfraTest, CudaKernelAdapterRuntimeConfigExposesFuseConvBiasAndSdpaKernel) { + onnxruntime::CUDAExecutionProvider provider{"CudaPluginExecutionProvider"}; + auto& config = onnxruntime::cuda::detail::GetCudaKernelAdapterRuntimeConfigForProvider(&provider); + config.fuse_conv_bias = true; + config.sdpa_kernel = static_cast(onnxruntime::contrib::attention::AttentionBackend::MATH); + + EXPECT_TRUE(provider.IsFuseConvBias()); + + const auto* attention_kernel_options = provider.GetAttentionKernelOptions(); + EXPECT_TRUE(attention_kernel_options->UseUnfusedAttention()); + EXPECT_FALSE(attention_kernel_options->UseFlashAttention()); + EXPECT_FALSE(attention_kernel_options->UseEfficientAttention()); + EXPECT_FALSE(attention_kernel_options->UseCudnnFlashAttention()); +} + +TEST(DynamicPluginEpInfraTest, CudaKernelAdapterTryBytesForCountDetectsOverflow) { + size_t bytes = 0; + EXPECT_FALSE(onnxruntime::cuda::detail::TryBytesForCount(std::numeric_limits::max(), 2, bytes)); +} + +TEST(DynamicPluginEpInfraTest, CudaKernelAdapterTryBytesForCountPreservesRawByteCounts) { + size_t bytes = 0; + ASSERT_TRUE(onnxruntime::cuda::detail::TryBytesForCount(123, 0, bytes)); + EXPECT_EQ(bytes, size_t{123}); +} + +TEST(DynamicPluginEpInfraTest, CudaKernelAdapterTryBytesForCountNormalCase) { + size_t bytes = 0; + ASSERT_TRUE(onnxruntime::cuda::detail::TryBytesForCount(10, 4, bytes)); + EXPECT_EQ(bytes, size_t{40}); +} +#endif + +} // namespace onnxruntime::test + +#endif // !defined(ORT_MINIMAL_BUILD) && defined(ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP) diff --git a/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py b/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py new file mode 100644 index 0000000000000..ebba84ccd0c27 --- /dev/null +++ b/onnxruntime/test/python/transformers/cuda_plugin_ep_helper.py @@ -0,0 +1,194 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import logging +import os +import sys +from importlib.metadata import PackageNotFoundError, distribution +from pathlib import Path + +import torch + +import onnxruntime as onnxrt + +CUDA_PLUGIN_EP_NAME = "CudaPluginExecutionProvider" +enable_debug_print = False +logger = logging.getLogger(__name__) + + +class _CudaPluginRegistrationState: + attempted = False + registered = False + + +def should_test_with_cuda_plugin_ep(default_value: bool = True) -> bool: + return os.getenv("ORT_TEST_CUDA_PLUGIN_EP", "1" if default_value else "0") == "1" + + +def _get_package_root(package_name: str, directory_name: str | None = None): + root_directory_name = directory_name or package_name + try: + dist = distribution(package_name) + files = dist.files or [] + + for file in files: + if file.name.endswith("__init__.py") and root_directory_name in file.parts: + return file.locate().parent + + if not directory_name: + for file in files: + if file.name.endswith("__init__.py"): + return file.locate().parent + except PackageNotFoundError: + # Some test environments only have an in-tree build, not an installed wheel. + pass + + return None + + +def _is_cuda_plugin_ep_built() -> bool: + build_info = onnxrt.get_build_info() + if ", cuda-plugin-ep=" in build_info: + return True + + ep_lib_path = os.environ.get("ORT_CUDA_PLUGIN_PATH", "") + if ep_lib_path and os.path.exists(ep_lib_path): + return True + + detected_path = _get_default_cuda_plugin_ep_path() + return bool(detected_path and os.path.exists(detected_path)) + + +def _get_cuda_plugin_library_name() -> str: + if sys.platform == "win32": + return "onnxruntime_providers_cuda_plugin.dll" + + if sys.platform == "darwin": + return "libonnxruntime_providers_cuda_plugin.dylib" + + return "libonnxruntime_providers_cuda_plugin.so" + + +def _get_default_cuda_plugin_ep_path() -> str | None: + library_name = _get_cuda_plugin_library_name() + + # 1) Match currently imported onnxruntime module first to avoid ABI mismatch. + loaded_onnxruntime_root = Path(onnxrt.__file__).resolve().parent + loaded_candidate = loaded_onnxruntime_root / "capi" / library_name + if loaded_candidate.exists(): + return str(loaded_candidate) + + # 2) Installed wheel location. + for package_name in ("onnxruntime-gpu", "onnxruntime"): + package_root = _get_package_root(package_name, "onnxruntime") + if package_root: + candidate = os.path.join(str(package_root), "capi", library_name) + if os.path.exists(candidate): + return candidate + + # 3) In-tree build location fallback. Search under the repo build dir so we + # can handle different platforms/configurations without hard-coding Release/.so. + # This assumes that user only builds in one configuration. + # Recommend to use ORT_CUDA_PLUGIN_PATH if building in multiple configurations. + repo_root = Path(__file__).resolve().parents[4] + build_root = repo_root / "build" + if not build_root.exists(): + return None + + matches = [path for path in build_root.rglob(library_name) if "CMakeFiles" not in path.parts] + if matches: + + def _sort_key(path: Path) -> tuple[int, int, str]: + path_str = str(path) + if "Release" in path.parts: + config_rank = 0 + elif "RelWithDebInfo" in path.parts: + config_rank = 1 + elif "Debug" in path.parts: + config_rank = 2 + else: + config_rank = 3 + + return (config_rank, len(path.parts), path_str) + + return str(sorted(matches, key=_sort_key)[0]) + + return None + + +def ensure_cuda_plugin_ep_registered(default_test_with_cuda_plugin_ep: bool = True) -> bool: + if _CudaPluginRegistrationState.registered: + return True + + if not should_test_with_cuda_plugin_ep(default_test_with_cuda_plugin_ep): + return False + + if not _is_cuda_plugin_ep_built(): + return False + + ep_lib_path = os.environ.get("ORT_CUDA_PLUGIN_PATH", "") + if not ep_lib_path: + detected_path = _get_default_cuda_plugin_ep_path() + ep_lib_path = detected_path if detected_path else "" + + if not ep_lib_path or not os.path.exists(ep_lib_path): + if enable_debug_print: + print(f"CUDA Plugin EP library not found: {ep_lib_path}") + return False + + _CudaPluginRegistrationState.attempted = True + + try: + onnxrt.register_execution_provider_library(CUDA_PLUGIN_EP_NAME, ep_lib_path) + _CudaPluginRegistrationState.registered = True + except Exception as e: + if "already registered" in str(e).lower(): + _CudaPluginRegistrationState.registered = True + else: + try: + providers = {device.ep_name for device in onnxrt.get_ep_devices()} + except Exception: + providers = set() + + _CudaPluginRegistrationState.registered = CUDA_PLUGIN_EP_NAME in providers + + if enable_debug_print and not _CudaPluginRegistrationState.registered: + print(f"Failed to register CUDA Plugin EP from {ep_lib_path}: {e}") + + return _CudaPluginRegistrationState.registered + + +def resolve_cuda_plugin_ep(ep: str, default_test_with_cuda_plugin_ep: bool = True) -> str: + # Keep all existing test call-sites unchanged: they pass CUDA EP, + # and we transparently route to plugin EP when it is built and loadable. + if ep == "CUDAExecutionProvider" and ensure_cuda_plugin_ep_registered(default_test_with_cuda_plugin_ep): + if _is_plugin_provider_type_available(): + return CUDA_PLUGIN_EP_NAME + + if enable_debug_print: + print(f"{CUDA_PLUGIN_EP_NAME} is not exposed in available provider types. Falling back to {ep}.") + return ep + + +def get_cuda_provider_name() -> str | None: + if not torch.cuda.is_available(): + return None + + resolved_provider = resolve_cuda_plugin_ep("CUDAExecutionProvider") + available_providers = onnxrt.get_available_providers() + + if resolved_provider in available_providers: + return resolved_provider + + if "CUDAExecutionProvider" in available_providers: + return "CUDAExecutionProvider" + + return None + + +def _is_plugin_provider_type_available() -> bool: + try: + return CUDA_PLUGIN_EP_NAME in onnxrt.get_available_providers() + except Exception as e: + logger.warning("Failed to query available providers while checking %s availability: %s", CUDA_PLUGIN_EP_NAME, e) + return False diff --git a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py new file mode 100644 index 0000000000000..f1ce5b3d187ea --- /dev/null +++ b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py @@ -0,0 +1,1861 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import tempfile +import unittest + +import numpy as np +import torch +import torch.nn.functional as F +from cuda_plugin_ep_helper import CUDA_PLUGIN_EP_NAME, ensure_cuda_plugin_ep_registered, should_test_with_cuda_plugin_ep +from onnx import OperatorSetIdProto, TensorProto, helper, save + +import onnxruntime as onnxrt + +try: + import faulthandler + + faulthandler.enable() +except ImportError: + # faulthandler is optional in some Python runtimes used by CI. + pass + + +TEST_PASS = "PASS" +TEST_SKIP = "SKIP" +TEST_FAIL = "FAIL" +EP_GRAPH_ASSIGNMENT_CONFIG_KEY = "session.record_ep_graph_assignment_info" + + +def require_cuda_plugin_ep(): + if not should_test_with_cuda_plugin_ep(): + raise unittest.SkipTest("CUDA plugin EP is not enabled for testing") + + if not ensure_cuda_plugin_ep_registered(): + raise unittest.SkipTest("CUDA plugin EP is not built or could not be registered") + + +def get_cuda_plugin_device(): + return get_cuda_plugin_devices()[0] + + +def get_cuda_plugin_devices(): + require_cuda_plugin_ep() + + try: + devices = onnxrt.get_ep_devices() + except Exception as exc: + raise unittest.SkipTest(f"Failed to enumerate CUDA plugin EP devices: {exc}") from exc + + plugin_devices = [device for device in devices if device.ep_name == CUDA_PLUGIN_EP_NAME] + if not plugin_devices: + raise unittest.SkipTest("CUDA plugin EP registered, but no plugin devices were enumerated") + + return plugin_devices + + +def get_cuda_plugin_device_by_id(device_id: int): + expected_device_id = str(device_id) + for device in get_cuda_plugin_devices(): + if device.ep_options.get("device_id") == expected_device_id: + return device + if device.ep_metadata.get("cuda_device_id") == expected_device_id: + return device + + raise unittest.SkipTest(f"CUDA plugin EP device_id={device_id} is not available in this environment") + + +def _create_session_options(session_config=None): + sess_options = onnxrt.SessionOptions() + if session_config: + for key, value in session_config.items(): + sess_options.add_session_config_entry(key, value) + + # Require graph-assignment data so the tests validate that nodes actually run on the plugin. + sess_options.add_session_config_entry(EP_GRAPH_ASSIGNMENT_CONFIG_KEY, "1") + return sess_options + + +def _format_assigned_node(node): + domain = node.domain or "ai.onnx" + if node.name: + return f"{domain}::{node.op_type}:{node.name}" + return f"{domain}::{node.op_type}" + + +def _get_assigned_nodes(session, ep_name): + assignment_info = list(session.get_provider_graph_assignment_info()) + assigned_nodes = [] + for subgraph in assignment_info: + if subgraph.ep_name == ep_name: + assigned_nodes.extend(subgraph.get_nodes()) + + return assigned_nodes, assignment_info + + +def _format_assignment_summary(assignment_info): + if not assignment_info: + return "" + + summaries = [] + for subgraph in assignment_info: + node_summary = ", ".join(_format_assigned_node(node) for node in subgraph.get_nodes()) or "" + summaries.append(f"{subgraph.ep_name}[{node_summary}]") + + return "; ".join(summaries) + + +def create_add_model(model_path): + # Create a simple Add model: Y = A + B + node_def = helper.make_node("Add", ["A", "B"], ["Y"]) + graph_def = helper.make_graph( + [node_def], + "test-model-add", + [ + helper.make_tensor_value_info("A", TensorProto.FLOAT, [3, 2]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 2]), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 2])], + ) + model_def = helper.make_model(graph_def, producer_name="onnx-example") + save(model_def, model_path) + + +def create_matmul_model(model_path): + # Create a simple MatMul model: Y = A @ B + node_def = helper.make_node("MatMul", ["A", "B"], ["Y"]) + graph_def = helper.make_graph( + [node_def], + "test-model-matmul", + [ + helper.make_tensor_value_info("A", TensorProto.FLOAT, [3, 4]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [4, 5]), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 5])], + ) + model_def = helper.make_model(graph_def, producer_name="onnx-example") + save(model_def, model_path) + + +def create_gemm_model(model_path, alpha=1.0, beta=1.0, transA=0, transB=0): + # Create a simple Gemm model: Y = alpha*A*B + beta*C + node_def = helper.make_node("Gemm", ["A", "B", "C"], ["Y"], alpha=alpha, beta=beta, transA=transA, transB=transB) + + m = 3 + k = 4 + n = 5 + shape_a = [m, k] if transA == 0 else [k, m] + shape_b = [k, n] if transB == 0 else [n, k] + shape_c = [n] # Test broadcast + + graph_def = helper.make_graph( + [node_def], + "test-model-gemm", + [ + helper.make_tensor_value_info("A", TensorProto.FLOAT, shape_a), + helper.make_tensor_value_info("B", TensorProto.FLOAT, shape_b), + helper.make_tensor_value_info("C", TensorProto.FLOAT, shape_c), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [m, n])], + ) + model_def = helper.make_model(graph_def, producer_name="onnx-example") + save(model_def, model_path) + + +def create_conv_model(model_path): + # Create a simple Conv model: Y = Conv(X, W) + node_def = helper.make_node("Conv", ["X", "W"], ["Y"], pads=[1, 1, 1, 1], strides=[1, 1], dilations=[1, 1], group=1) + graph_def = helper.make_graph( + [node_def], + "test-model-conv", + [ + helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2, 4, 4]), + helper.make_tensor_value_info("W", TensorProto.FLOAT, [3, 2, 3, 3]), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 4, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 11 + model_def = helper.make_model(graph_def, producer_name="onnx-example", opset_imports=[opset]) + save(model_def, model_path) + + +def create_batch_norm_model(model_path): + """Create a BatchNormalization model for NHWC testing.""" + num_channels = 3 + node_def = helper.make_node( + "BatchNormalization", + ["X", "scale", "B", "input_mean", "input_var"], + ["Y"], + epsilon=1e-5, + ) + # scale, B, mean, var are 1D tensors of shape [num_channels] + scale_init = helper.make_tensor( + "scale", TensorProto.FLOAT, [num_channels], np.ones(num_channels, dtype=np.float32).tolist() + ) + bias_init = helper.make_tensor( + "B", TensorProto.FLOAT, [num_channels], np.zeros(num_channels, dtype=np.float32).tolist() + ) + mean_init = helper.make_tensor( + "input_mean", TensorProto.FLOAT, [num_channels], np.zeros(num_channels, dtype=np.float32).tolist() + ) + var_init = helper.make_tensor( + "input_var", TensorProto.FLOAT, [num_channels], np.ones(num_channels, dtype=np.float32).tolist() + ) + + graph_def = helper.make_graph( + [node_def], + "test-model-batchnorm", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, num_channels, 4, 4])], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, num_channels, 4, 4])], + initializer=[scale_init, bias_init, mean_init, var_init], + ) + opset = OperatorSetIdProto() + opset.version = 15 + model_def = helper.make_model(graph_def, producer_name="onnx-example", opset_imports=[opset]) + save(model_def, model_path) + + +def create_maxpool_model(model_path): + """Create a MaxPool model for NHWC testing.""" + node_def = helper.make_node( + "MaxPool", + ["X"], + ["Y"], + kernel_shape=[2, 2], + strides=[2, 2], + ) + graph_def = helper.make_graph( + [node_def], + "test-model-maxpool", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4])], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 2, 2])], + ) + opset = OperatorSetIdProto() + opset.version = 12 + model_def = helper.make_model(graph_def, producer_name="onnx-example", opset_imports=[opset]) + save(model_def, model_path) + + +def create_avgpool_model(model_path): + """Create an AveragePool model for NHWC testing.""" + node_def = helper.make_node( + "AveragePool", + ["X"], + ["Y"], + kernel_shape=[2, 2], + strides=[2, 2], + ) + graph_def = helper.make_graph( + [node_def], + "test-model-avgpool", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 4, 4])], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 2, 2])], + ) + opset = OperatorSetIdProto() + opset.version = 12 + model_def = helper.make_model(graph_def, producer_name="onnx-example", opset_imports=[opset]) + save(model_def, model_path) + + +def make_bias_dropout_model(): + """Create a deterministic BiasDropout model by forcing inference mode.""" + node = helper.make_node( + "BiasDropout", + ["X", "bias", "residual", "ratio", "training_mode"], + ["Y", ""], + domain="com.microsoft", + ) + graph = helper.make_graph( + [node], + "test-BiasDropout", + [ + helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4]), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, [4]), + helper.make_tensor_value_info("residual", TensorProto.FLOAT, [2, 4]), + helper.make_tensor_value_info("ratio", TensorProto.FLOAT, []), + helper.make_tensor_value_info("training_mode", TensorProto.BOOL, []), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])], + ) + opset_onnx = OperatorSetIdProto() + opset_onnx.version = 13 + opset_ms = OperatorSetIdProto() + opset_ms.domain = "com.microsoft" + opset_ms.version = 1 + return helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) + + +def run_operator_test( + target_device, model_creator, inputs, expected_fn, ep_name=CUDA_PLUGIN_EP_NAME, session_config=None +): + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp: + model_path = tmp.name + try: + model_creator(model_path) + sess_options = _create_session_options(session_config) + sess_options.add_provider_for_devices([target_device], {}) + sess = onnxrt.InferenceSession(model_path, sess_options=sess_options) + + active_providers = sess.get_providers() + assigned_nodes, assignment_info = _get_assigned_nodes(sess, ep_name) + if not assigned_nodes: + print( + f"FAILURE: {ep_name} was assigned no nodes. Providers: {active_providers}. " + f"Assignments: {_format_assignment_summary(assignment_info)}" + ) + return False + + print( + f"(Session created with {active_providers}; assigned nodes: " + f"{', '.join(_format_assigned_node(node) for node in assigned_nodes)})", + flush=True, + ) + res = sess.run(None, inputs) + expected = expected_fn(inputs) + np.testing.assert_allclose(res[0], expected, rtol=1e-3, atol=1e-3) + return True + finally: + if os.path.exists(model_path): + os.remove(model_path) + + +def run_provider_options_test(provider_options, expect_plugin_provider=True): + require_cuda_plugin_ep() + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp: + model_path = tmp.name + try: + create_add_model(model_path) + providers = [(CUDA_PLUGIN_EP_NAME, provider_options), "CPUExecutionProvider"] + sess = onnxrt.InferenceSession(model_path, sess_options=_create_session_options(), providers=providers) + active_providers = sess.get_providers() + assigned_nodes, assignment_info = _get_assigned_nodes(sess, CUDA_PLUGIN_EP_NAME) + + if expect_plugin_provider and not assigned_nodes: + print( + f"FAILURE: {CUDA_PLUGIN_EP_NAME} was assigned no nodes. Providers: {active_providers}. " + f"Assignments: {_format_assignment_summary(assignment_info)}" + ) + return False + if not expect_plugin_provider and assigned_nodes: + print( + f"FAILURE: {CUDA_PLUGIN_EP_NAME} unexpectedly owned nodes. " + f"Assignments: {_format_assignment_summary(assignment_info)}" + ) + return False + + a = np.random.rand(3, 2).astype(np.float32) + b = np.random.rand(3, 2).astype(np.float32) + res = sess.run(None, {"A": a, "B": b}) + np.testing.assert_allclose(res[0], a + b, rtol=1e-3, atol=1e-3) + return True + except Exception as e: + if expect_plugin_provider: + print(f"FAIL ({e})") + return False + + print(f"Expected failure for provider options {provider_options}: {e}") + return True + finally: + if os.path.exists(model_path): + os.remove(model_path) + + +def _expected_conv(inputs): + return F.conv2d(torch.from_numpy(inputs["X"]), torch.from_numpy(inputs["W"]), padding=1).numpy() + + +_NHWC_CONFIG = {"ep.cuda.prefer_nhwc_layout": "1"} + + +def _expected_batchnorm(inputs): + return inputs["X"] / np.sqrt(1.0 + 1e-5) + + +def _make_simple_model(op_type, inputs_info, outputs_info, attrs=None, opset=13, domain=""): + """Helper to create a simple single-node ONNX model. + + Args: + op_type: ONNX op type string + inputs_info: list of (name, elem_type, shape) tuples + outputs_info: list of (name, elem_type, shape) tuples + attrs: dict of node attributes + opset: opset version + domain: op domain (empty string for default ONNX domain) + """ + input_names = [info[0] for info in inputs_info] + output_names = [info[0] for info in outputs_info] + node = helper.make_node(op_type, input_names, output_names, domain=domain, **(attrs or {})) + graph = helper.make_graph( + [node], + f"test-{op_type}", + [helper.make_tensor_value_info(n, t, s) for n, t, s in inputs_info], + [helper.make_tensor_value_info(n, t, s) for n, t, s in outputs_info], + ) + opset_import = [OperatorSetIdProto()] + opset_import[0].version = opset + if domain: + ms_opset = OperatorSetIdProto() + ms_opset.domain = domain + ms_opset.version = 1 + opset_import.append(ms_opset) + model = helper.make_model(graph, opset_imports=opset_import) + return model + + +def _run_model_test( + target_device, op_name, model, feed_dict, expected_fn, ep_name=CUDA_PLUGIN_EP_NAME, rtol=1e-3, atol=1e-3 +): + """Run a single op test: save model, create session, run, compare.""" + with tempfile.NamedTemporaryFile(suffix=f"_{op_name}.onnx", delete=False) as tmp: + model_path = tmp.name + try: + save(model, model_path) + sess_options = _create_session_options() + sess_options.add_provider_for_devices([target_device], {}) + sess = onnxrt.InferenceSession(model_path, sess_options=sess_options) + active_providers = sess.get_providers() + assigned_nodes, assignment_info = _get_assigned_nodes(sess, ep_name) + if not assigned_nodes: + print( + f"{TEST_FAIL} ({ep_name} was assigned no nodes; providers={active_providers}; " + f"assignments={_format_assignment_summary(assignment_info)})" + ) + return TEST_FAIL + res = sess.run(None, feed_dict) + expected = expected_fn(feed_dict) + if isinstance(expected, (list, tuple)): + if len(res) != len(expected): + raise AssertionError(f"{op_name} produced {len(res)} outputs, expected {len(expected)}") + + for r, e in zip(res, expected, strict=True): + np.testing.assert_allclose(r, e, rtol=rtol, atol=atol) + else: + np.testing.assert_allclose(res[0], expected, rtol=rtol, atol=atol) + return TEST_PASS + except Exception as e: + print(f"{TEST_FAIL} ({e})") + return TEST_FAIL + finally: + if os.path.exists(model_path): + os.remove(model_path) + + +class TestCudaPluginEP(unittest.TestCase): + # ---- Registration tests (verify nodes run on the plugin EP) ---- + + def test_registration_add(self): + target_device = get_cuda_plugin_device() + inputs = {"A": np.random.rand(3, 2).astype(np.float32), "B": np.random.rand(3, 2).astype(np.float32)} + result = run_operator_test(target_device, create_add_model, inputs, lambda feed: feed["A"] + feed["B"]) + self.assertTrue(result, "Add plugin registration test failed") + + def test_registration_matmul(self): + target_device = get_cuda_plugin_device() + inputs = {"A": np.random.rand(3, 4).astype(np.float32), "B": np.random.rand(4, 5).astype(np.float32)} + result = run_operator_test(target_device, create_matmul_model, inputs, lambda feed: feed["A"] @ feed["B"]) + self.assertTrue(result, "MatMul plugin registration test failed") + + def test_registration_gemm(self): + target_device = get_cuda_plugin_device() + inputs = { + "A": np.random.rand(3, 4).astype(np.float32), + "B": np.random.rand(4, 5).astype(np.float32), + "C": np.random.rand(5).astype(np.float32), + } + result = run_operator_test( + target_device, + lambda model_path: create_gemm_model(model_path, alpha=2.0, beta=0.5), + inputs, + lambda feed: 2.0 * (feed["A"] @ feed["B"]) + 0.5 * feed["C"], + ) + self.assertTrue(result, "Gemm plugin registration test failed") + + def test_registration_conv(self): + target_device = get_cuda_plugin_device() + inputs = { + "X": np.random.rand(1, 2, 4, 4).astype(np.float32), + "W": np.random.rand(3, 2, 3, 3).astype(np.float32), + } + result = run_operator_test(target_device, create_conv_model, inputs, _expected_conv) + self.assertTrue(result, "Conv plugin registration test failed") + + # ---- Provider options tests ---- + + def test_provider_options_valid(self): + result = run_provider_options_test({"device_id": "0", "use_tf32": "0"}, expect_plugin_provider=True) + self.assertTrue(result, "Provider options with valid device_id/use_tf32 failed") + + def test_provider_options_invalid_device(self): + result = run_provider_options_test({"device_id": "999"}, expect_plugin_provider=False) + self.assertTrue(result, "Provider options with invalid device_id failed") + + def test_provider_options_second_device(self): + plugin_devices = get_cuda_plugin_devices() + if len(plugin_devices) < 2: + self.skipTest("Multi-GPU CUDA plugin EP test requires at least two plugin devices") + + target_device = get_cuda_plugin_device_by_id(1) + + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp: + model_path = tmp.name + try: + create_add_model(model_path) + providers = [(CUDA_PLUGIN_EP_NAME, {"device_id": "1"}), "CPUExecutionProvider"] + sess = onnxrt.InferenceSession(model_path, sess_options=_create_session_options(), providers=providers) + + active_providers = sess.get_providers() + assigned_nodes, assignment_info = _get_assigned_nodes(sess, CUDA_PLUGIN_EP_NAME) + self.assertTrue( + assigned_nodes, + f"{CUDA_PLUGIN_EP_NAME} was assigned no nodes. Providers: {active_providers}. " + f"Assignments: {_format_assignment_summary(assignment_info)}", + ) + + provider_options = sess.get_provider_options() + self.assertEqual( + provider_options[CUDA_PLUGIN_EP_NAME].get("device_id"), + "1", + f"Expected provider option device_id=1, got {provider_options[CUDA_PLUGIN_EP_NAME]}", + ) + self.assertEqual(target_device.ep_options.get("device_id"), "1") + + a = np.random.rand(3, 2).astype(np.float32) + b = np.random.rand(3, 2).astype(np.float32) + res = sess.run(None, {"A": a, "B": b}) + np.testing.assert_allclose(res[0], a + b, rtol=1e-3, atol=1e-3) + finally: + if os.path.exists(model_path): + os.remove(model_path) + + # ---- NHWC layout tests ---- + + def test_nhwc_conv(self): + target_device = get_cuda_plugin_device() + inputs = { + "X": np.random.rand(1, 2, 4, 4).astype(np.float32), + "W": np.random.rand(3, 2, 3, 3).astype(np.float32), + } + result = run_operator_test( + target_device, create_conv_model, inputs, _expected_conv, session_config=_NHWC_CONFIG + ) + self.assertTrue(result, "Conv (NHWC) plugin test failed") + + def test_nhwc_batch_normalization(self): + target_device = get_cuda_plugin_device() + inputs = {"X": np.random.rand(1, 3, 4, 4).astype(np.float32)} + result = run_operator_test( + target_device, create_batch_norm_model, inputs, _expected_batchnorm, session_config=_NHWC_CONFIG + ) + self.assertTrue(result, "BatchNormalization (NHWC) plugin test failed") + + def test_nhwc_maxpool(self): + target_device = get_cuda_plugin_device() + inputs = {"X": np.random.rand(1, 3, 4, 4).astype(np.float32)} + result = run_operator_test( + target_device, + create_maxpool_model, + inputs, + lambda feed: F.max_pool2d(torch.from_numpy(feed["X"]), kernel_size=2, stride=2).numpy(), + session_config=_NHWC_CONFIG, + ) + self.assertTrue(result, "MaxPool (NHWC) plugin test failed") + + def test_nhwc_avgpool(self): + target_device = get_cuda_plugin_device() + inputs = {"X": np.random.rand(1, 3, 4, 4).astype(np.float32)} + result = run_operator_test( + target_device, + create_avgpool_model, + inputs, + lambda feed: F.avg_pool2d(torch.from_numpy(feed["X"]), kernel_size=2, stride=2).numpy(), + session_config=_NHWC_CONFIG, + ) + self.assertTrue(result, "AveragePool (NHWC) plugin test failed") + + # ---- Standard op tests ---- + + def test_op_reshape(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "Reshape", [("X", f_dtype, [2, 3, 4]), ("shape", TensorProto.INT64, [2])], [("Y", f_dtype, [6, 4])] + ) + model.graph.initializer.append(helper.make_tensor("shape", TensorProto.INT64, [2], [6, 4])) + x = np.random.rand(2, 3, 4).astype(np.float32) + result = _run_model_test(target_device, "Reshape", model, {"X": x}, lambda f: f["X"].reshape(6, 4)) + self.assertEqual(result, TEST_PASS, "Reshape plugin op test failed") + + def test_op_split(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Split", ["X", "split"], ["Y1", "Y2"], axis=0) + graph = helper.make_graph( + [node], + "test-Split", + [helper.make_tensor_value_info("X", f_dtype, [6, 4])], + [ + helper.make_tensor_value_info("Y1", f_dtype, [3, 4]), + helper.make_tensor_value_info("Y2", f_dtype, [3, 4]), + ], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("split", TensorProto.INT64, [2], [3, 3])) + x = np.random.rand(6, 4).astype(np.float32) + result = _run_model_test(target_device, "Split", model, {"X": x}, lambda f: [f["X"][:3], f["X"][3:]]) + self.assertEqual(result, TEST_PASS, "Split plugin op test failed") + + def test_op_concat(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "Concat", [("A", f_dtype, [2, 3]), ("B", f_dtype, [2, 3])], [("Y", f_dtype, [4, 3])], attrs={"axis": 0} + ) + a = np.random.rand(2, 3).astype(np.float32) + b = np.random.rand(2, 3).astype(np.float32) + result = _run_model_test( + target_device, "Concat", model, {"A": a, "B": b}, lambda f: np.concatenate([f["A"], f["B"]], axis=0) + ) + self.assertEqual(result, TEST_PASS, "Concat plugin op test failed") + + def test_op_gather(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "Gather", + [("X", f_dtype, [5, 4]), ("indices", TensorProto.INT64, [3])], + [("Y", f_dtype, [3, 4])], + attrs={"axis": 0}, + opset=13, + ) + x = np.random.rand(5, 4).astype(np.float32) + idx = np.array([0, 2, 4], dtype=np.int64) + result = _run_model_test( + target_device, "Gather", model, {"X": x, "indices": idx}, lambda f: f["X"][f["indices"]] + ) + self.assertEqual(result, TEST_PASS, "Gather plugin op test failed") + + def test_op_unsqueeze(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Unsqueeze", ["X", "axes"], ["Y"]) + graph = helper.make_graph( + [node], + "test-Unsqueeze", + [helper.make_tensor_value_info("X", f_dtype, [3, 4])], + [helper.make_tensor_value_info("Y", f_dtype, [1, 3, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("axes", TensorProto.INT64, [1], [0])) + x = np.random.rand(3, 4).astype(np.float32) + result = _run_model_test(target_device, "Unsqueeze", model, {"X": x}, lambda f: np.expand_dims(f["X"], 0)) + self.assertEqual(result, TEST_PASS, "Unsqueeze plugin op test failed") + + def test_op_tile(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Tile", ["X", "repeats"], ["Y"]) + graph = helper.make_graph( + [node], + "test-Tile", + [helper.make_tensor_value_info("X", f_dtype, [2, 3])], + [helper.make_tensor_value_info("Y", f_dtype, [4, 9])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("repeats", TensorProto.INT64, [2], [2, 3])) + x = np.random.rand(2, 3).astype(np.float32) + result = _run_model_test(target_device, "Tile", model, {"X": x}, lambda f: np.tile(f["X"], (2, 3))) + self.assertEqual(result, TEST_PASS, "Tile plugin op test failed") + + def test_op_cumsum(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("CumSum", ["X", "axis"], ["Y"]) + graph = helper.make_graph( + [node], + "test-CumSum", + [helper.make_tensor_value_info("X", f_dtype, [3, 4])], + [helper.make_tensor_value_info("Y", f_dtype, [3, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 14 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("axis", TensorProto.INT64, [], [1])) + x = np.random.rand(3, 4).astype(np.float32) + result = _run_model_test(target_device, "CumSum", model, {"X": x}, lambda f: np.cumsum(f["X"], axis=1)) + self.assertEqual(result, TEST_PASS, "CumSum plugin op test failed") + + def test_op_constant_of_shape(self): + target_device = get_cuda_plugin_device() + node = helper.make_node( + "ConstantOfShape", ["shape"], ["Y"], value=helper.make_tensor("value", TensorProto.FLOAT, [1], [3.14]) + ) + graph = helper.make_graph( + [node], + "test-ConstantOfShape", + [helper.make_tensor_value_info("shape", TensorProto.INT64, [2])], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, None)], + ) + opset = OperatorSetIdProto() + opset.version = 9 + model = helper.make_model(graph, opset_imports=[opset]) + result = _run_model_test( + target_device, + "ConstantOfShape", + model, + {"shape": np.array([2, 3], dtype=np.int64)}, + lambda f: np.full((2, 3), 3.14, dtype=np.float32), + ) + self.assertEqual(result, TEST_PASS, "ConstantOfShape plugin op test failed") + + def test_op_space_to_depth(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "SpaceToDepth", + [("X", f_dtype, [1, 2, 4, 4])], + [("Y", f_dtype, [1, 8, 2, 2])], + attrs={"blocksize": 2}, + opset=13, + ) + x = np.random.rand(1, 2, 4, 4).astype(np.float32) + + def expected(f): + inp = f["X"] + b, c, h, w = inp.shape + bs = 2 + tmp = inp.reshape(b, c, h // bs, bs, w // bs, bs) + tmp = tmp.transpose(0, 3, 5, 1, 2, 4) + return tmp.reshape(b, c * bs * bs, h // bs, w // bs) + + result = _run_model_test(target_device, "SpaceToDepth", model, {"X": x}, expected) + self.assertEqual(result, TEST_PASS, "SpaceToDepth plugin op test failed") + + def test_op_pad(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Pad", ["X", "pads", "constant_value"], ["Y"]) + graph = helper.make_graph( + [node], + "test-Pad", + [helper.make_tensor_value_info("X", f_dtype, [2, 3])], + [helper.make_tensor_value_info("Y", f_dtype, [4, 5])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("pads", TensorProto.INT64, [4], [1, 1, 1, 1])) + model.graph.initializer.append(helper.make_tensor("constant_value", TensorProto.FLOAT, [], [0.0])) + x = np.random.rand(2, 3).astype(np.float32) + result = _run_model_test( + target_device, "Pad", model, {"X": x}, lambda f: np.pad(f["X"], ((1, 1), (1, 1)), constant_values=0) + ) + self.assertEqual(result, TEST_PASS, "Pad plugin op test failed") + + def test_op_slice(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"]) + graph = helper.make_graph( + [node], + "test-Slice", + [helper.make_tensor_value_info("X", f_dtype, [4, 6])], + [helper.make_tensor_value_info("Y", f_dtype, [2, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("starts", TensorProto.INT64, [2], [1, 1])) + model.graph.initializer.append(helper.make_tensor("ends", TensorProto.INT64, [2], [3, 5])) + model.graph.initializer.append(helper.make_tensor("axes", TensorProto.INT64, [2], [0, 1])) + x = np.random.rand(4, 6).astype(np.float32) + result = _run_model_test(target_device, "Slice", model, {"X": x}, lambda f: f["X"][1:3, 1:5]) + self.assertEqual(result, TEST_PASS, "Slice plugin op test failed") + + def test_op_resize(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Resize", ["X", "", "scales"], ["Y"], mode="nearest") + graph = helper.make_graph( + [node], + "test-Resize", + [helper.make_tensor_value_info("X", f_dtype, [1, 1, 2, 2])], + [helper.make_tensor_value_info("Y", f_dtype, [1, 1, 4, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("scales", TensorProto.FLOAT, [4], [1.0, 1.0, 2.0, 2.0])) + x = np.random.rand(1, 1, 2, 2).astype(np.float32) + result = _run_model_test( + target_device, "Resize", model, {"X": x}, lambda f: np.repeat(np.repeat(f["X"], 2, axis=2), 2, axis=3) + ) + self.assertEqual(result, TEST_PASS, "Resize plugin op test failed") + + def test_op_sum_variadic(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "Sum", + [("A", f_dtype, [3, 4]), ("B", f_dtype, [3, 4]), ("C", f_dtype, [3, 4])], + [("Y", f_dtype, [3, 4])], + opset=13, + ) + a = np.random.rand(3, 4).astype(np.float32) + b = np.random.rand(3, 4).astype(np.float32) + c = np.random.rand(3, 4).astype(np.float32) + result = _run_model_test( + target_device, "Sum_variadic", model, {"A": a, "B": b, "C": c}, lambda f: f["A"] + f["B"] + f["C"] + ) + self.assertEqual(result, TEST_PASS, "Sum_variadic plugin op test failed") + + # ---- CPU base class op tests ---- + + def test_op_upsample(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Upsample", ["X", "scales"], ["Y"], mode="nearest") + graph = helper.make_graph( + [node], + "test-Upsample", + [helper.make_tensor_value_info("X", f_dtype, [1, 1, 2, 2])], + [helper.make_tensor_value_info("Y", f_dtype, [1, 1, 4, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 9 + model = helper.make_model(graph, opset_imports=[opset]) + model.graph.initializer.append(helper.make_tensor("scales", TensorProto.FLOAT, [4], [1.0, 1.0, 2.0, 2.0])) + x = np.random.rand(1, 1, 2, 2).astype(np.float32) + result = _run_model_test( + target_device, "Upsample", model, {"X": x}, lambda f: np.repeat(np.repeat(f["X"], 2, axis=2), 2, axis=3) + ) + self.assertEqual(result, TEST_PASS, "Upsample plugin op test failed") + + def test_op_depth_to_space(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model( + "DepthToSpace", + [("X", f_dtype, [1, 8, 2, 2])], + [("Y", f_dtype, [1, 2, 4, 4])], + attrs={"blocksize": 2, "mode": "DCR"}, + opset=13, + ) + x = np.random.rand(1, 8, 2, 2).astype(np.float32) + + def expected(f): + inp = f["X"] + b, c, h, w = inp.shape + bs = 2 + return ( + inp.reshape(b, bs, bs, c // (bs * bs), h, w) + .transpose(0, 3, 4, 1, 5, 2) + .reshape(b, c // (bs * bs), h * bs, w * bs) + ) + + result = _run_model_test(target_device, "DepthToSpace", model, {"X": x}, expected) + self.assertEqual(result, TEST_PASS, "DepthToSpace plugin op test failed") + + # ---- Contrib op tests ---- + + def test_op_fast_gelu(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("FastGelu", ["X"], ["Y"], domain="com.microsoft") + graph = helper.make_graph( + [node], + "test-FastGelu", + [helper.make_tensor_value_info("X", f_dtype, [2, 4])], + [helper.make_tensor_value_info("Y", f_dtype, [2, 4])], + ) + opset_onnx = OperatorSetIdProto() + opset_onnx.version = 13 + opset_ms = OperatorSetIdProto() + opset_ms.domain = "com.microsoft" + opset_ms.version = 1 + model = helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) + x = np.random.rand(2, 4).astype(np.float32) + + def expected(f): + v = f["X"] + return v * (1.0 / (1.0 + np.exp(-1.702 * v))) + + result = _run_model_test(target_device, "FastGelu", model, {"X": x}, expected, rtol=1e-2, atol=1e-2) + self.assertEqual(result, TEST_PASS, "FastGelu plugin op test failed") + + def test_op_bias_dropout(self): + target_device = get_cuda_plugin_device() + model = make_bias_dropout_model() + x = np.random.rand(2, 4).astype(np.float32) + bias = np.random.rand(4).astype(np.float32) + residual = np.random.rand(2, 4).astype(np.float32) + ratio = np.array(0.5, dtype=np.float32) + training_mode = np.array(False, dtype=np.bool_) + feed = {"X": x, "bias": bias, "residual": residual, "ratio": ratio, "training_mode": training_mode} + result = _run_model_test( + target_device, "BiasDropout", model, feed, lambda f: f["X"] + f["bias"] + f["residual"] + ) + self.assertEqual(result, TEST_PASS, "BiasDropout plugin op test failed") + + def test_op_dropout_opset7(self): + """Dropout opset 7-9: simple in/out, no mask. Verifies old-version registration in dropout.cc.""" + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Dropout", ["X"], ["Y"], ratio=0.0) + graph = helper.make_graph( + [node], + "test-Dropout-opset7", + [helper.make_tensor_value_info("X", f_dtype, [2, 4])], + [helper.make_tensor_value_info("Y", f_dtype, [2, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 7 + model = helper.make_model(graph, opset_imports=[opset]) + x = np.random.rand(2, 4).astype(np.float32) + result = _run_model_test(target_device, "Dropout_opset7", model, {"X": x}, lambda f: f["X"]) + self.assertEqual(result, TEST_PASS, "Dropout opset 7 plugin op test failed") + + def test_op_dropout_opset10(self): + """Dropout opset 10-11: data + optional mask output. Verifies old-version registration in dropout.cc.""" + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Dropout", ["X"], ["Y", "mask"]) + graph = helper.make_graph( + [node], + "test-Dropout-opset10", + [helper.make_tensor_value_info("X", f_dtype, [2, 4])], + [ + helper.make_tensor_value_info("Y", f_dtype, [2, 4]), + helper.make_tensor_value_info("mask", TensorProto.BOOL, [2, 4]), + ], + ) + opset = OperatorSetIdProto() + opset.version = 10 + model = helper.make_model(graph, opset_imports=[opset]) + x = np.random.rand(2, 4).astype(np.float32) + result = _run_model_test( + target_device, + "Dropout_opset10", + model, + {"X": x}, + lambda f: [f["X"], np.zeros((2, 4), dtype=bool)], + ) + self.assertEqual(result, TEST_PASS, "Dropout opset 10 plugin op test failed") + + def test_op_dequantize_linear_opset21(self): + """DequantizeLinear opset 21 uses TWO_TYPED_KERNEL_EX — verifies the new adapter macro.""" + target_device = get_cuda_plugin_device() + node = helper.make_node("DequantizeLinear", ["x", "x_scale", "x_zero_point"], ["y"]) + x_data = np.array([0, 1, 2, 3, 4, 5], dtype=np.uint8) + scale_data = np.array(0.5, dtype=np.float32) + zp_data = np.array(2, dtype=np.uint8) + graph = helper.make_graph( + [node], + "test-DequantizeLinear-opset21", + [ + helper.make_tensor_value_info("x", TensorProto.UINT8, [6]), + helper.make_tensor_value_info("x_scale", TensorProto.FLOAT, []), + helper.make_tensor_value_info("x_zero_point", TensorProto.UINT8, []), + ], + [helper.make_tensor_value_info("y", TensorProto.FLOAT, [6])], + ) + opset = OperatorSetIdProto() + opset.version = 21 + model = helper.make_model(graph, opset_imports=[opset]) + feed = {"x": x_data, "x_scale": scale_data, "x_zero_point": zp_data} + result = _run_model_test( + target_device, + "DequantizeLinear_opset21", + model, + feed, + lambda f: (f["x"].astype(np.float32) - f["x_zero_point"].astype(np.float32)) * f["x_scale"], + ) + self.assertEqual(result, TEST_PASS, "DequantizeLinear opset 21 plugin op test failed") + + def test_op_quantize_linear_opset21(self): + """QuantizeLinear opset 21 uses TWO_TYPED_KERNEL_EX — verifies the new adapter macro.""" + target_device = get_cuda_plugin_device() + node = helper.make_node("QuantizeLinear", ["x", "y_scale", "y_zero_point"], ["y"]) + x_data = np.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5], dtype=np.float32) + scale_data = np.array(0.5, dtype=np.float32) + zp_data = np.array(0, dtype=np.uint8) + graph = helper.make_graph( + [node], + "test-QuantizeLinear-opset21", + [ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [6]), + helper.make_tensor_value_info("y_scale", TensorProto.FLOAT, []), + helper.make_tensor_value_info("y_zero_point", TensorProto.UINT8, []), + ], + [helper.make_tensor_value_info("y", TensorProto.UINT8, [6])], + ) + opset = OperatorSetIdProto() + opset.version = 21 + model = helper.make_model(graph, opset_imports=[opset]) + feed = {"x": x_data, "y_scale": scale_data, "y_zero_point": zp_data} + result = _run_model_test( + target_device, + "QuantizeLinear_opset21", + model, + feed, + lambda f: np.clip( + np.round(f["x"] / f["y_scale"]).astype(np.float32) + f["y_zero_point"].astype(np.float32), + 0, + 255, + ).astype(np.uint8), + atol=1, + ) + self.assertEqual(result, TEST_PASS, "QuantizeLinear opset 21 plugin op test failed") + + def test_op_gather_block_quantized(self): + """GatherBlockQuantized uses THREE_TYPED_KERNEL_EX — verifies the new adapter macro.""" + target_device = get_cuda_plugin_device() + # GatherBlockQuantized: gathers rows from a block-quantized weight matrix. + # data shape [4, 16] (uint8), scales shape [4, 1] (float), indices [2] (int64) + # bits=8, block_size=16 (must be >= 16 and power of 2), quantize_axis=last + node = helper.make_node( + "GatherBlockQuantized", + ["data", "indices", "scales"], + ["output"], + domain="com.microsoft", + gather_axis=0, + quantize_axis=1, + block_size=16, + bits=8, + ) + data = np.random.randint(0, 255, size=(4, 16), dtype=np.uint8) + scales = np.random.rand(4, 1).astype(np.float32) * 0.1 + 0.01 + indices = np.array([0, 2], dtype=np.int64) + graph = helper.make_graph( + [node], + "test-GatherBlockQuantized", + [ + helper.make_tensor_value_info("data", TensorProto.UINT8, [4, 16]), + helper.make_tensor_value_info("indices", TensorProto.INT64, [2]), + helper.make_tensor_value_info("scales", TensorProto.FLOAT, [4, 1]), + ], + [helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 16])], + ) + opset_onnx = OperatorSetIdProto() + opset_onnx.version = 21 + opset_ms = OperatorSetIdProto() + opset_ms.domain = "com.microsoft" + opset_ms.version = 1 + model = helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) + feed = {"data": data, "indices": indices, "scales": scales} + + def expected(f): + # Gather rows [0, 2], then dequantize: float_val = uint8_val * scale + gathered_data = f["data"][f["indices"]] # [2, 16] + gathered_scales = f["scales"][f["indices"]] # [2, 1] + return gathered_data.astype(np.float32) * gathered_scales + + result = _run_model_test(target_device, "GatherBlockQuantized", model, feed, expected, rtol=1e-2, atol=1e-2) + self.assertEqual(result, TEST_PASS, "GatherBlockQuantized plugin op test failed") + + def test_op_skip_layer_norm(self): + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + hidden_size = 8 + node = helper.make_node( + "SkipLayerNormalization", + ["X", "skip", "gamma", "beta"], + ["Y", "", "", "input_skip_bias_sum"], + domain="com.microsoft", + epsilon=1e-5, + ) + graph = helper.make_graph( + [node], + "test-SkipLayerNorm", + [ + helper.make_tensor_value_info("X", f_dtype, [2, hidden_size]), + helper.make_tensor_value_info("skip", f_dtype, [2, hidden_size]), + helper.make_tensor_value_info("gamma", f_dtype, [hidden_size]), + helper.make_tensor_value_info("beta", f_dtype, [hidden_size]), + ], + [ + helper.make_tensor_value_info("Y", f_dtype, [2, hidden_size]), + helper.make_tensor_value_info("input_skip_bias_sum", f_dtype, None), + ], + ) + opset_onnx = OperatorSetIdProto() + opset_onnx.version = 13 + opset_ms = OperatorSetIdProto() + opset_ms.domain = "com.microsoft" + opset_ms.version = 1 + model = helper.make_model(graph, opset_imports=[opset_onnx, opset_ms]) + x = np.random.rand(2, hidden_size).astype(np.float32) + skip = np.random.rand(2, hidden_size).astype(np.float32) + gamma = np.ones(hidden_size, dtype=np.float32) + beta = np.zeros(hidden_size, dtype=np.float32) + + def expected(f): + added = f["X"] + f["skip"] + mean = added.mean(axis=-1, keepdims=True) + var = added.var(axis=-1, keepdims=True) + normed = (added - mean) / np.sqrt(var + 1e-5) + return [normed * f["gamma"] + f["beta"], added] + + result = _run_model_test( + target_device, + "SkipLayerNorm", + model, + {"X": x, "skip": skip, "gamma": gamma, "beta": beta}, + expected, + rtol=1e-2, + atol=1e-2, + ) + self.assertEqual(result, TEST_PASS, "SkipLayerNorm plugin op test failed") + + # ---- Tests for previously-excluded ops (identity, crop, dynamicslice) ---- + + def test_op_identity(self): + """Identity op: previously excluded from plugin due to TensorSeq; now Tensor-only.""" + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model("Identity", [("X", f_dtype, [3, 4])], [("Y", f_dtype, [3, 4])]) + x = np.random.rand(3, 4).astype(np.float32) + result = _run_model_test(target_device, "Identity", model, {"X": x}, lambda f: f["X"]) + self.assertEqual(result, TEST_PASS, "Identity plugin op test failed") + + def test_op_identity_opset25(self): + """Identity opset 25: highest opset, uses V type constraint (Tensor subset in plugin).""" + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + model = _make_simple_model("Identity", [("X", f_dtype, [2, 5])], [("Y", f_dtype, [2, 5])], opset=25) + x = np.random.rand(2, 5).astype(np.float32) + result = _run_model_test(target_device, "Identity_opset25", model, {"X": x}, lambda f: f["X"]) + self.assertEqual(result, TEST_PASS, "Identity opset 25 plugin op test failed") + + def test_op_crop(self): + """Crop (opset 1, contrib): previously excluded from plugin.""" + target_device = get_cuda_plugin_device() + f_dtype = TensorProto.FLOAT + node = helper.make_node("Crop", ["input"], ["output"], border=[1, 1, 1, 1]) + graph = helper.make_graph( + [node], + "test-Crop", + [helper.make_tensor_value_info("input", f_dtype, [1, 1, 4, 4])], + [helper.make_tensor_value_info("output", f_dtype, [1, 1, 2, 2])], + ) + opset = OperatorSetIdProto() + opset.version = 1 + model = helper.make_model(graph, opset_imports=[opset]) + x = np.arange(16, dtype=np.float32).reshape(1, 1, 4, 4) + result = _run_model_test( + target_device, + "Crop", + model, + {"input": x}, + lambda f: f["input"][:, :, 1:3, 1:3], + ) + self.assertEqual(result, TEST_PASS, "Crop plugin op test failed") + + def test_plugin_ep_claims_key_ops(self): + """Session-based probing: verify the plugin EP claims key ops via graph assignment.""" + target_device = get_cuda_plugin_device() + + # Representative ops the plugin EP must claim (op_type, domain, opset, inputs, outputs, attrs). + # One representative per major op family; ops already covered by dedicated test_registration_* + # or test_op_* tests (Add, MatMul, Gemm, Conv, …) are intentionally excluded here. + probe_specs = [ + # binary elementwise (Sub — Add is tested by test_registration_add) + ( + "Sub", + "", + 13, + [("A", TensorProto.FLOAT, [2, 4]), ("B", TensorProto.FLOAT, [2, 4])], + [("Y", TensorProto.FLOAT, [2, 4])], + None, + ), + # unary activation + ("Relu", "", 13, [("X", TensorProto.FLOAT, [2, 4])], [("Y", TensorProto.FLOAT, [2, 4])], None), + # reduction-style + ("Softmax", "", 13, [("X", TensorProto.FLOAT, [2, 4])], [("Y", TensorProto.FLOAT, [2, 4])], {"axis": -1}), + # data-movement + ( + "Transpose", + "", + 13, + [("X", TensorProto.FLOAT, [2, 4])], + [("Y", TensorProto.FLOAT, [4, 2])], + {"perm": [1, 0]}, + ), + # type-dispatch + ( + "Cast", + "", + 13, + [("X", TensorProto.FLOAT, [2, 4])], + [("Y", TensorProto.FLOAT16, [2, 4])], + {"to": int(TensorProto.FLOAT16)}, + ), + # second unary + ("Sigmoid", "", 13, [("X", TensorProto.FLOAT, [2, 4])], [("Y", TensorProto.FLOAT, [2, 4])], None), + # cuDNN: ConvTranspose (Conv already tested by test_registration_conv) + ( + "ConvTranspose", + "", + 13, + [("X", TensorProto.FLOAT, [1, 2, 3, 3]), ("W", TensorProto.FLOAT, [2, 3, 3, 3])], + [("Y", TensorProto.FLOAT, [1, 3, 5, 5])], + None, + ), + # cuDNN: LRN (local response normalization) + ( + "LRN", + "", + 13, + [("X", TensorProto.FLOAT, [1, 2, 4, 4])], + [("Y", TensorProto.FLOAT, [1, 2, 4, 4])], + {"size": 3}, + ), + ] + + claimed = [] + not_claimed = [] + errors = [] + + for op_type, domain, opset, inputs_info, outputs_info, attrs in probe_specs: + model = _make_simple_model(op_type, inputs_info, outputs_info, attrs=attrs, opset=opset, domain=domain) + with tempfile.NamedTemporaryFile(suffix=f"_probe_{op_type}.onnx", delete=False) as tmp: + model_path = tmp.name + try: + save(model, model_path) + sess_options = _create_session_options() + sess_options.graph_optimization_level = onnxrt.GraphOptimizationLevel.ORT_DISABLE_ALL + sess_options.add_provider_for_devices([target_device], {}) + sess = onnxrt.InferenceSession(model_path, sess_options=sess_options) + assigned_nodes, _ = _get_assigned_nodes(sess, CUDA_PLUGIN_EP_NAME) + if assigned_nodes: + claimed.append(op_type) + else: + not_claimed.append(op_type) + except Exception as e: + errors.append((op_type, str(e)[:120])) + finally: + if os.path.exists(model_path): + os.remove(model_path) + + # All probed ops should be claimed by the plugin EP + self.assertFalse( + not_claimed, + f"Plugin EP did not claim these key ops: {not_claimed}", + ) + self.assertFalse( + errors, + f"Errors probing ops: {errors}", + ) + self.assertGreater(len(claimed), 0, "No ops were claimed at all") + + # ---- Newly-included ops that previously lacked tests ---- + + def test_op_einsum(self): + """Test Einsum op (recently un-excluded from plugin build).""" + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Einsum", + [("A", TensorProto.FLOAT, [2, 3]), ("B", TensorProto.FLOAT, [3, 4])], + [("Y", TensorProto.FLOAT, [2, 4])], + attrs={"equation": "ij,jk->ik"}, + opset=12, + ) + feed = {"A": np.random.rand(2, 3).astype(np.float32), "B": np.random.rand(3, 4).astype(np.float32)} + result = _run_model_test(target_device, "Einsum", model, feed, lambda f: f["A"] @ f["B"]) + self.assertEqual(result, TEST_PASS, "Einsum test failed") + + def test_op_einsum_batch(self): + """Test Einsum op with batch matrix multiply.""" + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Einsum", + [("A", TensorProto.FLOAT, [2, 3, 4]), ("B", TensorProto.FLOAT, [2, 4, 5])], + [("Y", TensorProto.FLOAT, [2, 3, 5])], + attrs={"equation": "bij,bjk->bik"}, + opset=12, + ) + feed = {"A": np.random.rand(2, 3, 4).astype(np.float32), "B": np.random.rand(2, 4, 5).astype(np.float32)} + result = _run_model_test(target_device, "Einsum_batch", model, feed, lambda f: np.matmul(f["A"], f["B"])) + self.assertEqual(result, TEST_PASS, "Einsum batch test failed") + + def test_op_softmax(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Softmax", + [("X", TensorProto.FLOAT, [2, 5])], + [("Y", TensorProto.FLOAT, [2, 5])], + attrs={"axis": 1}, + opset=13, + ) + feed = {"X": np.random.rand(2, 5).astype(np.float32)} + + def expected(f): + x = f["X"] + e = np.exp(x - np.max(x, axis=1, keepdims=True)) + return e / np.sum(e, axis=1, keepdims=True) + + result = _run_model_test(target_device, "Softmax", model, feed, expected) + self.assertEqual(result, TEST_PASS, "Softmax test failed") + + def test_op_relu(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Relu", + [("X", TensorProto.FLOAT, [3, 4])], + [("Y", TensorProto.FLOAT, [3, 4])], + opset=14, + ) + feed = {"X": np.random.randn(3, 4).astype(np.float32)} + result = _run_model_test(target_device, "Relu", model, feed, lambda f: np.maximum(f["X"], 0)) + self.assertEqual(result, TEST_PASS, "Relu test failed") + + def test_op_sigmoid(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Sigmoid", + [("X", TensorProto.FLOAT, [3, 4])], + [("Y", TensorProto.FLOAT, [3, 4])], + opset=13, + ) + feed = {"X": np.random.randn(3, 4).astype(np.float32)} + result = _run_model_test(target_device, "Sigmoid", model, feed, lambda f: 1.0 / (1.0 + np.exp(-f["X"]))) + self.assertEqual(result, TEST_PASS, "Sigmoid test failed") + + def test_op_tanh(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Tanh", + [("X", TensorProto.FLOAT, [3, 4])], + [("Y", TensorProto.FLOAT, [3, 4])], + opset=13, + ) + feed = {"X": np.random.randn(3, 4).astype(np.float32)} + result = _run_model_test(target_device, "Tanh", model, feed, lambda f: np.tanh(f["X"])) + self.assertEqual(result, TEST_PASS, "Tanh test failed") + + def test_op_transpose(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Transpose", + [("X", TensorProto.FLOAT, [2, 3, 4])], + [("Y", TensorProto.FLOAT, [4, 2, 3])], + attrs={"perm": [2, 0, 1]}, + opset=13, + ) + feed = {"X": np.random.rand(2, 3, 4).astype(np.float32)} + result = _run_model_test(target_device, "Transpose", model, feed, lambda f: np.transpose(f["X"], (2, 0, 1))) + self.assertEqual(result, TEST_PASS, "Transpose test failed") + + def test_op_cast(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Cast", + [("X", TensorProto.FLOAT, [3, 4])], + [("Y", TensorProto.FLOAT16, [3, 4])], + attrs={"to": int(TensorProto.FLOAT16)}, + opset=13, + ) + feed = {"X": np.random.rand(3, 4).astype(np.float32)} + result = _run_model_test( + target_device, "Cast", model, feed, lambda f: f["X"].astype(np.float16), rtol=1e-2, atol=1e-2 + ) + self.assertEqual(result, TEST_PASS, "Cast test failed") + + def test_op_where(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Where", + [ + ("cond", TensorProto.BOOL, [3, 4]), + ("X", TensorProto.FLOAT, [3, 4]), + ("Y", TensorProto.FLOAT, [3, 4]), + ], + [("out", TensorProto.FLOAT, [3, 4])], + opset=16, + ) + cond = np.random.randint(0, 2, size=(3, 4)).astype(bool) + x = np.random.rand(3, 4).astype(np.float32) + y = np.random.rand(3, 4).astype(np.float32) + feed = {"cond": cond, "X": x, "Y": y} + result = _run_model_test(target_device, "Where", model, feed, lambda f: np.where(f["cond"], f["X"], f["Y"])) + self.assertEqual(result, TEST_PASS, "Where test failed") + + def test_op_flatten(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Flatten", + [("X", TensorProto.FLOAT, [2, 3, 4])], + [("Y", TensorProto.FLOAT, [2, 12])], + attrs={"axis": 1}, + opset=13, + ) + feed = {"X": np.random.rand(2, 3, 4).astype(np.float32)} + result = _run_model_test(target_device, "Flatten", model, feed, lambda f: f["X"].reshape(2, 12)) + self.assertEqual(result, TEST_PASS, "Flatten test failed") + + def test_op_argmax(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "ArgMax", + [("X", TensorProto.FLOAT, [3, 5])], + [("Y", TensorProto.INT64, [3, 1])], + attrs={"axis": 1, "keepdims": 1}, + opset=13, + ) + feed = {"X": np.random.rand(3, 5).astype(np.float32)} + result = _run_model_test( + target_device, "ArgMax", model, feed, lambda f: np.argmax(f["X"], axis=1).reshape(3, 1) + ) + self.assertEqual(result, TEST_PASS, "ArgMax test failed") + + def test_op_topk(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "TopK", + [("X", TensorProto.FLOAT, [3, 6]), ("K", TensorProto.INT64, [1])], + [("values", TensorProto.FLOAT, [3, 3]), ("indices", TensorProto.INT64, [3, 3])], + attrs={"axis": 1}, + opset=11, + ) + x = np.random.rand(3, 6).astype(np.float32) + k = np.array([3], dtype=np.int64) + feed = {"X": x, "K": k} + + def expected(f): + idx = np.argsort(-f["X"], axis=1)[:, :3] + vals = np.take_along_axis(f["X"], idx, axis=1) + return [vals, idx] + + result = _run_model_test(target_device, "TopK", model, feed, expected) + self.assertEqual(result, TEST_PASS, "TopK test failed") + + def test_op_layer_normalization(self): + """Test LayerNormalization — critical for transformer models.""" + target_device = get_cuda_plugin_device() + normalized_shape = 8 + model = _make_simple_model( + "LayerNormalization", + [ + ("X", TensorProto.FLOAT, [2, 3, normalized_shape]), + ("scale", TensorProto.FLOAT, [normalized_shape]), + ("bias", TensorProto.FLOAT, [normalized_shape]), + ], + [("Y", TensorProto.FLOAT, [2, 3, normalized_shape])], + attrs={"axis": -1, "epsilon": 1e-5}, + opset=17, + ) + scale = np.ones(normalized_shape, dtype=np.float32) + bias = np.zeros(normalized_shape, dtype=np.float32) + scale_init = helper.make_tensor("scale", TensorProto.FLOAT, [normalized_shape], scale.tolist()) + bias_init = helper.make_tensor("bias", TensorProto.FLOAT, [normalized_shape], bias.tolist()) + model.graph.initializer.append(scale_init) + model.graph.initializer.append(bias_init) + + x = np.random.rand(2, 3, normalized_shape).astype(np.float32) + feed = {"X": x} + + def expected(f): + x = f["X"] + mean = np.mean(x, axis=-1, keepdims=True) + var = np.var(x, axis=-1, keepdims=True) + return (x - mean) / np.sqrt(var + 1e-5) * scale + bias + + result = _run_model_test(target_device, "LayerNormalization", model, feed, expected) + self.assertEqual(result, TEST_PASS, "LayerNormalization test failed") + + def test_op_instance_normalization(self): + target_device = get_cuda_plugin_device() + n_channels = 3 + model = _make_simple_model( + "InstanceNormalization", + [ + ("X", TensorProto.FLOAT, [1, n_channels, 4, 4]), + ("scale", TensorProto.FLOAT, [n_channels]), + ("B", TensorProto.FLOAT, [n_channels]), + ], + [("Y", TensorProto.FLOAT, [1, n_channels, 4, 4])], + attrs={"epsilon": 1e-5}, + opset=6, + ) + scale = np.ones(n_channels, dtype=np.float32) + bias = np.zeros(n_channels, dtype=np.float32) + model.graph.initializer.append(helper.make_tensor("scale", TensorProto.FLOAT, [n_channels], scale.tolist())) + model.graph.initializer.append(helper.make_tensor("B", TensorProto.FLOAT, [n_channels], bias.tolist())) + + x = np.random.rand(1, n_channels, 4, 4).astype(np.float32) + feed = {"X": x} + + def expected(f): + x = f["X"] + result = np.empty_like(x) + for c in range(n_channels): + ch = x[0, c] + mean = ch.mean() + var = ch.var() + result[0, c] = (ch - mean) / np.sqrt(var + 1e-5) * scale[c] + bias[c] + return result + + result = _run_model_test(target_device, "InstanceNormalization", model, feed, expected) + self.assertEqual(result, TEST_PASS, "InstanceNormalization test failed") + + def test_op_conv_transpose(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "ConvTranspose", + [ + ("X", TensorProto.FLOAT, [1, 3, 4, 4]), + ("W", TensorProto.FLOAT, [3, 2, 3, 3]), + ], + [("Y", TensorProto.FLOAT, [1, 2, 6, 6])], + attrs={"kernel_shape": [3, 3], "strides": [1, 1], "pads": [0, 0, 0, 0]}, + opset=11, + ) + x = np.random.rand(1, 3, 4, 4).astype(np.float32) + w = np.random.rand(3, 2, 3, 3).astype(np.float32) + feed = {"X": x, "W": w} + + def expected(f): + return F.conv_transpose2d(torch.from_numpy(f["X"]), torch.from_numpy(f["W"])).numpy() + + result = _run_model_test(target_device, "ConvTranspose", model, feed, expected) + self.assertEqual(result, TEST_PASS, "ConvTranspose test failed") + + def test_op_reduce_mean(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "ReduceMean", + [("X", TensorProto.FLOAT, [3, 4, 5])], + [("Y", TensorProto.FLOAT, [3, 1, 5])], + attrs={"axes": [1], "keepdims": 1}, + opset=13, + ) + feed = {"X": np.random.rand(3, 4, 5).astype(np.float32)} + result = _run_model_test( + target_device, "ReduceMean", model, feed, lambda f: np.mean(f["X"], axis=1, keepdims=True) + ) + self.assertEqual(result, TEST_PASS, "ReduceMean test failed") + + def test_op_reduce_sum(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "ReduceSum", + [("X", TensorProto.FLOAT, [3, 4, 5]), ("axes", TensorProto.INT64, [1])], + [("Y", TensorProto.FLOAT, [3, 1, 5])], + attrs={"keepdims": 1}, + opset=13, + ) + axes_init = helper.make_tensor("axes", TensorProto.INT64, [1], [1]) + model.graph.initializer.append(axes_init) + feed = {"X": np.random.rand(3, 4, 5).astype(np.float32)} + result = _run_model_test( + target_device, "ReduceSum", model, feed, lambda f: np.sum(f["X"], axis=1, keepdims=True) + ) + self.assertEqual(result, TEST_PASS, "ReduceSum test failed") + + def test_op_gather_nd(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "GatherND", + [ + ("data", TensorProto.FLOAT, [2, 3, 4]), + ("indices", TensorProto.INT64, [2, 1]), + ], + [("Y", TensorProto.FLOAT, [2, 4])], + attrs={"batch_dims": 1}, + opset=12, + ) + data = np.random.rand(2, 3, 4).astype(np.float32) + indices = np.array([[1], [2]], dtype=np.int64) + feed = {"data": data, "indices": indices} + + def expected(f): + return np.array([f["data"][0, 1], f["data"][1, 2]]) + + result = _run_model_test(target_device, "GatherND", model, feed, expected) + self.assertEqual(result, TEST_PASS, "GatherND test failed") + + def test_op_scatter_elements(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "ScatterElements", + [ + ("data", TensorProto.FLOAT, [3, 3]), + ("indices", TensorProto.INT64, [2, 3]), + ("updates", TensorProto.FLOAT, [2, 3]), + ], + [("Y", TensorProto.FLOAT, [3, 3])], + attrs={"axis": 0}, + opset=16, + ) + data = np.zeros((3, 3), dtype=np.float32) + indices = np.array([[1, 0, 2], [0, 2, 1]], dtype=np.int64) + updates = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + feed = {"data": data, "indices": indices, "updates": updates} + + def expected(f): + result = f["data"].copy() + for i in range(2): + for j in range(3): + result[f["indices"][i, j], j] = f["updates"][i, j] + return result + + result = _run_model_test(target_device, "ScatterElements", model, feed, expected) + self.assertEqual(result, TEST_PASS, "ScatterElements test failed") + + def test_op_onehot(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "OneHot", + [ + ("indices", TensorProto.INT64, [4]), + ("depth", TensorProto.INT64, [1]), + ("values", TensorProto.FLOAT, [2]), + ], + [("Y", TensorProto.FLOAT, [4, 6])], + attrs={"axis": 1}, + opset=11, + ) + indices = np.array([0, 2, 4, 5], dtype=np.int64) + depth = np.array([6], dtype=np.int64) + values = np.array([0.0, 1.0], dtype=np.float32) + feed = {"indices": indices, "depth": depth, "values": values} + + def expected(f): + result = np.zeros((4, 6), dtype=np.float32) + for i, idx in enumerate(f["indices"]): + result[i, idx] = 1.0 + return result + + result = _run_model_test(target_device, "OneHot", model, feed, expected) + self.assertEqual(result, TEST_PASS, "OneHot test failed") + + # NOTE: Range is excluded — it runs on CPU (shape computation op, not claimed by CUDA EP). + + def test_op_non_zero(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "NonZero", + [("X", TensorProto.FLOAT, [3, 4])], + [("Y", TensorProto.INT64, None)], + opset=13, + ) + x = np.array([[1, 0, 3, 0], [0, 5, 0, 7], [0, 0, 0, 10]], dtype=np.float32) + feed = {"X": x} + result = _run_model_test(target_device, "NonZero", model, feed, lambda f: np.array(np.nonzero(f["X"]))) + self.assertEqual(result, TEST_PASS, "NonZero test failed") + + def test_op_grid_sample(self): + target_device = get_cuda_plugin_device() + n, c, h, w = 1, 1, 4, 4 + model = _make_simple_model( + "GridSample", + [ + ("X", TensorProto.FLOAT, [n, c, h, w]), + ("grid", TensorProto.FLOAT, [n, 2, 2, 2]), + ], + [("Y", TensorProto.FLOAT, [n, c, 2, 2])], + attrs={"mode": "bilinear", "padding_mode": "zeros", "align_corners": 0}, + opset=16, + ) + x = np.random.rand(n, c, h, w).astype(np.float32) + grid = np.random.rand(n, 2, 2, 2).astype(np.float32) * 2 - 1 # in [-1, 1] + feed = {"X": x, "grid": grid} + + def expected(f): + return F.grid_sample( + torch.from_numpy(f["X"]), + torch.from_numpy(f["grid"]), + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ).numpy() + + result = _run_model_test(target_device, "GridSample", model, feed, expected, rtol=1e-3, atol=1e-3) + self.assertEqual(result, TEST_PASS, "GridSample test failed") + + def test_op_gelu(self): + """Test Gelu contrib op — important for transformer models.""" + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Gelu", + [("X", TensorProto.FLOAT, [2, 8])], + [("Y", TensorProto.FLOAT, [2, 8])], + domain="com.microsoft", + opset=13, + ) + feed = {"X": np.random.randn(2, 8).astype(np.float32)} + + def expected(f): + return torch.nn.functional.gelu(torch.from_numpy(f["X"])).numpy() + + result = _run_model_test(target_device, "Gelu", model, feed, expected) + self.assertEqual(result, TEST_PASS, "Gelu test failed") + + def test_op_bias_gelu(self): + """Test BiasGelu contrib op.""" + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "BiasGelu", + [("X", TensorProto.FLOAT, [2, 8]), ("bias", TensorProto.FLOAT, [8])], + [("Y", TensorProto.FLOAT, [2, 8])], + domain="com.microsoft", + opset=13, + ) + bias = np.random.randn(8).astype(np.float32) + model.graph.initializer.append(helper.make_tensor("bias", TensorProto.FLOAT, [8], bias.tolist())) + feed = {"X": np.random.randn(2, 8).astype(np.float32)} + + def expected(f): + x = torch.from_numpy(f["X"]) + torch.from_numpy(bias) + return torch.nn.functional.gelu(x).numpy() + + result = _run_model_test(target_device, "BiasGelu", model, feed, expected) + self.assertEqual(result, TEST_PASS, "BiasGelu test failed") + + def test_op_fused_matmul(self): + """Test FusedMatMul contrib op (MatMul with alpha).""" + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "FusedMatMul", + [("A", TensorProto.FLOAT, [3, 4]), ("B", TensorProto.FLOAT, [4, 5])], + [("Y", TensorProto.FLOAT, [3, 5])], + attrs={"alpha": 2.0}, + domain="com.microsoft", + opset=13, + ) + feed = {"A": np.random.rand(3, 4).astype(np.float32), "B": np.random.rand(4, 5).astype(np.float32)} + result = _run_model_test(target_device, "FusedMatMul", model, feed, lambda f: 2.0 * (f["A"] @ f["B"])) + self.assertEqual(result, TEST_PASS, "FusedMatMul test failed") + + def test_op_trilu(self): + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "Trilu", + [("X", TensorProto.FLOAT, [4, 4])], + [("Y", TensorProto.FLOAT, [4, 4])], + attrs={"upper": 1}, + opset=14, + ) + feed = {"X": np.random.rand(4, 4).astype(np.float32)} + result = _run_model_test(target_device, "Trilu", model, feed, lambda f: np.triu(f["X"])) + self.assertEqual(result, TEST_PASS, "Trilu test failed") + + def test_op_matmul_integer(self): + """Test MatMulInteger — used in INT8 quantization pipelines.""" + target_device = get_cuda_plugin_device() + model = _make_simple_model( + "MatMulInteger", + [ + ("A", TensorProto.INT8, [3, 4]), + ("B", TensorProto.INT8, [4, 5]), + ], + [("Y", TensorProto.INT32, [3, 5])], + opset=10, + ) + a = np.random.randint(-128, 127, size=(3, 4)).astype(np.int8) + b = np.random.randint(-128, 127, size=(4, 5)).astype(np.int8) + feed = {"A": a, "B": b} + result = _run_model_test( + target_device, + "MatMulInteger", + model, + feed, + lambda f: f["A"].astype(np.int32) @ f["B"].astype(np.int32), + ) + self.assertEqual(result, TEST_PASS, "MatMulInteger test failed") + + # ---- MemcpyFromHost / MemcpyToHost tests ---- + # These tests explicitly place MemcpyFromHost/MemcpyToHost nodes in the graph + # to directly exercise the plugin-side copy kernels. + + def test_memcpy_from_host_explicit(self): + """Explicit MemcpyFromHost node: CPU input → GPU copy → Relu on GPU.""" + target_device = get_cuda_plugin_device() + # X (CPU) → MemcpyFromHost → X_gpu → Relu → Y + copy_node = helper.make_node("MemcpyFromHost", ["X"], ["X_gpu"]) + relu_node = helper.make_node("Relu", ["X_gpu"], ["Y"]) + graph = helper.make_graph( + [copy_node, relu_node], + "test-explicit-memcpy-from-host", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 4])], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 4])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + feed = {"X": np.random.randn(3, 4).astype(np.float32)} + result = _run_model_test( + target_device, + "MemcpyFromHost_explicit", + model, + feed, + lambda f: np.maximum(f["X"], 0), + ) + self.assertEqual(result, TEST_PASS, "Explicit MemcpyFromHost test failed") + + def test_memcpy_to_host_explicit(self): + """Explicit MemcpyToHost node: GPU Add → GPU-to-CPU copy → output.""" + target_device = get_cuda_plugin_device() + # A, B → Add (GPU) → sum_gpu → MemcpyToHost → Y (CPU) + add_node = helper.make_node("Add", ["A", "B"], ["sum_gpu"]) + copy_node = helper.make_node("MemcpyToHost", ["sum_gpu"], ["Y"]) + graph = helper.make_graph( + [add_node, copy_node], + "test-explicit-memcpy-to-host", + [ + helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [2, 3]), + ], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + feed = { + "A": np.random.rand(2, 3).astype(np.float32), + "B": np.random.rand(2, 3).astype(np.float32), + } + result = _run_model_test( + target_device, + "MemcpyToHost_explicit", + model, + feed, + lambda f: f["A"] + f["B"], + ) + self.assertEqual(result, TEST_PASS, "Explicit MemcpyToHost test failed") + + def test_memcpy_roundtrip_explicit(self): + """Explicit both directions: CPU → MemcpyFromHost → Relu (GPU) → MemcpyToHost → CPU.""" + target_device = get_cuda_plugin_device() + copy_in = helper.make_node("MemcpyFromHost", ["X"], ["X_gpu"]) + relu_node = helper.make_node("Relu", ["X_gpu"], ["relu_out"]) + copy_out = helper.make_node("MemcpyToHost", ["relu_out"], ["Y"]) + graph = helper.make_graph( + [copy_in, relu_node, copy_out], + "test-explicit-memcpy-roundtrip", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 5])], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 5])], + ) + opset = OperatorSetIdProto() + opset.version = 13 + model = helper.make_model(graph, opset_imports=[opset]) + feed = {"X": np.random.randn(4, 5).astype(np.float32)} + result = _run_model_test( + target_device, + "MemcpyRoundtrip_explicit", + model, + feed, + lambda f: np.maximum(f["X"], 0), + ) + self.assertEqual(result, TEST_PASS, "Explicit MemcpyFromHost→Relu→MemcpyToHost roundtrip test failed") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 5ff0572c927c6..23c47e84c1630 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -20,6 +20,7 @@ import numpy import torch +from cuda_plugin_ep_helper import get_cuda_provider_name, resolve_cuda_plugin_ep from einops import rearrange, repeat # --- ONNX and Torch/Numpy Dtype Mappings --- @@ -34,7 +35,7 @@ from packaging import version from parameterized import parameterized -from onnxruntime import InferenceSession, SessionOptions, get_available_providers, get_build_info +from onnxruntime import InferenceSession, SessionOptions, get_build_info from onnxruntime import __version__ as ort_version # Set seed for reproducibility @@ -456,7 +457,7 @@ def gqa_prompt_func( new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) + ort_session = InferenceSession(onnx_model_str, sess_options, providers=[resolve_cuda_plugin_ep(ep)]) io_binding = ort_session.io_binding() # Determine input device for binding @@ -492,8 +493,9 @@ def gqa_prompt_func( # total_sequence_length is INT32 [1] # Schema requires this to be on CPU (OrtMemTypeCPUInput) - tsl = torch.tensor([config.q_sequence_length], dtype=torch.int32, device="cpu") - bind_tensor(io_binding, "total_sequence_length", tsl, "cpu", TensorProto.INT32) + cpu_device = torch.device("cpu") + tsl = torch.tensor([config.q_sequence_length], dtype=torch.int32, device=cpu_device) + bind_tensor(io_binding, "total_sequence_length", tsl, cpu_device, TensorProto.INT32) # 5. Optional inputs if cos is not None: @@ -616,7 +618,7 @@ def gqa_past_func( sess_options = SessionOptions() # sess_options.log_severity_level = 0 - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) + ort_session = InferenceSession(onnx_model_str, sess_options, providers=[resolve_cuda_plugin_ep(ep)]) io_binding = ort_session.io_binding() # Common inputs @@ -653,8 +655,10 @@ def gqa_past_func( seqlens_k_int32 = seqlens_k.to(dtype=torch.int32, device=device) bind_tensor(io_binding, "seqlens_k", seqlens_k_int32, device, TensorProto.INT32) - tsl = torch.tensor([total_seq_len], dtype=torch.int32, device=device) - bind_tensor(io_binding, "total_sequence_length", tsl, device, TensorProto.INT32) + # GroupQueryAttention expects total_sequence_length as CPU input. + cpu_device = torch.device("cpu") + tsl = torch.tensor([total_seq_len], dtype=torch.int32, device=cpu_device) + bind_tensor(io_binding, "total_sequence_length", tsl, cpu_device, TensorProto.INT32) # 5. Optional inputs if cos is not None: @@ -1926,7 +1930,7 @@ def gqa_cuda_quantized_test_cases(is_past: bool): def has_cuda_provider(): - return "CUDAExecutionProvider" in get_available_providers() + return get_cuda_provider_name() is not None def has_cuda_device(min_capability: int = 80): @@ -2342,7 +2346,7 @@ def test_gqa_rope_separate_qkv_bug(self): The bug caused q_out to be nullptr when unpacking separate QKV with only Q rotation (standard GQA), leading to unrotated Q being used in Attention. """ - if "CUDAExecutionProvider" not in get_available_providers(): + if not has_cuda_provider(): self.skipTest("CUDA required") # Config that triggers the path: Prompt phase, Separate QKV inputs, RoPE enabled @@ -2381,7 +2385,7 @@ def test_gqa_int8_large_seq_batch4(self): Regression test for batch_size=4 + max_seq_len=8192 + int8 KV cache crash. This reproduces a CUDA illegal memory access due to scratch size under-allocation. """ - if "CUDAExecutionProvider" not in get_available_providers(): + if not has_cuda_provider(): self.skipTest("CUDA required") # Config that triggers the crash: batch=4, large max_seq_len, int8 kv diff --git a/onnxruntime/test/python/transformers/test_moe_cuda.py b/onnxruntime/test/python/transformers/test_moe_cuda.py index c09d8bacf1fa2..4b9f4e3634a9b 100644 --- a/onnxruntime/test/python/transformers/test_moe_cuda.py +++ b/onnxruntime/test/python/transformers/test_moe_cuda.py @@ -17,6 +17,7 @@ import numpy import torch import torch.nn.functional as F +from cuda_plugin_ep_helper import get_cuda_provider_name from onnx import TensorProto, helper from parameterized import parameterized from torch import nn @@ -28,10 +29,19 @@ onnxruntime.preload_dlls() + # Determine the execution provider and device based on CUDA availability. -use_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers() and torch.cuda.is_available() +cuda_provider = get_cuda_provider_name() +use_cuda = cuda_provider is not None device = torch.device("cuda:0" if use_cuda else "cpu") -ort_provider = ["CUDAExecutionProvider"] if use_cuda else ["CPUExecutionProvider"] + + +def get_ort_provider(): + if not use_cuda: + return ["CPUExecutionProvider"] + + return [cuda_provider] + torch.manual_seed(42) numpy.random.seed(42) @@ -586,11 +596,12 @@ def create_ort_session(self, moe_onnx_graph): sess_options = SessionOptions() sess_options.log_severity_level = 2 + providers = get_ort_provider() try: - ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider) + ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=providers) except Exception as e: - print(f"Failed to create ONNX Runtime session with provider {ort_provider}: {e}") + print(f"Failed to create ONNX Runtime session with provider {providers}: {e}") print("Skipping ONNX Runtime execution for this test case.") return None @@ -1403,7 +1414,7 @@ def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): def has_bf16_moe(): - if "CUDAExecutionProvider" not in onnxruntime.get_available_providers() or not torch.cuda.is_available(): + if not use_cuda or not torch.cuda.is_available(): return False major, _ = torch.cuda.get_device_capability() return major >= 8 diff --git a/onnxruntime/test/unittest_util/base_tester.cc b/onnxruntime/test/unittest_util/base_tester.cc index 1f744df14cfb8..6622960a57680 100644 --- a/onnxruntime/test/unittest_util/base_tester.cc +++ b/onnxruntime/test/unittest_util/base_tester.cc @@ -41,6 +41,13 @@ void DebugTrap() { } #endif +bool ShouldRouteCudaToDynamicPluginEp(const std::optional& dynamic_plugin_ep_name) { + // Route CUDA requests to the CUDA plugin EP when unit test main has initialized + // dynamic plugin EP infrastructure with the CUDA plugin registration. + return dynamic_plugin_ep_name.has_value() && + *dynamic_plugin_ep_name == dynamic_plugin_ep_infra::kCudaPluginExecutionProviderName; +} + } // namespace BaseTester::~BaseTester() { @@ -689,9 +696,10 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, #endif const auto dynamic_plugin_ep_name = dynamic_plugin_ep_infra::GetEpName(); + const bool route_cuda_to_dynamic_plugin_ep = ShouldRouteCudaToDynamicPluginEp(dynamic_plugin_ep_name); std::optional> provider_types_including_dynamic_plugin_ep{}; - if (dynamic_plugin_ep_name.has_value()) { + if (dynamic_plugin_ep_name.has_value() && !route_cuda_to_dynamic_plugin_ep) { ORT_ENFORCE(std::find(all_provider_types.begin(), all_provider_types.end(), *dynamic_plugin_ep_name) == all_provider_types.end(), "Dynamic plugin EP name conflicts with a known EP name: ", *dynamic_plugin_ep_name); @@ -716,7 +724,7 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, if (provider_type == onnxruntime::kCpuExecutionProvider) execution_provider = DefaultCpuExecutionProvider(); else if (provider_type == onnxruntime::kCudaExecutionProvider) - execution_provider = DefaultCudaExecutionProvider(); + execution_provider = route_cuda_to_dynamic_plugin_ep ? dynamic_plugin_ep_infra::MakeEp() : DefaultCudaExecutionProvider(); #ifdef ENABLE_CUDA_NHWC_OPS else if (provider_type == onnxruntime::kCudaNHWCExecutionProvider) execution_provider = DefaultCudaNHWCExecutionProvider(); @@ -812,11 +820,20 @@ void BaseTester::ExecuteModelForEps( bool allow_released_onnx_opset_only, size_t* number_of_pre_packed_weights_counter, size_t* number_of_shared_pre_packed_weights_counter) { + const auto dynamic_plugin_ep_name = dynamic_plugin_ep_infra::GetEpName(); + const bool route_cuda_to_dynamic_plugin_ep = ShouldRouteCudaToDynamicPluginEp(dynamic_plugin_ep_name); + for (auto& entry : execution_providers) { // Be noted, entry in execution providers passed in OpTester will be std::moved in the first BaseTester::Run(), // To make the error more obvious to debug (instead of a segment fault), we do check explicitly here. ASSERT_TRUE(entry) << "Execution provider entry invalid."; + if (route_cuda_to_dynamic_plugin_ep && entry->Type() == kCudaExecutionProvider) { + auto plugin_ep = dynamic_plugin_ep_infra::MakeEp(); + ASSERT_TRUE(plugin_ep) << "Failed to create CUDA plugin EP while routing from CUDAExecutionProvider."; + entry = std::move(plugin_ep); + } + if (entry->Type() == kDmlExecutionProvider) { sess_options.enable_mem_pattern = false; sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; diff --git a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h index 0962df8e35308..b5bf075a25447 100644 --- a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h +++ b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h @@ -29,6 +29,8 @@ namespace test { // unit testing infrastructure. namespace dynamic_plugin_ep_infra { +inline constexpr std::string_view kCudaPluginExecutionProviderName{"CudaPluginExecutionProvider"}; + // Note: `Initialize()` and `Shutdown()` are not thread-safe. // They should be called before and after calls to most of the other functions in this namespace. // The exception to this is `ParseInitializationConfig()`, which may be called before `Initialize()`. diff --git a/orttraining/orttraining/training_ops/cuda/reduction/reduction_ops.cc b/orttraining/orttraining/training_ops/cuda/reduction/reduction_ops.cc index ff8f5f81e4e37..0c59efaee62c8 100644 --- a/orttraining/orttraining/training_ops/cuda/reduction/reduction_ops.cc +++ b/orttraining/orttraining/training_ops/cuda/reduction/reduction_ops.cc @@ -58,8 +58,9 @@ Status ReduceKernel::ComputeImplEx(OpKernelContext* ctx, cudnn Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); const bool fast_reduction = fast_reduction_ && !ctx->GetUseDeterministicCompute(); - return ReduceComputeCore(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), *X, prepare_reduce_metadata, *Y, cudnn_reduce_op, axes, - calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction, ctx->GetComputeStream()); + return ReduceComputeCore(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), this, *X, prepare_reduce_metadata, *Y, cudnn_reduce_op, axes, + calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction, + Stream(ctx), GetComputeStream(ctx), GetCudnnHandle(ctx)); } template <> diff --git a/tools/ci_build/cuda_plugin_parity_report.py b/tools/ci_build/cuda_plugin_parity_report.py new file mode 100755 index 0000000000000..1dffb5ea1292b --- /dev/null +++ b/tools/ci_build/cuda_plugin_parity_report.py @@ -0,0 +1,737 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +CUDA EP Plugin Registration Parity Report + +Compares kernel registrations between the bundled CUDA EP and the plugin CUDA EP +by statically parsing source files or interrogating the kernel registry at runtime. +Produces a report showing which ops are in both builds, only in bundled, or only in plugin. + +Usage: + # Static parse mode (default): + python tools/ci_build/cuda_plugin_parity_report.py [--repo-root /path/to/onnxruntime] + + # Runtime inquiry mode: + python tools/ci_build/cuda_plugin_parity_report.py --runtime [--plugin-ep-lib /path/to/libonnxruntime_providers_cuda_plugin.so] +""" + +import argparse +import os +import re +import sys +from collections import defaultdict +from pathlib import Path + +# Regex patterns for kernel registration macros +# These macros define kernel classes and are the source of truth for op registrations. +KERNEL_EX_PATTERNS = [ + # ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) + re.compile( + r"ONNX_OPERATOR_KERNEL_EX\s*\(\s*" + r"(\w+)\s*,\s*" # name + r"(\w+)\s*,\s*" # domain + r"(\d+)\s*,\s*" # version + r"(\w+)\s*," # provider + ), + # ONNX_OPERATOR_TYPED_KERNEL_EX(name, domain, ver, type, provider, builder, ...) + re.compile( + r"ONNX_OPERATOR_TYPED_KERNEL_EX\s*\(\s*" + r"(\w+)\s*,\s*" # name + r"(\w+)\s*,\s*" # domain + r"(\d+)\s*,\s*" # version + r"(\w+)\s*,\s*" # type + r"(\w+)\s*," # provider + ), + # ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, start_ver, end_ver, provider, builder, ...) + re.compile( + r"ONNX_OPERATOR_VERSIONED_KERNEL_EX\s*\(\s*" + r"(\w+)\s*,\s*" # name + r"(\w+)\s*,\s*" # domain + r"(\d+)\s*,\s*" # start_version + r"(\d+)\s*,\s*" # end_version + r"(\w+)\s*," # provider + ), + # ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, domain, start_ver, end_ver, type, provider, builder, ...) + re.compile( + r"ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX\s*\(\s*" + r"(\w+)\s*,\s*" # name + r"(\w+)\s*,\s*" # domain + r"(\d+)\s*,\s*" # start_version + r"(\d+)\s*,\s*" # end_version + r"(\w+)\s*,\s*" # type + r"(\w+)\s*," # provider + ), +] + +# Patterns for contrib ops (CUDA_MS_OP macros expand to ONNX_OPERATOR macros internally) +# Just match the ONNX_OPERATOR_*_KERNEL_EX patterns since the CUDA_MS_OP macros are wrappers. + +# Terminal kernel registration macro names (op_name at arg 0, domain at arg 1 in all variants) +_TERMINAL_KERNEL_MACROS = { + "ONNX_OPERATOR_KERNEL_EX", + "ONNX_OPERATOR_TYPED_KERNEL_EX", + "ONNX_OPERATOR_VERSIONED_KERNEL_EX", + "ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX", + "ONNX_OPERATOR_TWO_TYPED_KERNEL_EX", + "ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX", +} + + +def _preprocess_content(file_path): + """Read a file and preprocess: strip comments, join continuation lines.""" + try: + content = Path(file_path).read_text(errors="replace") + except OSError: + return None + content = re.sub(r"//.*?$", "", content, flags=re.MULTILINE) + content = re.sub(r"/\*.*?\*/", "", content, flags=re.DOTALL) + content = re.sub(r"\\\s*\n\s*", " ", content) + return content + + +def _strip_define_bodies(content): + """Replace #define lines with blanks so regex only matches non-define code.""" + lines = content.split("\n") + return "\n".join("" if re.match(r"\s*#\s*define\b", line) else line for line in lines) + + +def _parse_macro_args_at(text, pos): + """Parse balanced-parentheses argument list starting at *pos* (must be '('). + Returns a list of argument strings, or None on failure.""" + if pos >= len(text) or text[pos] != "(": + return None + depth = 0 + args = [] + arg_start = pos + 1 + for i in range(pos, len(text)): + ch = text[i] + if ch == "(": + if depth == 0: + arg_start = i + 1 + depth += 1 + elif ch == ")": + depth -= 1 + if depth == 0: + arg = text[arg_start:i].strip() + if arg or args: # keep empty trailing arg if there were prior args + args.append(arg) + return args + elif ch == "," and depth == 1: + args.append(text[arg_start:i].strip()) + arg_start = i + 1 + return None + + +def _parse_macro_definitions(content): + """Parse ``#define NAME(params) body`` statements. + Returns ``{name: (param_name_list, body_string)}``.""" + macros = {} + for m in re.finditer(r"^#\s*define\s+(\w+)\s*\(([^)]*)\)\s*(.*)", content, re.MULTILINE): + name = m.group(1) + params = [p.strip() for p in m.group(2).split(",") if p.strip() and p.strip() != "..."] + body = m.group(3).strip() + macros[name] = (params, body) + return macros + + +def _find_calls_in_text(text, target_names): + """Find macro invocations in *text* whose name is in *target_names*. + Returns list of ``(macro_name, [arg, ...])``.""" + results = [] + for m in re.finditer(r"\b(\w+)\s*\(", text): + name = m.group(1) + if name in target_names: + args = _parse_macro_args_at(text, m.end() - 1) + if args is not None: + results.append((name, args)) + return results + + +def _resolve_through_chain(info_field, call_args, parent_params): + """Resolve an ``('param', idx)`` or ``('literal', val)`` through one macro-call level.""" + kind, value = info_field + if kind == "literal": + return info_field + if kind == "param": + idx = value + if idx < len(call_args): + arg_val = call_args[idx].strip() + if arg_val in parent_params: + return ("param", parent_params.index(arg_val)) + return ("literal", arg_val) + return info_field + + +def _resolve_kernel_macros(macros): + """Determine which macros transitively expand to ``ONNX_OPERATOR_*_KERNEL_EX`` + and how their parameters map to (op_name, domain). + + Returns ``{macro_name: [(op_name_info, domain_info), ...]}`` where each + ``*_info`` is ``('literal', str_value)`` or ``('param', int_index)``. + """ + kernel_macros = {} # macro_name -> [(op_info, dom_info), ...] + + # Phase 1: macros whose body directly contains a terminal KERNEL_EX call + for macro_name, (params, body) in macros.items(): + calls = _find_calls_in_text(body, _TERMINAL_KERNEL_MACROS) + entries = set() + for _call_name, call_args in calls: + if len(call_args) >= 2: + a0, a1 = call_args[0].strip(), call_args[1].strip() + op_info = ("param", params.index(a0)) if a0 in params else ("literal", a0) + dom_info = ("param", params.index(a1)) if a1 in params else ("literal", a1) + entries.add((op_info, dom_info)) + if entries: + kernel_macros[macro_name] = list(entries) + + # Phase 2: iteratively resolve higher-level wrapper macros + changed = True + for _ in range(20): # bounded iteration + if not changed: + break + changed = False + for macro_name, (params, body) in macros.items(): + if macro_name in kernel_macros: + continue + calls = _find_calls_in_text(body, set(kernel_macros.keys())) + entries = set() + for call_name, call_args in calls: + for child_op, child_dom in kernel_macros[call_name]: + new_op = _resolve_through_chain(child_op, call_args, params) + new_dom = _resolve_through_chain(child_dom, call_args, params) + entries.add((new_op, new_dom)) + if entries: + kernel_macros[macro_name] = list(entries) + changed = True + + return kernel_macros + + +def _extract_wrapper_registrations(content, file_path): + """Extract registrations from wrapper macros that expand to KERNEL_EX.""" + macros = _parse_macro_definitions(content) + if not macros: + return [] + + kernel_macros = _resolve_kernel_macros(macros) + if not kernel_macros: + return [] + + non_define = _strip_define_bodies(content) + invocations = _find_calls_in_text(non_define, set(kernel_macros.keys())) + + registrations = [] + seen = set() + for call_name, call_args in invocations: + for op_info, dom_info in kernel_macros[call_name]: + # Resolve op_name + if op_info[0] == "literal": + op_name = op_info[1] + elif op_info[0] == "param" and op_info[1] < len(call_args): + op_name = call_args[op_info[1]].strip() + else: + continue + + # Resolve domain + if dom_info[0] == "literal": + domain = dom_info[1] + elif dom_info[0] == "param" and dom_info[1] < len(call_args): + domain = call_args[dom_info[1]].strip() + else: + domain = "kOnnxDomain" + + # Filter out C++ types / param names that aren't valid op names + if not op_name or not op_name[0].isupper(): + continue + + key = (op_name, domain, file_path) + if key not in seen: + seen.add(key) + registrations.append((op_name, domain, 0, file_path)) + + return registrations + + +def extract_kernel_registrations(file_path): + """Extract (op_name, domain, since_version) tuples from kernel registration macros in a file.""" + content = _preprocess_content(file_path) + if content is None: + return [] + + registrations = [] + + # Phase 1: Direct ONNX_OPERATOR_*_KERNEL_EX patterns outside #define bodies + non_define = _strip_define_bodies(content) + for pattern in KERNEL_EX_PATTERNS: + for m in pattern.finditer(non_define): + groups = m.groups() + op_name = groups[0] + domain = groups[1] + since_version = int(groups[2]) + registrations.append((op_name, domain, since_version, str(file_path))) + + # Phase 2: Wrapper macros that expand to ONNX_OPERATOR_*_KERNEL_EX + registrations.extend(_extract_wrapper_registrations(content, str(file_path))) + + return registrations + + +def parse_registration_table(file_path, table_func_name): + """Parse the registration table function to extract op names referenced in BuildKernelCreateInfo calls.""" + registrations = set() + try: + content = Path(file_path).read_text(errors="replace") + except OSError: + return registrations + + # Find the function + func_start = content.find(f"{table_func_name}") + if func_start < 0: + return registrations + + # Extract class names from BuildKernelCreateInfo entries + # Pattern: ONNX_OPERATOR_*_KERNEL_CLASS_NAME(provider, domain, ver, [type,] name) + class_name_patterns = [ + # ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name) + re.compile(r"ONNX_OPERATOR_KERNEL_CLASS_NAME\s*\(\s*\w+\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*(\w+)\s*\)"), + # ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name) + re.compile( + r"ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME\s*\(\s*\w+\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*\w+\s*,\s*(\w+)\s*\)" + ), + # ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, start, end, name) + re.compile( + r"ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME\s*\(\s*\w+\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*\d+\s*,\s*(\w+)\s*\)" + ), + # ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, start, end, type, name) + re.compile( + r"ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME\s*\(\s*\w+\s*,\s*(\w+)\s*,\s*(\d+)\s*,\s*\d+\s*,\s*\w+\s*,\s*(\w+)\s*\)" + ), + ] + + # Also handle CUDA_MS_OP_CLASS_NAME and CUDA_MS_OP_TYPED_CLASS_NAME + ms_op_patterns = [ + # CUDA_MS_OP_CLASS_NAME(ver, name) -> domain is kMSDomain + re.compile(r"CUDA_MS_OP_CLASS_NAME\s*\(\s*(\d+)\s*,\s*(\w+)\s*\)"), + # CUDA_MS_OP_TYPED_CLASS_NAME(ver, type, name) + re.compile(r"CUDA_MS_OP_TYPED_CLASS_NAME\s*\(\s*(\d+)\s*,\s*\w+\s*,\s*(\w+)\s*\)"), + ] + + # Scan from function start to end of file (conservative) + search_region = content[func_start:] + + for pattern in class_name_patterns: + for m in pattern.finditer(search_region): + domain, version, name = m.group(1), int(m.group(2)), m.group(3) + registrations.add((name, domain, version)) + + for pattern in ms_op_patterns: + for m in pattern.finditer(search_region): + version, name = int(m.group(1)), m.group(2) + registrations.add((name, "kMSDomain", version)) + + return registrations + + +def get_excluded_files(cmake_path, repo_root): + """Parse the plugin CMake file to get regex exclusion patterns.""" + exclusion_patterns = [] + try: + content = Path(cmake_path).read_text() + except OSError: + return exclusion_patterns + + # Match: list(FILTER CC_SRCS EXCLUDE REGEX "pattern") + # or: list(FILTER CU_SRCS EXCLUDE REGEX "pattern") + for m in re.finditer(r'list\s*\(\s*FILTER\s+\w+\s+EXCLUDE\s+REGEX\s+"([^"]+)"\s*\)', content): + pat = m.group(1) + # Only keep non-commented lines + line_start = content.rfind("\n", 0, m.start()) + 1 + line = content[line_start : m.start()] + if "#" not in line: + exclusion_patterns.append(pat) + + return exclusion_patterns + + +def find_kernel_files(base_dirs, extensions=(".cc",)): + """Find all kernel source files in the given directories.""" + files = [] + for base_dir in base_dirs: + for ext in extensions: + for path in Path(base_dir).rglob(f"*{ext}"): + files.append(str(path)) + return sorted(files) + + +def is_excluded(file_path, exclusion_patterns): + """Check if a file path matches any exclusion pattern.""" + return any(re.search(pat, file_path) for pat in exclusion_patterns) + + +def generate_report(repo_root): + """Generate the full parity report.""" + repo_root = Path(repo_root) + + # Paths + cuda_ep_cc = repo_root / "onnxruntime/core/providers/cuda/cuda_execution_provider.cc" + cuda_nhwc_cc = repo_root / "onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc" + contrib_cc = repo_root / "onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc" + plugin_cmake = repo_root / "cmake/onnxruntime_providers_cuda_plugin.cmake" + + # 1. Parse bundled EP registration tables + bundled_standard = parse_registration_table(cuda_ep_cc, "RegisterCudaKernels") + bundled_nhwc = parse_registration_table(cuda_nhwc_cc, "RegisterCudaKernels") # NHWC uses same function name pattern + bundled_contrib = parse_registration_table(contrib_cc, "RegisterCudaContribKernels") + + # 2. Get exclusion patterns from plugin CMake + exclusion_patterns = get_excluded_files(plugin_cmake, repo_root) + + # 3. Scan all CUDA kernel source files + core_cuda_dir = repo_root / "onnxruntime/core/providers/cuda" + contrib_cuda_dir = repo_root / "onnxruntime/contrib_ops/cuda" + + all_cc_files = find_kernel_files([core_cuda_dir, contrib_cuda_dir]) + + # 4. Categorize files and extract registrations + plugin_registrations = [] # (op, domain, ver, file) tuples - compiled into plugin + excluded_registrations = [] # (op, domain, ver, file) tuples - excluded from plugin + + for f in all_cc_files: + regs = extract_kernel_registrations(f) + if not regs: + continue + if is_excluded(f, exclusion_patterns): + excluded_registrations.extend(regs) + else: + plugin_registrations.extend(regs) + + # 5. Build op sets for comparison + plugin_ops = set() + for op, domain, ver, _ in plugin_registrations: + plugin_ops.add((op, domain, ver)) + + excluded_ops = set() + for op, domain, ver, _ in excluded_registrations: + excluded_ops.add((op, domain, ver)) + + # Unique op names (ignoring version/type variants) + plugin_op_names = {(op, domain) for op, domain, _ in plugin_ops} + excluded_op_names = {(op, domain) for op, domain, _ in excluded_ops} + bundled_op_names = set() + for ops_set in [bundled_standard, bundled_nhwc, bundled_contrib]: + for op, domain, _ver in ops_set: + bundled_op_names.add((op, domain)) + + # 6. Generate report + report = [] + report.append("=" * 70) + report.append("CUDA EP Plugin — Kernel Registration Parity Report") + report.append("=" * 70) + report.append("") + + report.append("## Summary") + report.append(" NOTE: Plugin macro counts may undercount due to nested macro") + report.append(" expansion (e.g., BINARY_OP_VERSIONED_UZILHFD wraps multiple") + report.append(" ONNX_OPERATOR_TYPED_KERNEL_EX calls). Bundled table counts") + report.append(" from RegisterCudaKernels/RegisterCudaContribKernels are accurate.") + report.append("") + report.append(" Bundled EP registration table entries:") + report.append(f" Standard ops: {len(bundled_standard)}") + report.append(f" NHWC ops: {len(bundled_nhwc)}") + report.append(f" Contrib ops: {len(bundled_contrib)}") + report.append(f" Total: {len(bundled_standard) + len(bundled_nhwc) + len(bundled_contrib)}") + report.append("") + report.append(" Plugin kernel macro invocations (in compiled .cc files):") + report.append(f" Total: {len(plugin_registrations)}") + report.append(" Excluded kernel macro invocations:") + report.append(f" Total: {len(excluded_registrations)}") + report.append("") + report.append(" Unique op names (op, domain):") + report.append(f" In plugin: {len(plugin_op_names)}") + report.append(f" Excluded: {len(excluded_op_names)}") + report.append(f" In bundled: {len(bundled_op_names)}") + report.append("") + + # Plugin-only ops (in plugin but not in bundled table — likely already handled) + plugin_only = plugin_op_names - bundled_op_names + if plugin_only: + report.append(f" Plugin-only op names (not in bundled table): {len(plugin_only)}") + for op, domain in sorted(plugin_only): + report.append(f" - {op} ({domain})") + report.append("") + + # Bundled-only ops (in bundled but not in plugin+excluded) + all_source_ops = plugin_op_names | excluded_op_names + bundled_only = bundled_op_names - all_source_ops + if bundled_only: + report.append(f" Bundled-only op names (not in any .cc file KERNEL_EX): {len(bundled_only)}") + for op, domain in sorted(bundled_only): + report.append(f" - {op} ({domain})") + report.append("") + + # Coverage ratio + if bundled_op_names: + coverage = len(plugin_op_names & bundled_op_names) / len(bundled_op_names) * 100 + report.append(f" Plugin coverage: {coverage:.1f}% of bundled unique op names") + report.append("") + + # 7. Excluded ops detail + report.append("## Excluded Ops by Category") + report.append("") + + # Group excluded by file/directory + excluded_by_dir = defaultdict(list) + for op, domain, ver, filepath in excluded_registrations: + # Extract a short category from the path + rel_path = str(filepath).replace(str(repo_root) + "/", "") + parts = rel_path.split("/") + # Find the most descriptive sub-directory + if "contrib_ops" in rel_path: + idx = parts.index("cuda") if "cuda" in parts else 0 + category = "/".join(parts[idx + 1 : -1]) or parts[-1] + elif "core/providers/cuda" in rel_path: + idx = [i for i, p in enumerate(parts) if p == "cuda"][-1] + category = "/".join(parts[idx + 1 : -1]) or parts[-1] + else: + category = "other" + excluded_by_dir[category].append((op, domain, ver, rel_path)) + + for category in sorted(excluded_by_dir): + entries = excluded_by_dir[category] + unique_ops = {(op, domain) for op, domain, _, _ in entries} + report.append(f" [{category}] ({len(entries)} registrations, {len(unique_ops)} unique ops)") + for op, domain in sorted(unique_ops): + report.append(f" - {op} ({domain})") + report.append("") + + report.append("## Active CMake Exclusion Patterns") + for i, pat in enumerate(exclusion_patterns, 1): + report.append(f" {i:2d}. {pat}") + report.append("") + + return "\n".join(report) + + +# ============================================================================ +# Runtime-based parity report (uses the actual kernel registries) +# ============================================================================ + +# Map C++ domain constants to the string forms used by the bundled EP table parser. +_DOMAIN_CONST_TO_STRING = { + "kOnnxDomain": "", + "kMSDomain": "com.microsoft", + "kMSInternalNHWCDomain": "com.microsoft.internal.nhwc", + "kPytorchAtenDomain": "com.microsoft.pytorch.aten", +} + +# Reverse: runtime domain string -> constant name for display. +_DOMAIN_STRING_TO_CONST = {v: k for k, v in _DOMAIN_CONST_TO_STRING.items()} +_DOMAIN_STRING_TO_CONST["ai.onnx"] = "kOnnxDomain" + + +def _runtime_domain_display(domain_str): + """Convert a runtime domain string (e.g. '' or 'com.microsoft') to the constant name used in reports.""" + return _DOMAIN_STRING_TO_CONST.get(domain_str, domain_str or "kOnnxDomain") + + +def _kernel_defs_to_op_names(kernel_defs): + """Convert a list of KernelDef objects to a set of (op_name, domain_display) tuples.""" + op_names = set() + for kd in kernel_defs: + domain = _runtime_domain_display(kd.domain) + op_names.add((kd.op_name, domain)) + return op_names + + +def generate_runtime_report(plugin_ep_name, plugin_lib_path, bundled_ep_name="CUDAExecutionProvider"): + """Generate a parity report by querying actual kernel registries at runtime.""" + import onnxruntime as ort # noqa: PLC0415 + import onnxruntime.capi.onnxruntime_pybind11_state as rtpy # noqa: PLC0415 + + # 1. Get bundled EP kernel defs + bundled_defs = [kd for kd in rtpy.get_all_opkernel_def() if kd.provider == bundled_ep_name] + bundled_op_names = _kernel_defs_to_op_names(bundled_defs) + + # 2. Register the plugin EP + try: + ort.register_execution_provider_library(plugin_ep_name, plugin_lib_path) + except Exception as e: + if "already registered" not in str(e).lower(): + print(f"Error: failed to register plugin EP '{plugin_ep_name}': {e}", file=sys.stderr) + sys.exit(1) + + # 3. Get plugin EP kernel defs via the C++ registry query API + if not hasattr(rtpy, "get_registered_ep_kernel_defs"): + raise RuntimeError( + "get_registered_ep_kernel_defs is not available. " + "Rebuild onnxruntime with the pybind change in onnxruntime_pybind_schema.cc." + ) + + plugin_defs = rtpy.get_registered_ep_kernel_defs(plugin_ep_name) + plugin_op_names = _kernel_defs_to_op_names(plugin_defs) + method = "kernel registry query" + + # 4. Build report + report = [] + report.append("=" * 70) + report.append("CUDA EP Plugin — Runtime Kernel Parity Report") + report.append("=" * 70) + report.append("") + + report.append(f"## Summary (Runtime — {method})") + report.append(f" Bundled EP ({bundled_ep_name}):") + report.append(f" Total kernel registrations: {len(bundled_defs)}") + report.append(f" Unique op names (op, domain): {len(bundled_op_names)}") + report.append("") + report.append(f" Plugin EP ({plugin_ep_name}):") + if plugin_defs is not None: + report.append(f" Total kernel registrations: {len(plugin_defs)}") + report.append(f" Unique op names (op, domain): {len(plugin_op_names)}") + report.append("") + + plugin_only = plugin_op_names - bundled_op_names + if plugin_only: + report.append(f" Plugin-only op names (not in bundled): {len(plugin_only)}") + for op, domain in sorted(plugin_only): + report.append(f" - {op} ({domain})") + report.append("") + + bundled_only = bundled_op_names - plugin_op_names + if bundled_only: + report.append(f" Bundled-only op names (not in plugin): {len(bundled_only)}") + for op, domain in sorted(bundled_only): + report.append(f" - {op} ({domain})") + report.append("") + + common = bundled_op_names & plugin_op_names + if bundled_op_names: + coverage = len(common) / len(bundled_op_names) * 100 + report.append( + f" Plugin coverage: {coverage:.1f}% of bundled unique op names ({len(common)}/{len(bundled_op_names)})" + ) + report.append("") + + # 5. Detailed version/type comparison (only with registry API) + if plugin_defs is not None: + bundled_by_op = defaultdict(list) + for kd in bundled_defs: + key = (kd.op_name, _runtime_domain_display(kd.domain)) + bundled_by_op[key].append(kd) + + plugin_by_op = defaultdict(list) + for kd in plugin_defs: + key = (kd.op_name, _runtime_domain_display(kd.domain)) + plugin_by_op[key].append(kd) + + version_gaps = [] + for op_key in sorted(common): + b_versions = {kd.version_range for kd in bundled_by_op[op_key]} + p_versions = {kd.version_range for kd in plugin_by_op[op_key]} + missing = b_versions - p_versions + if missing: + version_gaps.append((op_key, missing)) + + if version_gaps: + report.append("## Version Gaps (op present but some version ranges missing in plugin)") + for (op, domain), missing in version_gaps: + ranges = ", ".join(f"[{v[0]}, {v[1]}]" if v[1] < 2147483647 else f"{v[0]}+" for v in sorted(missing)) + report.append(f" {op} ({domain}): missing versions {ranges}") + report.append("") + + return "\n".join(report) + + +# _probe_plugin_ops and _make_probe_model removed — session-based probing +# has been moved to test_cuda_plugin_ep.py (test_plugin_ep_claims_key_ops). +# The runtime report now requires the C++ get_registered_ep_kernel_defs API. + + +def main(): + parser = argparse.ArgumentParser(description="CUDA EP Plugin Registration Parity Report") + parser.add_argument("--repo-root", default=None, help="Path to onnxruntime repo root") + parser.add_argument( + "--runtime", + action="store_true", + help="Use runtime kernel registry queries instead of static source analysis. " + "Requires a built onnxruntime with the plugin EP library available.", + ) + parser.add_argument( + "--plugin-ep-name", + default="CudaPluginExecutionProvider", + help="Name of the plugin EP (default: CudaPluginExecutionProvider)", + ) + parser.add_argument( + "--plugin-ep-lib", + default=None, + help="Path to the plugin EP shared library. Auto-detected from build dir if not specified.", + ) + args = parser.parse_args() + + if args.runtime: + # Auto-detect plugin library path if not specified + lib_path = args.plugin_ep_lib + if lib_path is None: + lib_path = _auto_detect_plugin_lib(args.repo_root) + if lib_path is None: + print( + "Error: could not auto-detect plugin EP library. Use --plugin-ep-lib to specify the path.", + file=sys.stderr, + ) + sys.exit(1) + + report = generate_runtime_report(args.plugin_ep_name, lib_path) + else: + if args.repo_root: + repo_root = args.repo_root + else: + script_dir = Path(__file__).resolve().parent + repo_root = script_dir.parent.parent + if not (Path(repo_root) / "onnxruntime").exists(): + print("Error: Could not find repo root. Use --repo-root flag.", file=sys.stderr) + sys.exit(1) + + report = generate_report(repo_root) + + print(report) + + +def _auto_detect_plugin_lib(repo_root): + """Try to find the plugin EP shared library in common build directories.""" + if repo_root is None: + script_dir = Path(__file__).resolve().parent + repo_root = script_dir.parent.parent + + repo_root = Path(repo_root) + lib_name = "libonnxruntime_providers_cuda_plugin.so" + + # Check ORT_CUDA_PLUGIN_PATH env var first + env_path = os.environ.get("ORT_CUDA_PLUGIN_PATH") + if env_path and Path(env_path).exists(): + return env_path + + # Search common build directories (pick the newest if multiple exist) + build_root = repo_root / "build" + if build_root.is_dir(): + candidates = sorted(build_root.rglob(lib_name), key=lambda p: p.stat().st_mtime, reverse=True) + if candidates: + return str(candidates[0]) + + # Check onnxruntime package installation + try: + import onnxruntime # noqa: PLC0415 + + pkg_dir = Path(onnxruntime.__file__).parent / "capi" + candidate = pkg_dir / lib_name + if candidate.exists(): + return str(candidate) + except ImportError: + # onnxruntime is not installed in the current environment. Return None here and the script will fail later. + pass + + return None + + +if __name__ == "__main__": + main() diff --git a/tools/python/gen_opkernel_doc.py b/tools/python/gen_opkernel_doc.py index f6f9f21396859..b5a4507c1f0b3 100644 --- a/tools/python/gen_opkernel_doc.py +++ b/tools/python/gen_opkernel_doc.py @@ -58,7 +58,42 @@ def expand_providers(provider_filter: [str]): return providers -def main(output_path: pathlib.Path, provider_filter: [str]): +def load_plugin_ep_kernel_defs(plugin_eps): + """Register plugin EP libraries and return their kernel defs. + + Args: + plugin_eps: list of "name:path" strings, e.g. + ["CudaPluginExecutionProvider:/path/to/lib.so"] + + Returns: + list of KernelDef objects from the plugin EPs. + """ + if not plugin_eps: + return [] + + import onnxruntime as ort # noqa: PLC0415 + + defs = [] + for spec in plugin_eps: + if ":" not in spec: + print(f"Warning: invalid --plugin-ep spec '{spec}' (expected NAME:PATH), skipping") + continue + ep_name, lib_path = spec.split(":", 1) + try: + ort.register_execution_provider_library(ep_name, lib_path) + except Exception as e: + if "already registered" not in str(e).lower(): + print(f"Warning: failed to register plugin EP '{ep_name}': {e}") + continue + try: + defs.extend(rtpy.get_registered_ep_kernel_defs(ep_name)) + except Exception as e: + print(f"Warning: failed to get kernel defs for plugin EP '{ep_name}': {e}") + + return defs + + +def main(output_path: pathlib.Path, provider_filter: [str], plugin_eps=None): providers = expand_providers(provider_filter) with open(output_path, "w", newline="", encoding="utf-8") as fout: @@ -109,6 +144,13 @@ def main(output_path: pathlib.Path, provider_filter: [str]): domain = "ai.onnx" index[op.provider][domain][op.op_name].append(op) + # Include kernel defs from plugin EPs + for op in load_plugin_ep_kernel_defs(plugin_eps): + domain = op.domain + if not domain: + domain = "ai.onnx" + index[op.provider][domain][op.op_name].append(op) + # TOC fout.write("## Execution Providers\n\n") for provider in sorted(index.keys()): @@ -166,6 +208,14 @@ def main(output_path: pathlib.Path, provider_filter: [str]): "'ExecutionProvider' is automatically appended as needed. " "e.g. `--providers cpu cuda` will match CPUExecutionProvider and CUDAExecutionProvider.", ) + parser.add_argument( + "--plugin-ep", + nargs="+", + dest="plugin_eps", + help="Register plugin EP libraries and include their kernels. " + "Each entry is NAME:PATH, e.g. " + "'CudaPluginExecutionProvider:/path/to/libonnxruntime_providers_cuda_plugin.so'.", + ) parser.add_argument( "--output_path", help="output markdown file path", @@ -175,4 +225,4 @@ def main(output_path: pathlib.Path, provider_filter: [str]): ) args = parser.parse_args() - main(args.output_path, args.providers) + main(args.output_path, args.providers, args.plugin_eps)