diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 996e0d816d51a..70e8ea7e2792f 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -22,7 +22,6 @@ jobs: strategy: matrix: vcpkg_option: [novcpkg, vcpkg] - wgsl_template: [static, dynamic] env: OrtPackageId: Microsoft.ML.OnnxRuntime OnnxRuntimeBuildDirectory: ${{ github.workspace }} @@ -124,7 +123,6 @@ jobs: --build_nodejs ` --build_java ` --use_webgpu ` - --wgsl_template ${{ matrix.wgsl_template }} ` ${{ matrix.vcpkg_option == 'vcpkg' && '--use_vcpkg' || '' }} ` --cmake_extra_defines ` onnxruntime_BUILD_UNIT_TESTS=ON ` diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index b01110b2a4a03..fb4238731ffc3 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -151,7 +151,6 @@ option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OF option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF) option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF) option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF) -option(onnxruntime_CLIENT_PACKAGE_BUILD "Enables default settings that are more appropriate for client/on-device workloads." OFF) cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF) # For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone cmake_dependent_option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." ON "onnxruntime_MINIMAL_BUILD;NOT onnxruntime_ENABLE_PYTHON" OFF) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 6d517003fa6b6..59d99ade131cd 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -95,11 +95,6 @@ if (onnxruntime_MINIMAL_BUILD) endif() endif() -# ORT build with default settings more appropriate for client/on-device workloads. -if (onnxruntime_CLIENT_PACKAGE_BUILD) - add_compile_definitions(ORT_CLIENT_PACKAGE_BUILD) -endif() - if (onnxruntime_ENABLE_LTO) include(CheckIPOSupported) check_ipo_supported(RESULT ipo_enabled OUTPUT ipo_output) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 228906030d14c..e8f6bbe895d29 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -774,24 +774,13 @@ if (onnxruntime_USE_WEBGPU) endif() if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic") - if(onnxruntime_USE_VCPKG) - find_package(unofficial-duktape CONFIG REQUIRED) - add_library(duktape_static ALIAS unofficial::duktape::duktape) - else() - onnxruntime_fetchcontent_declare( - duktape - URL ${DEP_URL_duktape} - URL_HASH SHA1=${DEP_SHA1_duktape} - EXCLUDE_FROM_ALL - ) - onnxruntime_fetchcontent_makeavailable(duktape) - - if(NOT TARGET duktape_static) - add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c") - target_compile_features(duktape_static PRIVATE c_std_99) - target_include_directories(duktape_static INTERFACE $) - endif() - endif() + onnxruntime_fetchcontent_declare( + duktape + URL ${DEP_URL_duktape} + URL_HASH SHA1=${DEP_SHA1_duktape} + EXCLUDE_FROM_ALL + ) + onnxruntime_fetchcontent_makeavailable(duktape) endif() endif() diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 47e7779d93b33..f8f5546ae9465 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -31,7 +31,6 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/eltwise.cpp ${MLAS_SRC_DIR}/erf.cpp ${MLAS_SRC_DIR}/compute.cpp - ${MLAS_SRC_DIR}/dequantize.cpp ${MLAS_SRC_DIR}/quantize.cpp ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp ${MLAS_SRC_DIR}/qladd.cpp diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 4184e0b049afc..69c81a5ec7b9d 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -72,9 +72,10 @@ endif() # TensorRT 10 GA onwards, the TensorRT libraries will have major version appended to the end on Windows, - # for example, nvinfer_10.dll, nvonnxparser_10.dll ... + # for example, nvinfer_10.dll, nvinfer_plugin_10.dll, nvonnxparser_10.dll ... if (WIN32 AND TRT_GREATER_OR_EQUAL_TRT_10_GA) set(NVINFER_LIB "nvinfer_${NV_TENSORRT_MAJOR}") + set(NVINFER_PLUGIN_LIB "nvinfer_plugin_${NV_TENSORRT_MAJOR}") set(PARSER_LIB "nvonnxparser_${NV_TENSORRT_MAJOR}") endif() @@ -82,11 +83,15 @@ set(NVINFER_LIB "nvinfer") endif() + if (NOT NVINFER_PLUGIN_LIB) + set(NVINFER_PLUGIN_LIB "nvinfer_plugin") + endif() + if (NOT PARSER_LIB) set(PARSER_LIB "nvonnxparser") endif() - MESSAGE(STATUS "Looking for ${NVINFER_LIB}") + MESSAGE(STATUS "Looking for ${NVINFER_LIB} and ${NVINFER_PLUGIN_LIB}") find_library(TENSORRT_LIBRARY_INFER ${NVINFER_LIB} HINTS ${TENSORRT_ROOT} @@ -96,6 +101,14 @@ MESSAGE(STATUS "Can't find ${NVINFER_LIB}") endif() + find_library(TENSORRT_LIBRARY_INFER_PLUGIN ${NVINFER_PLUGIN_LIB} + HINTS ${TENSORRT_ROOT} + PATH_SUFFIXES lib lib64 lib/x64) + + if (NOT TENSORRT_LIBRARY_INFER_PLUGIN) + MESSAGE(STATUS "Can't find ${NVINFER_PLUGIN_LIB}") + endif() + if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) MESSAGE(STATUS "Looking for ${PARSER_LIB}") @@ -107,7 +120,7 @@ MESSAGE(STATUS "Can't find ${PARSER_LIB}") endif() - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_NVONNXPARSER}) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN} ${TENSORRT_LIBRARY_NVONNXPARSER}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") else() if (TRT_GREATER_OR_EQUAL_TRT_10_GA) @@ -140,7 +153,7 @@ endif() # Static libraries are just nvonnxparser_static on all platforms set(onnxparser_link_libs nvonnxparser_static) - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER}) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") endif() @@ -148,7 +161,7 @@ # nvonnxparser_static is linked against tensorrt libraries in onnx-tensorrt # See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121 # However, starting from TRT 10 GA, nvonnxparser_static doesn't link against tensorrt libraries. - # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER}. + # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER} and ${TENSORRT_LIBRARY_INFER_PLUGIN}. if(onnxruntime_CUDA_MINIMAL) set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) else() diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index 2865ad33b39f4..5b80b1262464d 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -172,12 +172,10 @@ file(MAKE_DIRECTORY ${WGSL_GENERATED_DIR}) # Find all WGSL template input files - file(GLOB_RECURSE WGSL_TEMPLATE_FILES - "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template" - "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.wgsl.template") + file(GLOB_RECURSE WGSL_TEMPLATE_FILES "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template") # Set wgsl-gen command line options as a list - set(WGSL_GEN_OPTIONS "-i" "${ONNXRUNTIME_ROOT}/core/providers/webgpu/" "-i" "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") + set(WGSL_GEN_OPTIONS "-i" "../" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") if (onnxruntime_WGSL_TEMPLATE STREQUAL "static") if (CMAKE_BUILD_TYPE STREQUAL "Debug") list(APPEND WGSL_GEN_OPTIONS "--generator" "static-cpp-literal") @@ -209,9 +207,10 @@ # Add the generated directory to include paths target_include_directories(onnxruntime_providers_webgpu PRIVATE ${WGSL_GENERATED_ROOT}) elseif(onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic") + add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c") + target_compile_features(duktape_static PRIVATE c_std_99) target_link_libraries(onnxruntime_providers_webgpu duktape_static) - onnxruntime_add_include_to_target(onnxruntime_providers_webgpu duktape_static) - + target_include_directories(onnxruntime_providers_webgpu PRIVATE ${duktape_SOURCE_DIR}/src) # Define the path to the generated templates.js file target_compile_definitions(onnxruntime_providers_webgpu PRIVATE "ORT_WGSL_TEMPLATES_JS_PATH=\"${WGSL_GENERATED_TEMPLATES_JS}\"") diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json index 373ecec440921..7c6b2fed36d1b 100644 --- a/cmake/vcpkg.json +++ b/cmake/vcpkg.json @@ -43,6 +43,7 @@ "ms-gsl", "nlohmann-json", "onnx", + "optional-lite", { "name": "protobuf", "version>=": "3.21.12" @@ -93,10 +94,6 @@ "webgpu-ep": { "description": "Build with WebGPU EP", "dependencies": [] - }, - "webgpu-ep-wgsl-template-dynamic": { - "description": "Build with WebGPU EP with dynamic WGSL template code generator", - "dependencies": ["duktape"] } }, "overrides": [ @@ -107,10 +104,6 @@ { "name": "flatbuffers", "version": "23.5.26" - }, - { - "name": "duktape", - "version": "2.7.0#2" } ] } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs index 6e6190b8227b8..c28830ec72157 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs @@ -40,12 +40,10 @@ public void RunPlatformUnitTest() var serializedResultSummary = _app.Invoke(_getResultsBackdoorMethodName)?.ToString(); Assert.IsNotEmpty(serializedResultSummary, "Test results were not returned"); - // Fix security issue (overflow with too much nesting): GHSA-5crp-9r3c-p9vr - JsonConvert.DefaultSettings = () => new JsonSerializerSettings { MaxDepth = 128 }; var testSummary = JsonConvert.DeserializeObject(serializedResultSummary); Assert.AreEqual(testSummary.Failed, 0, $"{testSummary.Failed} tests failed"); _app.Screenshot("Post-testing"); } } -} +} \ No newline at end of file diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs index 625cc2c54055c..8419d261e4a41 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs @@ -45,9 +45,8 @@ public TestResultSummary GetResults() public string GetSerializedResults() { var resultSummary = GetResults(); - JsonConvert.DefaultSettings = () => new JsonSerializerSettings { MaxDepth = 128 }; var serializedResultSummary = JsonConvert.SerializeObject(resultSummary, Formatting.Indented); return serializedResultSummary; } } -} +} \ No newline at end of file diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f3dcde1abe37a..b80918e6615e1 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2545,8 +2545,6 @@ This version of the operator has been available since version 1 of the 'com.micr
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
-
qk_output : int
-
Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).
rotary_interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
scale : float
@@ -2557,7 +2555,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Softcap value for attention weights. Default value is 0.
-#### Inputs (7 - 12) +#### Inputs (7 - 11)
query : T
@@ -2582,11 +2580,9 @@ This version of the operator has been available since version 1 of the 'com.micr
2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel uses only the first element
attention_bias (optional) : T
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
-
head_sink (optional) : T
-
1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.
-#### Outputs (3 - 4) +#### Outputs
output : T
@@ -2595,8 +2591,6 @@ This version of the operator has been available since version 1 of the 'com.micr
present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
present_value : T
present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
-
output_qk (optional) : T
-
Values of QK matrix multiplication, either before or after softmax normalization
#### Type Constraints diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index fa6c731231405..1ffcabee8cc10 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -538,7 +538,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -942,7 +942,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1420,7 +1420,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index c18a42cc1bbc1..54e03a31fceef 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -952,12 +952,9 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return const_cast(this)->GetNodeArg(name); } - // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding mutable NodeArg + // search this and up through any parent_graph_ instance for a NodeArg NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name); - // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding const NodeArg - const NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const; - /** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found. @param name The NodeArg name. @param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created. diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h deleted file mode 100644 index 37665542f614f..0000000000000 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ /dev/null @@ -1,718 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -/* - SUMMARY: - Utilities to serialize an OrtGraph into an ONNX GraphProto or ModelProto. Can be used by execution provider - implementations that need to convert an OrtGraph instance into an ONNX protobuf model. - - Users may copy this file and modify as needed. - - USAGE: - This is a header-only implementation that includes both the function declarations and definitions. Copy this file - into a project that links with both ONNX Runtime and ONNX. - - Define the ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL preprocessor macro before the #include statement in exactly one C++ - file to define the implementation. Example: - - #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL - #include "ort_graph_to_proto.h" - - Other compilation units that depend on these utilities should include this file without defining the - preprocessor macro. - - Example program snippets are shown below. Refer to the function declarations for detailed usage information. - - EXAMPLE SNIPPET (initializers stored within TensorProto): - - ```C++ - #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL - #include "ort_graph_to_proto.h" - - OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, - OrtEpGraphSupportInfo* graph_support_info) { - onnx::GraphProto graph_proto; - OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto); - - // graph_proto stores initializers internally - } - ``` - - EXAMPLE SNIPPET (large initializers stored in external file): - - ```C++ - #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL - #include "ort_graph_to_proto.h" - - OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, - OrtEpGraphSupportInfo* graph_support_info) { - std::string external_file_path = "weights.bin"; - std::ofstream out_file(external_file_path, std::ios::binary); - - auto handle_initializer_data = [&external_file_path, &out_file](const OrtValueInfo* value_info, - const void* data, size_t bytes, - bool& is_external, std::string& location, - int64_t& offset) -> Ort::Status { - // OrtValueInfo* could be used to query initializer's name, type, shape, consumers, etc. - (void)value_info; - - if (bytes <= 127) { - is_external = false; // Keep small initializers stored inside the TensorProto. - return Ort::Status{nullptr}; - } - - offset = out_file.tellp(); - location = external_file_path; - out_file.write(static_cast(data), bytes); - out_file.flush(); - is_external = true; // True if is external initializer - return Ort::Status{nullptr}; - } - - ONNX_NAMESPACE::GraphProto graph_proto; - OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); - - // graph_proto stores large initializers in an external file - } - ``` -*/ - -#ifndef INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ -#define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ - -#include -#include "core/session/onnxruntime_cxx_api.h" -#include "onnx/onnx_pb.h" - -namespace OrtEpUtils { - -/// -/// Signature of user-provided function to handle initializer data. Called by OrtGraphToProto() for every initializer. -/// -/// If the function sets the `is_external` output parameter to false, OrtGraphToProto() stores initializer data -/// within the TensorProto as raw_data. -/// -/// Otherwise, if the function sets `is_external` to true, OrtGraphToProto() assumes that this function stores the -/// initializer data in a file. In this case, OrtGraphToProto() configures the corresponding TensorProto to point the -/// location and offset returned via the `location` and `offset` output parameters. -/// -/// It is recommended to keep small initializers with byte size <= 127 stored inline the TensorProto to ensure -/// ONNX shape inference works correctly with the serialized ONNX model. -/// -/// OrtValueInfo for the initializer. Can be used to query name, type, shape, -/// and consumer nodes. -/// Opaque pointer to the initializer data. -/// Size in bytes of the initializer data. -/// Output parameter set to true if the initializer data is stored externally. The -/// implementer is responsible for writing the initializer data to file. If set to false, -/// the initializer will be stored within the TensorProto. -/// Output parameter set to the location (e.g., file) into which the initializer is stored -/// by the implementer of this function. Ignored if `is_external` is set to false. -/// Output parameter set to the offset (e.g., file offset) into which the initializer is stored -/// by the implementer of this function. Ignored if `is_external` is set to false. -/// An Ort::Status indicating success or an error. Serialization exits if this returns an error. -using HandleInitializerDataFunc = std::function; - -/// -/// Serializes the provided OrtGraph to a onnx::GraphProto. -/// Allows the caller to provide a function that specifies whether an initializer should be stored -/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). -/// -/// OrtGraph instance to serialize. -/// Destination GraphProto into which to serialize the input OrtGraph. -/// Optional function called to allow the user to determine -/// where the initializer data is stored. -/// An Ort::Status indicating success or an error. -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, - onnx::GraphProto& graph_proto, - HandleInitializerDataFunc handle_initializer_data_func = nullptr); - -/// -/// Serializes the provided top-level OrtGraph to a onnx::ModelProto. -/// Allows the caller to provide a function that specifies whether an initializer should be stored -/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). -/// -/// OrtGraph instance to serialize. -/// Destination ModelProto into which to serialize the input OrtGraph. -/// Optional function called to allow the user to determine -/// where the initializer data is stored. -/// An Ort::Status indicating success or an error. -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, - onnx::ModelProto& model_proto, - HandleInitializerDataFunc handle_initializer_data_func = nullptr); -} // namespace OrtEpUtils - -// End of header -#endif // INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ - -// -// IMPLEMENTATION BELOW -// -#ifdef ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL - -#include -#include -#include -#include -#include -#include - -#define ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) \ - do { \ - OrtStatus* _status = (fn); \ - if (_status != nullptr) { \ - return Ort::Status{_status}; \ - } \ - } while (0) - -#define ORT_EP_UTILS_CXX_RETURN_IF_ERROR(fn) \ - do { \ - Ort::Status _status = (fn); \ - if (!_status.IsOK()) { \ - return _status; \ - } \ - } while (0) - -#define ORT_EP_UTILS_C_RETURN_IF(cond, ort_api, msg) \ - do { \ - if ((cond)) { \ - return Ort::Status{(ort_api).CreateStatus(ORT_FAIL, (msg))}; \ - } \ - } while (0) - -namespace OrtEpUtils { - -static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, - bool get_symbolic_dims, - /*out*/ ONNXTensorElementDataType& elem_type, - /*out*/ std::vector& dims, - /*out*/ std::vector& symbolic_dims); -static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); - -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, - onnx::GraphProto& graph_proto, - HandleInitializerDataFunc handle_initializer_data_func) { - const OrtApi& ort_api = Ort::GetApi(); - - // - // Set GraphProto metadata - // - const char* graph_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetName(&ort_graph, &graph_name)); - graph_proto.set_name(graph_name); - graph_proto.set_doc_string("Serialized from OrtGraph"); - - // - // Set GraphProto inputs and outputs - // - size_t num_graph_inputs = 0; - size_t num_graph_outputs = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumInputs(&ort_graph, &num_graph_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(&ort_graph, &num_graph_outputs)); - - std::vector graph_inputs(num_graph_inputs); - std::vector graph_outputs(num_graph_outputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetInputs(&ort_graph, graph_inputs.data(), graph_inputs.size())); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOutputs(&ort_graph, graph_outputs.data(), graph_outputs.size())); - - for (const OrtValueInfo* ort_value_info : graph_inputs) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); - } - - for (const OrtValueInfo* ort_value_info : graph_outputs) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); - } - - // - // Set GraphProto nodes, value_infos, and initializers. - // - - // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. - // A std::map maintains its elements in a stable ordering. - std::map value_infos; // For GraphProto.value_info - std::map initializer_value_infos; // For GraphProto.initializer - - // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. - // Optionally returns the OrtValueInfo name to the caller. - auto collect_value_info = [&ort_api, &value_infos, - &initializer_value_infos](const OrtValueInfo& ort_value_info, - /*out*/ const char** value_name_out = nullptr) -> Ort::Status { - const char* value_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); - - if (value_name_out != nullptr) { - *value_name_out = value_name; - } - - if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { - return Ort::Status{nullptr}; // Already processed this OrtValueInfo. - } - - bool is_required_graph_input = false; - bool is_optional_graph_input = false; - bool is_graph_output = false; - bool is_constant_initializer = false; - bool is_from_outer_scope = false; - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsRequiredGraphInput(&ort_value_info, &is_required_graph_input)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsOptionalGraphInput(&ort_value_info, &is_optional_graph_input)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsGraphOutput(&ort_value_info, &is_graph_output)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(&ort_value_info, &is_constant_initializer)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsFromOuterScope(&ort_value_info, &is_from_outer_scope)); - - // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. - // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. - // For values defined in an outer scope, just add the value info but not the initializer. - if (is_from_outer_scope) { - value_infos.emplace(value_name, &ort_value_info); - } else if (is_optional_graph_input) { - initializer_value_infos.emplace(value_name, &ort_value_info); - } else if (is_constant_initializer) { - value_infos.emplace(value_name, &ort_value_info); - initializer_value_infos.emplace(value_name, &ort_value_info); - } else if (!is_required_graph_input && !is_graph_output) { - value_infos.emplace(value_name, &ort_value_info); // This is an internal OrtValueInfo. - } - - return Ort::Status{nullptr}; - }; - - size_t num_nodes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); - - std::vector nodes(num_nodes); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); - - // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos - // that will be stored in GraphProto.value_info and GraphProto.initializer. - for (size_t i = 0; i < num_nodes; i++) { - const OrtNode* ort_node = nodes[i]; - onnx::NodeProto* node_proto = graph_proto.add_node(); - - const char* node_name = nullptr; - const char* node_domain = nullptr; - const char* node_op_type = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetName(ort_node, &node_name)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetDomain(ort_node, &node_domain)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOperatorType(ort_node, &node_op_type)); - - node_proto->set_name(node_name); - node_proto->set_domain(node_domain); - node_proto->set_op_type(node_op_type); - - size_t num_inputs = 0; - size_t num_implicit_inputs = 0; - size_t num_outputs = 0; - size_t num_attrs = 0; - size_t num_subgraphs = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumInputs(ort_node, &num_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumImplicitInputs(ort_node, &num_implicit_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(ort_node, &num_outputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(ort_node, &num_attrs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumSubgraphs(ort_node, &num_subgraphs)); - - // Handle node attributes - if (num_attrs > 0) { - std::vector ort_attrs(num_attrs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetAttributes(ort_node, ort_attrs.data(), ort_attrs.size())); - - for (const OrtOpAttr* ort_attr : ort_attrs) { - OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - - Ort::Status status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; - if (!status.IsOK()) { - // This is an attribute type that ORT does not support via ReadOpAttr(), like subgraphs, so skip it. - // Can use Node_GetSubgraphs to get subgraphs. - continue; - } - - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); - } - } - - // Handle node subgraphs - if (num_subgraphs > 0) { - std::vector ort_subgraphs(num_subgraphs); - std::vector subgraph_attr_names(num_subgraphs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetSubgraphs(ort_node, ort_subgraphs.data(), ort_subgraphs.size(), - subgraph_attr_names.data())); - - for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { - const OrtGraph* ort_subgraph = ort_subgraphs[subgraph_idx]; - const char* subgraph_attr_name = subgraph_attr_names[subgraph_idx]; - - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - onnx::GraphProto* subgraph_proto = attr_proto->mutable_g(); - - attr_proto->set_name(subgraph_attr_name); - attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_subgraph, *subgraph_proto)); - } - } - - // Handle node inputs - if (num_inputs > 0) { - std::vector ort_inputs(num_inputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetInputs(ort_node, ort_inputs.data(), ort_inputs.size())); - - for (const OrtValueInfo* ort_value_info : ort_inputs) { - if (ort_value_info == nullptr) { - // missing optional input. - node_proto->add_input(""); - continue; - } - - const char* value_name = nullptr; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); - - node_proto->add_input(value_name); - } - } - - // Handle implicit inputs to this node. - if (num_implicit_inputs > 0) { - std::vector ort_implicit_inputs(num_implicit_inputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetImplicitInputs(ort_node, ort_implicit_inputs.data(), - ort_implicit_inputs.size())); - - for (const OrtValueInfo* ort_value_info : ort_implicit_inputs) { - assert(ort_value_info != nullptr); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, /*value_name_out*/ nullptr)); - } - } - - // Handle node outputs - if (num_outputs > 0) { - std::vector ort_outputs(num_outputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOutputs(ort_node, ort_outputs.data(), ort_outputs.size())); - - for (const OrtValueInfo* ort_value_info : ort_outputs) { - if (ort_value_info == nullptr) { - // missing optional output. - node_proto->add_output(""); - continue; - } - - const char* value_name = nullptr; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); - - node_proto->add_output(value_name); - } - } - } - - // Add value_infos to GraphProto as ValueInfoProto objects. - for (const std::pair& entry : value_infos) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*entry.second, *value_info_proto)); - } - - // Add initializers to GraphProto as TensorProto objects. - for (const std::pair& entry : initializer_value_infos) { - const OrtValueInfo* initializer_value_info = entry.second; - std::string initializer_name = std::string{entry.first}; // Need a null-terminated string. - std::vector initializer_dims; - std::vector initializer_sym_dims; - ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(*initializer_value_info, /*get_sym_dims*/ false, - initializer_elem_type, initializer_dims, - initializer_sym_dims)); - - onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); - tensor_proto->set_name(initializer_name); - tensor_proto->set_data_type(initializer_elem_type); - - auto* tensor_proto_dims = tensor_proto->mutable_dims(); - for (int64_t dim : initializer_dims) { - tensor_proto_dims->Add(dim); - } - - const OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer_value_info, &ort_value)); - - const void* data = nullptr; - size_t data_bytes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); - - std::string ext_location; - int64_t ext_offset = 0; - bool is_external = false; - - if (handle_initializer_data_func != nullptr) { - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, - is_external, ext_location, ext_offset)); - } - - if (is_external) { - tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); - auto* ext_data_entries = tensor_proto->mutable_external_data(); - onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); - onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); - - location_entry->set_key("location"); - location_entry->set_value(ext_location); - offset_entry->set_key("offset"); - offset_entry->set_value(std::to_string(ext_offset)); - } else { - // User wants to store data inline the TensorProto's raw_data - tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); - tensor_proto->set_raw_data(data, data_bytes); - } - } - - return Ort::Status{nullptr}; -} - -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, - onnx::ModelProto& model_proto, - HandleInitializerDataFunc handle_initializer_data_func) { - const OrtApi& ort_api = Ort::GetApi(); - - // Check that OrtGraph is a top-level graph (no parent node). - const OrtNode* parent_node = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetParentNode(&ort_graph, &parent_node)); - ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, ort_api, "Cannot serialize nested OrtGraph into a ModelProto"); - - // Set model description. - model_proto.set_doc_string("Serialized from OrtGraph"); - model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); - - // Set ir version. - int64_t ir_version = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOnnxIRVersion(&ort_graph, &ir_version)); - model_proto.set_ir_version(ir_version); - - // Set operator sets. - size_t num_operator_sets = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOperatorSets(&ort_graph, &num_operator_sets)); - ORT_EP_UTILS_C_RETURN_IF(num_operator_sets == 0, ort_api, "OrtGraph should have at least one operator set."); - - std::vector domains(num_operator_sets, nullptr); - std::vector opset_versions(num_operator_sets); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOperatorSets(&ort_graph, domains.data(), opset_versions.data(), - num_operator_sets)); - - auto* operator_sets = model_proto.mutable_opset_import(); - - for (size_t i = 0; i < num_operator_sets; ++i) { - onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); - operator_set->set_domain(domains[i]); - operator_set->set_version(opset_versions[i]); - } - - model_proto.clear_graph(); - onnx::GraphProto* graph_proto = model_proto.mutable_graph(); - - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(ort_graph, *graph_proto, handle_initializer_data_func)); - - return Ort::Status{nullptr}; -} - -static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, - bool get_symbolic_dims, - /*out*/ ONNXTensorElementDataType& elem_type, - /*out*/ std::vector& dims, - /*out*/ std::vector& symbolic_dims) { - const OrtApi& ort_api = Ort::GetApi(); - - const OrtTypeInfo* ort_type_info = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(&ort_value_info, &ort_type_info)); - - ONNXType ort_onnx_type = ONNX_TYPE_UNKNOWN; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(ort_type_info, &ort_onnx_type)); - ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, ort_api, "Expected OrtValueInfo to represent a Tensor"); - - const OrtTensorTypeAndShapeInfo* ort_type_shape = nullptr; - ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(ort_type_info, &ort_type_shape)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(ort_type_shape, &ort_elem_type)); - - size_t num_dims = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(ort_type_shape, &num_dims)); - - std::vector ort_dims(num_dims, 0); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(ort_type_shape, ort_dims.data(), ort_dims.size())); - - elem_type = ort_elem_type; - dims = std::move(ort_dims); - - if (get_symbolic_dims) { - std::vector ort_dim_syms(num_dims, nullptr); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetSymbolicDimensions(ort_type_shape, ort_dim_syms.data(), - ort_dim_syms.size())); - - symbolic_dims.reserve(num_dims); - for (const char* sym_dim : ort_dim_syms) { - symbolic_dims.push_back(sym_dim); - } - } - - return Ort::Status{nullptr}; -} - -// Create an onnx::ValueInfoProto from an OrtValueInfo (name, type, shape). -static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, - onnx::ValueInfoProto& value_info_proto) { - const OrtApi& ort_api = Ort::GetApi(); - - std::vector ort_dims; - std::vector ort_dim_syms; - ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - - // We currently only support ONNX tensors. Support for other types (e.g., ONNX_TYPE_SEQUENCE) can be added later. - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, - ort_elem_type, ort_dims, ort_dim_syms)); - - const char* value_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); - value_info_proto.set_name(value_name); - - onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); - type_proto_tensor->set_elem_type(ort_elem_type); - - onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); - - for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { - onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); - - if (ort_dims[dim_idx] >= 0) { - dim_proto->set_dim_value(ort_dims[dim_idx]); - } else { - const std::string& dim_param = ort_dim_syms[dim_idx]; - - // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, - // which represents an unknown dimension. - if (!dim_param.empty()) { - dim_proto->set_dim_param(dim_param); - } - } - } - - return Ort::Status{nullptr}; -} - -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { - const OrtApi& ort_api = Ort::GetApi(); - - const char* attr_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetName(&ort_attr, &attr_name)); - attr_proto.set_name(attr_name); - - size_t total_attr_bytes = 0; - OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetType(&ort_attr, &attr_type)); - - switch (attr_type) { - case OrtOpAttrType::ORT_OP_ATTR_INT: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); - - int64_t i_val = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &i_val, sizeof(i_val), &total_attr_bytes)); - attr_proto.set_i(i_val); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_INTS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector i_vals(total_attr_bytes / sizeof(int64_t)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, i_vals.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* ints = attr_proto.mutable_ints(); - for (int64_t val : i_vals) { - ints->Add(val); - } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); - - float f_val = 0.0f; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &f_val, sizeof(f_val), &total_attr_bytes)); - attr_proto.set_f(f_val); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector f_vals(total_attr_bytes / sizeof(float)); - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, f_vals.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* floats = attr_proto.mutable_floats(); - for (float val : f_vals) { - floats->Add(val); - } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_STRING: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::string* str = attr_proto.mutable_s(); - - str->resize(total_attr_bytes, '\0'); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, - &total_attr_bytes)); - - str->resize(total_attr_bytes - 1); // remove extra ending terminating '\0' character. - break; - } - case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector chars(total_attr_bytes, '\0'); - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, chars.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* strs = attr_proto.mutable_strings(); - - // Strings are all in a single buffer, each separated with a '\0'. - // Extract each string and add it to the STRINGS attribute array. - char* at = chars.data(); - char* end = at + chars.size(); - - while (at < end) { - char* str_begin = at; - - while (*at && at < end) { - at++; - } - - strs->Add()->assign(str_begin, at - str_begin); - if (at < end) { - assert(*at == '\0'); - at++; // Skip '\0' to get to the beginning of the next string. - } - } - - break; - } - default: { - std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); - return Ort::Status(err_msg.c_str(), ORT_FAIL); - } - } - - return Ort::Status{nullptr}; -} - -} // namespace OrtEpUtils -#endif // ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 82e782112974f..86c0b60db2bc4 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -66,7 +66,6 @@ extern "C" { #define _In_reads_(X) #define _Inout_updates_(X) #define _Out_writes_(X) -#define _Out_writes_opt_(X) #define _Inout_updates_all_(X) #define _Out_writes_bytes_all_(X) #define _Out_writes_all_(X) @@ -4750,8 +4749,6 @@ struct OrtApi { * \param[in] len Number of bytes allowed to store in data * \param[out] out Number of bytes required to save the data when the call failed, or the real number of bytes saved to data on success * - * \note Does not support reading graph attributes. Refer to Node_GetSubgraphs. - * * \since Version 1.17. */ ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out); @@ -5571,45 +5568,6 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); - /** \brief Returns the number of operator sets that the graph's model uses. - * - * \note An operator set is uniquely identified by the (domain, opset_version) pair. All models must have at - * least one entry that specifies which entry of the ONNX operator set is used. The ONNX domain is represented by - * an empty string. - * - * \param[in] graph The OrtGraph instance. - * \param[out] num_operator_sets Output parameter set to the number of operator sets that the graph's model uses. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); - - /** \brief Returns the operator sets that the graph's model uses. - * - * \note An operator set is uniquely identified by the (domain, opset_version) pair. All models must have at - * least one entry that specifies which entry of the ONNX operator set is used. The ONNX domain is represented by - * an empty string. - * - * \param[in] graph The OrtGraph instance. - * \param[out] domains Pre-allocated array of `num_operator_sets` elements that is filled with - * null-terminated domain names. - * \param[out] opset_versions Pre-allocated array of `num_operator_sets` elements that is filled with - * the opset version of the corresponding domain in the `domains` array. - * \param[in] num_operator_sets The size of the `domains` and `opset_versions` arrays. - * Typical usage sets this to the result of Graph_GetNumOperatorSets(). - * An error status is returned if `num_operator_sets` is less than the actual number - * of operator sets. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(Graph_GetOperatorSets, _In_ const OrtGraph* graph, - _Out_writes_(num_operator_sets) const char** domains, - _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets); - /** \brief Returns the number of graph inputs. * * \note The count includes initializers that are included in the list of graph inputs. @@ -5748,24 +5706,6 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); - /** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph. - * - * Note: - * The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference - * the same underlying graph. - * - * \param[in] src_graph The source OrtGraph instance. - * \param[in] nodes A subset of the nodes/OrtNodes in 'graph'. - * \param[in] num_nodes Number of nodes. - * \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(Graph_GetGraphView, _In_ const OrtGraph* src_graph, _In_ const OrtNode** nodes, - _In_ size_t num_nodes, _Outptr_ OrtGraph** dst_graph); - /// @} /// \name OrtNode @@ -5993,24 +5933,20 @@ struct OrtApi { /** \brief Get the subgraphs, as OrtGraph instances, contained by the given node. * - * \note Only certain operator types (e.g., If and Loop) contain nested subgraphs. ONNX nodes store subgraphs in - * their attributes, however, this function must be used to obtain subgraphs from an OrtNode. + * \note Only certain operator types (e.g., If and Loop) contain nested subgraphs. * * \param[in] node The OrtNode instance. * \param[out] subgraphs Pre-allocated array of `num_subgraphs` elements that is filled with the node's subgraphs. * \param[in] num_subgraphs The size of the `num_subgraphs` array. * Typical usage sets this to the result of Node_GetNumSubgraphs(). An error status is * returned if `num_subgraphs` is less than the number of node subgraphs. - * \param[out] attribute_names Optional pre-allocated array of `num_subgraphs` elements that is filled with the - * attribute names that correspond to the subgraphs. Ignored if set to NULL. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, - _Out_writes_opt_(num_subgraphs) const char** attribute_names); + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs); /** \brief Get the node's parent OrtGraph instance. * @@ -6026,19 +5962,6 @@ struct OrtApi { */ ORT_API2_STATUS(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); - /** \brief Returns the execution provider name that this node is assigned to run on. - * Returns NULL if the node has not been assigned to any execution provider yet. - * For plugin execution providers, the name is the one returned by OrtEp::GetName. - * - * \param[in] node The OrtNode instance. - * \param[out] out Output execution provider type and can be NULL if node has not been assigned. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); - /// @} /// \name OrtRunOptions @@ -6887,24 +6810,6 @@ struct OrtCompileApi { */ ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options, size_t flags); - - /** Sets information related to EP context binary file. - * - * EP uses this information to decide the location and context binary file name. - * Used while compiling model with input and output in memory buffer - * - * \param[in] model_compile_options The OrtModelCompilationOptions instance. - * \param[in] output_directory Null terminated string of the path (wchar on Windows, char otherwise). - * \param[in] model_name Null terminated string of the model name (wchar on Windows, char otherwise). - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(ModelCompilationOptions_SetEpContextBinaryInformation, - _In_ OrtModelCompilationOptions* model_compile_options, - _In_ const ORTCHAR_T* output_directory, - _In_ const ORTCHAR_T* model_name); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index d1b08f127fa2a..c59baa59c91a5 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1161,8 +1161,6 @@ struct ModelCompilationOptions : detail::Base { size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer - ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory, - const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index ba5d53e6c2dd0..612adc81d3309 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -819,15 +819,6 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelPath( return *this; } -inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextBinaryInformation( - const ORTCHAR_T* output_directory, const ORTCHAR_T* model_name) { - Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextBinaryInformation( - this->p_, - output_directory, - model_name)); - return *this; -} - inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalInitializersFile( const ORTCHAR_T* file_path, size_t initializer_size_threshold) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelExternalInitializersFile( diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 5d00ce4940d02..44c7bb6ee424a 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -358,7 +358,7 @@ struct OrtEp { * * \since Version 1.22. */ - ORT_API_T(const char*, GetName, _In_ const OrtEp* this_ptr); + const char*(ORT_API_CALL* GetName)(_In_ const OrtEp* this_ptr); /** \brief Get information about the nodes supported by the OrtEp instance. * @@ -376,8 +376,8 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API2_STATUS(GetCapability, _In_ OrtEp* this_ptr, _In_ const OrtGraph* graph, - _Inout_ OrtEpGraphSupportInfo* graph_support_info); + OrtStatus*(ORT_API_CALL* GetCapability)(_In_ OrtEp* this_ptr, _In_ const OrtGraph* graph, + _Inout_ OrtEpGraphSupportInfo* graph_support_info); /** \brief Compile OrtGraph instances assigned to the OrtEp. Implementer must set a OrtNodeComputeInfo instance * for each OrtGraph in order to define its computation function. @@ -416,10 +416,10 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API2_STATUS(Compile, _In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, - _In_ const OrtNode** fused_nodes, _In_ size_t count, - _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes); + OrtStatus*(ORT_API_CALL* Compile)(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes); /** \brief Release OrtNodeComputeInfo instances. * @@ -429,9 +429,9 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API_T(void, ReleaseNodeComputeInfos, _In_ OrtEp* this_ptr, - OrtNodeComputeInfo** node_compute_infos, - _In_ size_t num_node_compute_infos); + void(ORT_API_CALL* ReleaseNodeComputeInfos)(_In_ OrtEp* this_ptr, + OrtNodeComputeInfo** node_compute_infos, + _In_ size_t num_node_compute_infos); /** \brief Get the EP's preferred data layout. * @@ -445,7 +445,8 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API2_STATUS(GetPreferredDataLayout, _In_ OrtEp* this_ptr, _Out_ OrtEpDataLayout* preferred_data_layout); + OrtStatus*(ORT_API_CALL* GetPreferredDataLayout)(_In_ OrtEp* this_ptr, + _Out_ OrtEpDataLayout* preferred_data_layout); /** \brief Given an op with domain `domain` and type `op_type`, determine whether an associated node's data layout * should be converted to `target_data_layout`. @@ -469,10 +470,11 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API2_STATUS(ShouldConvertDataLayoutForOp, _In_ OrtEp* this_ptr, - _In_z_ const char* domain, _In_z_ const char* op_type, - _In_ OrtEpDataLayout target_data_layout, - _Outptr_ int* should_convert); + OrtStatus*(ORT_API_CALL* ShouldConvertDataLayoutForOp)(_In_ OrtEp* this_ptr, + _In_z_ const char* domain, + _In_z_ const char* op_type, + _In_ OrtEpDataLayout target_data_layout, + _Outptr_ int* should_convert); /** \brief Set dynamic options on this EP. * @@ -490,10 +492,10 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API2_STATUS(SetDynamicOptions, _In_ OrtEp* this_ptr, - _In_reads_(num_options) const char* const* option_keys, - _In_reads_(num_options) const char* const* option_values, - _In_ size_t num_options); + OrtStatus*(ORT_API_CALL* SetDynamicOptions)(_In_ OrtEp* this_ptr, + _In_reads_(num_options) const char* const* option_keys, + _In_reads_(num_options) const char* const* option_values, + _In_ size_t num_options); /** \brief Called by ORT to notify the EP of the start of a run. * @@ -506,7 +508,8 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API2_STATUS(OnRunStart, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options); + OrtStatus*(ORT_API_CALL* OnRunStart)(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options); /** \brief Called by ORT to notify the EP of the end of a run. * @@ -521,7 +524,9 @@ struct OrtEp { * * \since Version 1.23. */ - ORT_API2_STATUS(OnRunEnd, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options, _In_ bool sync_stream); + OrtStatus*(ORT_API_CALL* OnRunEnd)(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options, + _In_ bool sync_stream); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. @@ -581,7 +586,7 @@ struct OrtEpFactory { * * \since Version 1.22. */ - ORT_API_T(const char*, GetName, const OrtEpFactory* this_ptr); + const char*(ORT_API_CALL* GetName)(const OrtEpFactory* this_ptr); /** \brief Get the name of vendor who owns the execution provider that the factory creates. * @@ -592,7 +597,7 @@ struct OrtEpFactory { * * \since Version 1.22. */ - ORT_API_T(const char*, GetVendor, const OrtEpFactory* this_ptr); // return EP vendor + const char*(ORT_API_CALL* GetVendor)(const OrtEpFactory* this_ptr); // return EP vendor /** \brief Get information from the execution provider about OrtHardwareDevice support. * @@ -611,12 +616,12 @@ struct OrtEpFactory { * * \since Version 1.22. */ - ORT_API2_STATUS(GetSupportedDevices, _In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_ size_t num_devices, - _Inout_ OrtEpDevice** ep_devices, - _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices); + OrtStatus*(ORT_API_CALL* GetSupportedDevices)(_In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices); /** \brief Function to create an OrtEp instance for use in a Session. * @@ -642,12 +647,12 @@ struct OrtEpFactory { * * \since Version 1.22. */ - ORT_API2_STATUS(CreateEp, _In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep); + OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep); /** \brief Release the OrtEp instance. * @@ -656,18 +661,7 @@ struct OrtEpFactory { * * \since Version 1.22. */ - ORT_API_T(void, ReleaseEp, OrtEpFactory* this_ptr, struct OrtEp* ep); - - /** \brief Get the vendor id who owns the execution provider that the factory creates. - * - * This is typically the PCI vendor ID. See https://pcisig.com/membership/member-companies - * - * \param[in] this_ptr The OrtEpFactory instance. - * \return vendor_id The vendor ID of the execution provider the factory creates. - * - * \since Version 1.23. - */ - ORT_API_T(uint32_t, GetVendorId, const OrtEpFactory* this_ptr); + void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); /** \brief Get the version of the execution provider that the factory creates. * @@ -681,7 +675,7 @@ struct OrtEpFactory { * * \since Version 1.23. */ - ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr); + const char*(ORT_API_CALL* GetVersion)(_In_ const OrtEpFactory* this_ptr); /** \brief Create an OrtAllocator for the given OrtMemoryInfo. * diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 314cf76cc8044..97e53e6acee5a 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -148,9 +148,7 @@ static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = " // Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking // "0": thread will block if found no job to run -// "1": thread will spin a number of times before blocking -// The default is "0" when ORT is built with "ORT_CLIENT_PACKAGE_BUILD" and "1" otherwise. -// Thread spinning is disabled by default for client/on-device workloads to reduce cpu utilization and improve power efficiency. +// "1": default, thread will spin a number of times before blocking static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning"; static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning"; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index c2085342efd80..5a837fd1e0bfa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -98,7 +98,7 @@ const calculateInputIndicesImpl = ( `fn calculateInputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} { var input_indices: ${input.type.indices}; var carry = 0u; - for (var i = ${inputShape.length - 1}; i >= 0; i--) { + for (var i = ${inputShape.length}; i >= 0; i--) { let input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}; let steps_i = ${getElementAt('uniforms.steps', 'i', inputShape.length)}; let signs_i = ${getElementAt('uniforms.signs', 'i', inputShape.length)}; diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index 87008f51ff4b9..c3300f7272bb9 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -38,6 +38,7 @@ Usage: Options: -d --debug specify the debug build type of the artifacts to download. -l --latest if set, will always use the latest build, even if it is not completed yet. + --webgpu-ep if set, will use the webgpu EP wasm build instead of the default(JSEP) one. -h --help print this message and exit `; @@ -80,8 +81,9 @@ try { // The following code checks both the command line arguments and the npm_config_* environment variables to get the correct values. const debug = args.debug || process.env.npm_config_d || process.env.npm_config_debug; const latest = args.latest || process.env.npm_config_l || process.env.npm_config_latest; +const webgpuEp = args['webgpu-ep'] || process.env.npm_config_webgpu_ep; -const folderName = debug ? 'Debug_wasm' : 'Release_wasm'; +const folderName = (debug ? 'Debug_wasm' : 'Release_wasm') + (webgpuEp ? '_webgpu' : ''); const allowImcomplete = latest; const run = args._[0]; // The first non-option argument @@ -149,17 +151,13 @@ async function downloadArtifactsForRun(run: any): Promise { if (!fs.existsSync(WASM_FOLDER)) { fs.mkdirSync(WASM_FOLDER); } else { + // TODO: revise artifacts download + const filesToDelete = ['ort-wasm-simd-threaded.jsep.mjs', 'ort-wasm-simd-threaded.jsep.wasm']; + if (!folderName.endsWith('_webgpu')) { + filesToDelete.push('ort-wasm-simd-threaded.mjs', 'ort-wasm-simd-threaded.wasm'); + } fs.readdirSync(WASM_FOLDER).forEach((file) => { - if ( - [ - 'ort-wasm-simd-threaded.jsep.mjs', - 'ort-wasm-simd-threaded.jsep.wasm', - 'ort-wasm-simd-threaded.jsep.mjs', - 'ort-wasm-simd-threaded.jsep.wasm', - 'ort-wasm-simd-threaded.mjs', - 'ort-wasm-simd-threaded.wasm', - ].includes(file) - ) { + if (filesToDelete.includes(file)) { const filePath = path.join(WASM_FOLDER, file); console.log(`Deleting old file: ${filePath}`); fs.unlinkSync(filePath); diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 80d374d3f0b25..243f611da49e1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -53,12 +53,6 @@ enum AttentionKernelType { AttentionKernel_Default }; -enum class QKOutputType : int { - NO_OUTPUT = 0, - BEFORE_SOFTMAX = 1, - AFTER_SOFTMAX = 2 -}; - constexpr bool LAYOUT_BSNH = false; constexpr bool LAYOUT_BNSH = true; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index aef47edd5fcd2..ac32a4445f3ca 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -17,13 +17,13 @@ namespace onnxruntime { namespace contrib { template -inline void ComputeSmoothSoftmaxInplace(T* score, int D, float sink, ThreadPool* tp) { - MlasComputeSoftmax(score, score, 1, D, false, true, sink, tp); +inline void ComputeSmoothSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { + MlasComputeSoftmax(score, score, N, D, false, true, tp); } template inline void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { - MlasComputeSoftmax(score, score, N, D, false, false, 0.0f, tp); + MlasComputeSoftmax(score, score, N, D, false, false, tp); } template diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 0d5117709c18a..c79508cbae273 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -35,8 +35,6 @@ class GQAAttentionBase { use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; local_window_size_ = has_local ? static_cast(info.GetAttrOrDefault("local_window_size", -1)) : -1; - - qk_output_ = static_cast(info.GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))); } int num_heads_; // number of attention heads of Q @@ -46,7 +44,6 @@ class GQAAttentionBase { bool do_rotary_; // whether or not to use rotary embeddings bool rotary_interleaved_; int local_window_size_; - int qk_output_; bool use_smooth_softmax_; @@ -54,14 +51,12 @@ class GQAAttentionBase { Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH const T* K, // K data with shape BxN_kvxSxH const T* V, // V data with shape BxN_kvxSxH - const T* head_sink, // Head sink for smooth softmax, nullptr if not used const Tensor* attention_bias, // Attention bias to add to QxK' const Tensor* past_key, // past K input tensor (if not using past state) const Tensor* past_value, // past V input tensor (if not using past state) Tensor* output, // output tensor Tensor* present_key, // present K output tensor (if separating present KV) Tensor* present_value, // present V output tensor (if separating present KV) - Tensor* output_qk, // output QK buffer const Tensor* seqlens_k, // past sequence lengths tensor GroupQueryAttentionParameters& parameters, // attention parameters AllocatorPtr allocator, // allocator for temporary tensors @@ -69,7 +64,6 @@ class GQAAttentionBase { const bool is_prompt = parameters.is_first_prompt; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int total_sequence_length = parameters.total_sequence_length; const int head_size = parameters.head_size; const int hidden_size = parameters.hidden_size; const bool packed_qkv = parameters.is_packed_qkv; @@ -85,7 +79,8 @@ class GQAAttentionBase { // Compute the attention score. bool gqa_mlas_supported = MlasGQASupported(CblasNoTrans, CblasTrans) && MlasGQASupported(CblasNoTrans, CblasNoTrans); - size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * (gqa_mlas_supported ? sizeof(T) : sizeof(float)); + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * + (gqa_mlas_supported ? sizeof(T) : sizeof(float)); auto attention_probs = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); @@ -101,13 +96,11 @@ class GQAAttentionBase { const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; - T* output_qk_buffer = output_qk != nullptr ? output_qk->MutableData() : nullptr; - if (gqa_mlas_supported) { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, - batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, - seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, - past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, + batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, + head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, + tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -117,10 +110,10 @@ class GQAAttentionBase { hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); } else { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, - batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, - seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, - past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, + batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, + head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, + tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -143,19 +136,16 @@ class GQAAttentionBase { void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT const T* Q, // Q data. Its size is BxNxSxH const T* K, // k data. Its size is BxNxLxH - const T* head_sink, // for smooth softmax. Its size is N. const int32_t* seqlens_k, // total - 1 sequence lengths tensor const T* attention_bias, // optional attention bias const size_t batch_size, // batch size of self-attention const size_t sequence_length, // sequence length of self-attention (S) - const size_t total_sequence_length, // total sequence length (T) const gsl::span attention_bias_shape, // shape of the attention bias const size_t past_buffer_sequence_length, // sequence length of past state const size_t present_buffer_sequence_length, // sequence length of present state const size_t head_size, // head size of self-attention const T* past_key, // past key only T* present_key, // present key only - T* output_qk, // output QK buffer const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt @@ -207,11 +197,6 @@ class GQAAttentionBase { const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; U* output = attention_probs + output_offset; - T* output_qk_thread = nullptr; - if (output_qk != nullptr) { - const ptrdiff_t output_qk_offset = SafeInt(sequence_length) * total_sequence_length * (batch_index * num_heads_ + head_index); - output_qk_thread = output_qk + output_qk_offset; - } // Compute attention bias offset based on the batch and head indexes // Attention bias is of shape (B or 1, H or 1, S, T) so handle broadcasting @@ -325,6 +310,12 @@ class GQAAttentionBase { } } + if (use_smooth_softmax_) { + ComputeSmoothSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); + } else { + ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); + } + // set causal [seq_causal_length, total_seqlen) to 0.f for (size_t total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) { if constexpr (std::is_same::value) { @@ -334,30 +325,11 @@ class GQAAttentionBase { } } - if (qk_output_ == static_cast(QKOutputType::BEFORE_SOFTMAX)) { - WriteOutputQKHeadChunk(output_qk_thread, output_softmax, total_sequence_length); - } - - if (use_smooth_softmax_ || head_sink != nullptr) { - float sink = (head_sink != nullptr) ? static_cast(head_sink[head_index]) : 0.0f; - ComputeSmoothSoftmaxInplace(output_softmax + start_offset, static_cast(window_size), sink, nullptr); - } else { - ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); - } - - if (qk_output_ == static_cast(QKOutputType::AFTER_SOFTMAX)) { - WriteOutputQKHeadChunk(output_qk_thread, output_softmax, total_sequence_length); - } - output_softmax += present_buffer_sequence_length; if (attention_bias_thread != nullptr) { attention_bias_thread += attention_total_seqlen; } - - if (output_qk_thread != nullptr) { - output_qk_thread += total_sequence_length; - } } } }); @@ -483,20 +455,6 @@ class GQAAttentionBase { SafeInt(sequence_length) * batch_size * num_heads_ * head_size); } } - - template - void WriteOutputQKHeadChunk(T* output_qk, const U* attention_probs, size_t total_sequence_length) const { - if (output_qk == nullptr) { - return; - } - - if constexpr (std::is_same_v) { - std::memcpy(output_qk, attention_probs, SafeInt(total_sequence_length) * sizeof(T)); - } else { - static_assert(std::is_same_v && std::is_same_v); - MlasConvertFloatToHalfBuffer(static_cast(attention_probs), output_qk, total_sequence_length); - } - } }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index eb1560ac8e341..a912bd6e6b43c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -95,11 +95,6 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { Tensor* present_k = context->Output(1, present_k_shape); Tensor* present_v = context->Output(2, present_v_shape); - std::vector output_qk_shape{static_cast(batch_size), static_cast(num_heads_), static_cast(parameters.sequence_length), static_cast(parameters.total_sequence_length)}; - Tensor* output_qk = context->Output(3, output_qk_shape); - - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckOutputs(output_qk, qk_output_)); - AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -211,12 +206,10 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data() : nullptr; - // Compute the attention score and apply the score to V return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), - head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, - output_qk, seqlens_k, parameters, allocator, context); + attention_bias, past_key, past_value, output, present_k, present_v, + seqlens_k, parameters, allocator, context); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index f01ce985658aa..0f66119540b03 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -398,37 +398,6 @@ Status CheckCustomAttentionInputs(const T* position_ids, return Status::OK(); } -template -Status CheckOutputs(const T* output_qk, int qk_output) { - const bool is_valid_qk_output = qk_output == static_cast(QKOutputType::NO_OUTPUT) || - qk_output == static_cast(QKOutputType::BEFORE_SOFTMAX) || - qk_output == static_cast(QKOutputType::AFTER_SOFTMAX); - if (!is_valid_qk_output) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "qk_output attribute received unsupported value ", qk_output); - } - - if (qk_output != static_cast(QKOutputType::NO_OUTPUT) && output_qk == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "qk_output attribute was configured but output buffer was not provided"); - } - - return Status::OK(); -} - -inline Status CheckNoQKOutput(int num_outputs, int qk_output) { - if (num_outputs > 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "output_qk optional output is not supported"); - } - - if (qk_output != static_cast(QKOutputType::NO_OUTPUT)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "qk_output attribute is not supported"); - } - - return Status::OK(); -} - } // namespace group_query_attention_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 9cb93cbcd3f32..68c4b01d2db20 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -109,12 +109,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; - // The current GQA CUDA implementation will never be able to have a QK output. - // GQA CUDA uses either flash attention or memory efficient attention. Neither kernel supports returning the QK output. - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( - context->OutputCount(), - static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); - if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 09a6550549614..85aef55908506 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -213,10 +213,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( - context->OutputCount(), - static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); - if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 1f039177b0a21..f3334b13dc645 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -178,10 +178,6 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& head_sink, params)); - ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( - context.OutputCount(), - static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); - WebgpuAttentionParameters parameters(params); TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size_); diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index c4667d53c0674..8ea593f107833 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -170,7 +170,7 @@ std::string CPUIDInfo::GetX86Vendor(int32_t* data) { uint32_t CPUIDInfo::GetVendorId(const std::string& vendor) { if (vendor == "GenuineIntel") return 0x8086; - if (vendor == "AuthenticAMD") return 0x1022; + if (vendor == "GenuineAMD") return 0x1022; if (vendor.find("Qualcomm") == 0) return 'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24); if (vendor.find("NV") == 0) return 0x10DE; return 0; diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 47fbe08da41ff..c3dd9321ebb0b 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -247,11 +247,8 @@ struct OrtNode { /// Gets the node's subgraphs (e.g., subgraphs contained by an If or Loop node). /// /// Buffer into which to copy the subgraphs. - /// Optional buffer into which to copy the attribute name for each subgraph. - /// If set, must point to a buffer with the same number of elements as `subgraphs`. /// A status indicating success or an error. - virtual onnxruntime::Status GetSubgraphs(gsl::span subgraphs, - const char** opt_attribute_names) const = 0; + virtual onnxruntime::Status GetSubgraphs(gsl::span subgraphs) const = 0; /// /// Gets the node's parent graph, which is the graph that contains this node. @@ -283,23 +280,6 @@ struct OrtGraph { /// The model's ONNX IR version. virtual int64_t GetOnnxIRVersion() const = 0; - /// - /// Gets the number of operator sets (domain, opset version) the graph's model relies on. - /// - /// Output parameter set to the number of operator sets. - /// A status indicating success or an error. - virtual onnxruntime::Status GetNumOperatorSets(size_t& num_operator_sets) const = 0; - - /// - /// Gets the operator sets the graph's model relies on. An operator set is uniquely identified by a - /// (domain, opset version) pair. - /// - /// Buffer into which to copy the domains. - /// Buffer into which to copy the opset version for each domain. - /// A status indicating success or an error. - virtual onnxruntime::Status GetOperatorSets(gsl::span domains, - gsl::span opset_versions) const = 0; - /// /// Returns the number of graph inputs, including initializers that appear in the list of graph inputs. /// diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index e2b17aa84d2b1..f2757c2c96471 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -6,7 +6,6 @@ #include "core/graph/contrib_ops/quantization_defs.h" #include "core/graph/contrib_ops/onnx_function_util.h" #include "core/graph/contrib_ops/shape_inference_functions.h" -#include "contrib_ops/cpu/bert/attention_common.h" // Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from // ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build #if defined(_WIN32) && !defined(NDEBUG) @@ -233,8 +232,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c // Type and shape inference for group query attention and sparse attention. void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index = -1, - int use_max_past_present_buffer = -1, - int output_qk_index = -1) { + int use_max_past_present_buffer = -1) { ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); int64_t kv_sequence_length = -1; @@ -279,20 +277,13 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte } } - if (ctx.getNumOutputs() >= 3) { // has present output + if (ctx.getNumOutputs() > 1) { // has present output // copy the type from query to present key ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1); // copy the type from query to present value ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2); - int64_t total_sequence_length_value = 0; - const auto* total_sequence_length_data = ctx.getInputData(6); - if (total_sequence_length_data != nullptr) { - const auto& data = ParseData(total_sequence_length_data); - total_sequence_length_value = static_cast(data[0]); - } - if (past_key_index >= 0 && hasInputShape(ctx, past_key_index)) { auto& past_shape = getInputShape(ctx, past_key_index); auto& past_dims = past_shape.dim(); @@ -308,25 +299,30 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); } else if (use_max_past_present_buffer == 0) { if (kv_sequence_length > 0 && past_dims[2].has_dim_value()) { - const int64_t present_sequence_length = kv_sequence_length + past_dims[2].dim_value(); + int64_t total_sequence_length = kv_sequence_length + past_dims[2].dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; for (auto& dim : past_dims) { *present_shape.add_dim() = dim; } - // shape of present key/value is (batch_size, kv_num_heads, present_sequence_length, head_size) - present_shape.mutable_dim(2)->set_dim_value(present_sequence_length); + // shape of present key/value is (batch_size, kv_num_heads, total_sequence_length, head_size) + present_shape.mutable_dim(2)->set_dim_value(total_sequence_length); updateOutputShape(ctx, 1, present_shape); updateOutputShape(ctx, 2, present_shape); } } else if (use_max_past_present_buffer == -1) { - if (total_sequence_length_value > 0 && past_dims[2].has_dim_value()) { + const auto* total_sequence_length_data = ctx.getInputData(6); + if (total_sequence_length_data != nullptr && past_dims[2].has_dim_value()) { + int64_t total_sequence_length_value = 0; + const auto& data = ParseData(total_sequence_length_data); + total_sequence_length_value = static_cast(data[0]); + // present_sequence_length = max(past_sequence_length, total_sequence_length) - const int64_t present_sequence_length = total_sequence_length_value > past_dims[2].dim_value() - ? total_sequence_length_value - : past_dims[2].dim_value(); + int64_t present_sequence_length = total_sequence_length_value > past_dims[2].dim_value() + ? total_sequence_length_value + : past_dims[2].dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; for (auto& dim : past_dims) { @@ -340,50 +336,19 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte updateOutputShape(ctx, 2, present_shape); } } - - if (output_qk_index >= 0) { - const bool did_supply_qk_buffer = ctx.hasOutput(output_qk_index); - const int64_t qk_output_type = getAttribute(ctx, "qk_output", static_cast(QKOutputType::NO_OUTPUT)); - - if (qk_output_type == static_cast(QKOutputType::NO_OUTPUT) && did_supply_qk_buffer) { - fail_shape_inference("Output QK buffer was provided but qk_output attribute was not configured"); - } - - if (qk_output_type != static_cast(QKOutputType::NO_OUTPUT) && !did_supply_qk_buffer) { - fail_shape_inference("Output QK buffer was not provided but qk_output attribute was configured"); - } - - int64_t num_heads = getAttribute(ctx, "num_heads", 0); - if (did_supply_qk_buffer && hasInputShape(ctx, 0) && total_sequence_length_value > 0 && num_heads > 0) { - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, output_qk_index); - - auto& query_shape = getInputShape(ctx, 0); - auto& query_dims = query_shape.dim(); - - if (query_dims[0].has_dim_value() && query_dims[1].has_dim_value()) { - ONNX_NAMESPACE::TensorShapeProto output_qk_shape; - *output_qk_shape.add_dim() = query_dims[0]; // batch_size - output_qk_shape.add_dim()->set_dim_value(num_heads); // num_heads - *output_qk_shape.add_dim() = query_dims[1]; // sequence_length - output_qk_shape.add_dim()->set_dim_value(total_sequence_length_value); // total_sequence_length - updateOutputShape(ctx, output_qk_index, output_qk_shape); - } - } - } } } } -void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index, int qk_output_index) { +void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) { // TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not constexpr int use_max_past_present_buffer = -1; - BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer, qk_output_index); + BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer); } void SparseAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) { constexpr int use_max_past_present_buffer = 1; - constexpr int qk_output_index = -1; - BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer, qk_output_index); + BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer); } constexpr const char* Attention_ver1_doc = R"DOC( @@ -1162,10 +1127,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Use a smooth factor in softmax.", AttributeProto::INT, static_cast(-1)) - .Attr("qk_output", - "Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).", - AttributeProto::INT, - static_cast(QKOutputType::NO_OUTPUT)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" @@ -1223,11 +1184,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", "T", OpSchema::Optional) - .Input(11, - "head_sink", - "1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.", - "T", - OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", @@ -1244,15 +1200,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", "T") - .Output(3, - "output_qk", - "Values of QK matrix multiplication, either before or after softmax normalization", - "T", - OpSchema::Optional) .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - GroupQueryAttentionTypeAndShapeInference(ctx, 3, 3); + GroupQueryAttentionTypeAndShapeInference(ctx, 3); })); constexpr const char* PagedAttention_ver1_doc = R"DOC( diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index f57543416a68f..698c7422a1e2a 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -129,12 +129,11 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, ConvertNodeArgsToValueInfos(ep_graph, value_infos_map, node_implicit_inputs, ep_node_implicit_inputs); - std::unordered_map> subgraphs_map = node.GetAttributeNameToSubgraphMap(); - ep_node_subgraphs.reserve(subgraphs_map.size()); + std::vector> node_subgraphs = node.GetSubgraphs(); + ep_node_subgraphs.reserve(node_subgraphs.size()); - for (const auto& [attr_name, subgraph] : subgraphs_map) { + for (gsl::not_null subgraph : node_subgraphs) { SubgraphState subgraph_state; - subgraph_state.attribute_name = attr_name; subgraph_state.subgraph_viewer = std::make_unique(*subgraph); ORT_RETURN_IF_ERROR(EpGraph::Create(*subgraph_state.subgraph_viewer, subgraph_state.ep_subgraph)); subgraph_state.ep_subgraph->SetParentNode(ep_node.get()); @@ -234,17 +233,12 @@ Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { return Status::OK(); } -Status EpNode::GetSubgraphs(gsl::span subgraphs, - const char** opt_attribute_names) const { +Status EpNode::GetSubgraphs(gsl::span dst) const { const size_t num_subgraphs = subgraphs_.size(); - ORT_RETURN_IF_ERROR((CheckCopyDestination("node subgraphs", num_subgraphs, subgraphs))); + ORT_RETURN_IF_ERROR((CheckCopyDestination("node attributes", num_subgraphs, dst))); for (size_t i = 0; i < num_subgraphs; ++i) { - subgraphs[i] = subgraphs_[i].ep_subgraph.get(); - - if (opt_attribute_names) { - opt_attribute_names[i] = subgraphs_[i].attribute_name.c_str(); - } + dst[i] = subgraphs_[i].ep_subgraph.get(); } return Status::OK(); @@ -276,10 +270,6 @@ const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const { } } -const std::string& EpNode::GetEpName() const { - return node_.GetExecutionProviderType(); -} - // // EpValueInfo // @@ -509,34 +499,10 @@ void EpGraph::IndexToEpNodeMap::SetEpNode(NodeIndex node_index, EpNode* ep_node) EpGraph::EpGraph(const GraphViewer& graph_viewer, PrivateTag) : OrtGraph(OrtGraphIrApi::kEpApi), graph_viewer_(graph_viewer) {} -EpGraph::EpGraph(std::unique_ptr graph_viewer, - std::unique_ptr indexed_sub_graph, - PrivateTag) - : OrtGraph(OrtGraphIrApi::kEpApi), - graph_viewer_(*graph_viewer.get()), - owned_graph_viewer_(std::move(graph_viewer)), - owned_indexed_sub_graph_(std::move(indexed_sub_graph)) {} - // Static class function to create a std::unique_ptr. Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { auto ep_graph = std::make_unique(graph_viewer, PrivateTag{}); - return CreateImpl(std::move(ep_graph), graph_viewer, result); -} - -// Static class function to create a std::unique_ptr. -Status EpGraph::Create(std::unique_ptr src_graph_viewer, - std::unique_ptr src_indexed_sub_graph, - /*out*/ std::unique_ptr& result) { - auto& graph_viewer = *src_graph_viewer.get(); - auto ep_graph = std::make_unique(std::move(src_graph_viewer), - std::move(src_indexed_sub_graph), - PrivateTag{}); - - return CreateImpl(std::move(ep_graph), graph_viewer, result); -} - -Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance(); std::unordered_map> value_infos_map; @@ -694,43 +660,6 @@ const std::string& EpGraph::GetName() const { return graph_viewer_.Name(); } int64_t EpGraph::GetOnnxIRVersion() const { return graph_viewer_.GetOnnxIRVersion(); } -Status EpGraph::GetNumOperatorSets(size_t& num_operator_sets) const { - num_operator_sets = graph_viewer_.DomainToVersionMap().size(); - return Status::OK(); -} - -Status EpGraph::GetOperatorSets(gsl::span domains, - gsl::span opset_versions) const { - const std::unordered_map& domain_to_version = graph_viewer_.DomainToVersionMap(); - size_t num_operator_sets = domain_to_version.size(); - - ORT_RETURN_IF_ERROR((CheckCopyDestination("operator set domains", num_operator_sets, domains))); - ORT_RETURN_IF_ERROR((CheckCopyDestination("operator set versions", num_operator_sets, opset_versions))); - - // Collect (domain, version) pairs and sort them by domain to ensure user always gets a stable ordering. - std::vector> pairs; - pairs.reserve(num_operator_sets); - - for (const auto& [domain, version] : domain_to_version) { - pairs.emplace_back(domain.c_str(), version); - } - - std::sort(pairs.begin(), pairs.end(), - [](const std::pair& a, const std::pair& b) -> bool { - return std::strcmp(a.first, b.first) < 0; - }); - - // Copy sorted (domain, version) pairs into the destination buffers. - size_t index = 0; - for (const auto& [domain_c_str, version] : pairs) { - domains[index] = domain_c_str; - opset_versions[index] = version; - index++; - } - - return Status::OK(); -} - size_t EpGraph::GetNumInputs() const { return inputs_.size(); } diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index d3921e051e18a..4240f5636b7ae 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -111,7 +111,6 @@ struct EpNode : public OrtNode { struct SubgraphState { SubgraphState() = default; SubgraphState(SubgraphState&& other) = default; - std::string attribute_name; std::unique_ptr subgraph_viewer; // The graph_viewer wrapped by EpGraph below. std::unique_ptr ep_subgraph; }; @@ -183,8 +182,7 @@ struct EpNode : public OrtNode { Status GetNumSubgraphs(size_t& num_subgraphs) const override; // Gets the subgraphs contained by this node. - Status GetSubgraphs(gsl::span subgraphs, - const char** opt_attribute_names) const override; + Status GetSubgraphs(gsl::span subgraphs) const override; // Gets this node's parent graph, which is the graph that directly contains this node. Status GetGraph(const OrtGraph*& parent_graph) const override; @@ -208,9 +206,6 @@ struct EpNode : public OrtNode { // Helper that gets the node's attributes by name. const OrtOpAttr* GetAttribute(const std::string& name) const; - // Helper that gets the execution provider name that this node is assigned to run on. - const std::string& GetEpName() const; - private: // Back pointer to containing graph. Useful when traversing through nested subgraphs. // Will be nullptr if the EpNode was created without an owning graph. @@ -254,32 +249,15 @@ struct EpGraph : public OrtGraph { public: EpGraph(const GraphViewer& graph_viewer, PrivateTag); - EpGraph(std::unique_ptr graph_viewer, - std::unique_ptr indexed_sub_graph, - PrivateTag); /// /// Creates an instance of EpGraph, which wraps a GraphViewer. - /// This call is used when creating an EpGraph from a GraphViewer instance. The GraphViewer instance is not onwed by this EpGraph. /// /// /// /// static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); - /// - /// Creates an instance of EpGraph, which wraps a GraphViewer. - /// This call is used when creating an EpGraph from a subset of nodes in another EpGraph. - /// In this case, due to the implementation of OrtApis::Graph_GetGraphView, the new EpGraph instance - /// must take ownership of both the GraphViewer and IndexedSubGraph. - /// - /// - /// - /// - static Status Create(std::unique_ptr graph_viewer, - std::unique_ptr indexed_sub_graph, - /*out*/ std::unique_ptr& result); - // Defines ToExternal() and ToInternal() functions to convert between OrtGraph and EpGraph. DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtGraph, EpGraph, OrtGraphIrApi::kEpApi) @@ -293,14 +271,6 @@ struct EpGraph : public OrtGraph { // Returns the model's ONNX IR version. int64_t GetOnnxIRVersion() const override; - // Gets the number of operator sets that the graph's model uses. - Status GetNumOperatorSets(size_t& num_operator_sets) const override; - - // Gets the operator sets that the graph's model uses. An operator set is uniquely identified by a - // (domain, opset version) pair. - Status GetOperatorSets(gsl::span domains, - gsl::span opset_versions) const override; - // Get the number of graph inputs, including initializers that are listed as graph inputs. size_t GetNumInputs() const override; @@ -351,22 +321,9 @@ struct EpGraph : public OrtGraph { const OrtValue* GetInitializerValue(std::string_view name) const; private: - /// - /// The real implementation of creating an EpGraph instance. - /// Please use one of the above 'Create' functions that internally call this function, and avoid calling this function directly. - /// - /// - /// - /// - /// - static Status CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); - const GraphViewer& graph_viewer_; const EpNode* parent_node_ = nullptr; - std::unique_ptr owned_graph_viewer_ = nullptr; - std::unique_ptr owned_indexed_sub_graph_ = nullptr; - std::vector> nodes_; IndexToEpNodeMap index_to_ep_node_; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 4d3091520d876..ca40bad2b4250 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,10 +1818,6 @@ NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name return node_arg; } -const NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const { - return const_cast(this)->GetNodeArgIncludingParentGraphs(node_arg_name); -} - void Graph::ReverseDFSFrom(gsl::span from, const std::function& enter, const std::function& leave, diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 948ebaa5f7e15..1842c2b4a0d1f 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -168,15 +168,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) filtered_node_inputs_including_initializers_.reserve(metadef->inputs.size()); for (const auto& input : metadef->inputs) { - // NodeArgs from the current scope or any outer scopes should be handled correctly. - // - // There is an edge case where the model consists of a graph with subgraphs nested across three levels. - // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). - // When constructing a new GraphViewer for the second- and third-layer subgraphs, - // the second-layer graph may not have the corresponding value_info for that first-layer input, - // because the second-layer graph itself doesn't consume it. - // Therefore, when working within the second-layer graph, we need to search outer scopes for the missing value_info. - const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(input); + const auto* nodearg = graph.GetNodeArg(input); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Input not found:", input); filtered_node_inputs_including_initializers_.push_back(nodearg); if (!graph.IsInitializedTensor(input)) { @@ -185,7 +177,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) } for (const auto& output : metadef->outputs) { - const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(output); + const auto* nodearg = graph.GetNodeArg(output); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Output not found:", output); filtered_node_outputs_.push_back(nodearg); } diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 6e7e17374bb59..6330a42c115db 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -136,8 +136,7 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); } - Status GetSubgraphs(gsl::span /*subgraphs*/, - const char** /*opt_attribute_names*/) const override { + Status GetSubgraphs(gsl::span /*subgraphs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); } @@ -177,17 +176,6 @@ struct ModelEditorGraph : public OrtGraph { return ONNX_NAMESPACE::Version::IR_VERSION; } - Status GetNumOperatorSets(size_t& /*num_operator_sets*/) const override { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "OrtModelEditorApi does not support getting the graph's operator sets."); - } - - Status GetOperatorSets(gsl::span /*domains*/, - gsl::span /*opset_versions*/) const override { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "OrtModelEditorApi does not support getting the graph's operator sets."); - } - size_t GetNumInputs() const override { return inputs.size(); } Status GetInputs(gsl::span /*result*/) const override { diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 4d85c35461825..3575e30721af7 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1020,7 +1020,6 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, - float Sink, MLAS_THREADPOOL* ThreadPool ); @@ -1224,21 +1223,6 @@ MlasQuantizeLinearS4( int8_t ZeroPoint ); -// -// Linear dequantization routines. -// - -template -void -MLASCALL -MlasDequantizeLinear( - const InputType* Input, - float* Output, - size_t N, - float Scale, - InputType ZeroPoint - ); - /** * @brief Requantize a block of the intermediate buffer to the output buffer, * optionally adding the supplied bias diff --git a/onnxruntime/core/mlas/lib/compute.cpp b/onnxruntime/core/mlas/lib/compute.cpp index 669c73d2b9c06..96a2398796777 100644 --- a/onnxruntime/core/mlas/lib/compute.cpp +++ b/onnxruntime/core/mlas/lib/compute.cpp @@ -74,7 +74,6 @@ struct MLAS_SOFTMAX_WORK_BLOCK { ptrdiff_t ThreadCountN; bool LogSoftmax; bool SmoothSoftmax; - float Sink; const T* Input; T* Output; size_t N; @@ -851,7 +850,6 @@ Return Value: const size_t D = WorkBlock->D; const bool LogSoftmax = WorkBlock->LogSoftmax; const bool SmoothSoftmax = WorkBlock->SmoothSoftmax; - const float Sink = WorkBlock->Sink; const float* Input = WorkBlock->Input + n * D; float* Output = WorkBlock->Output + n * D; @@ -882,11 +880,10 @@ Return Value: #else float Maximum = MlasReduceMaximumF32Kernel(Input, D); #endif - if (SmoothSoftmax && Sink > Maximum) { - Maximum = Sink; - } - float NegativeMaximum = -Maximum; + if (SmoothSoftmax && NegativeMaximum > 0.0f) { + NegativeMaximum = 0.0f; + } // // Compute the exponential function for each element of the row (save to Temp if provided) and @@ -900,7 +897,7 @@ Return Value: #endif if (SmoothSoftmax) { - Accumulation += expf(Sink + NegativeMaximum); + Accumulation += expf(NegativeMaximum); } if (LogSoftmax) { @@ -1017,7 +1014,6 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, - float Sink, MLAS_THREADPOOL* ThreadPool ) /*++ @@ -1043,8 +1039,6 @@ Routine Description: SmoothSoftmax - Supplies true if a smooth factor is used in softmax operation. - Sink - Supplies the smooth factor to use in the softmax operation. - ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. @@ -1066,7 +1060,6 @@ Return Value: WorkBlock.Output = Output; WorkBlock.N = N; WorkBlock.D = D; - WorkBlock.Sink = Sink; // // Compute the number of target threads given the complexity of the softmax @@ -1104,7 +1097,6 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, - float Sink, MLAS_THREADPOOL* ThreadPool ); @@ -1118,7 +1110,6 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, - float Sink, MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/mlas/lib/dequantize.cpp b/onnxruntime/core/mlas/lib/dequantize.cpp deleted file mode 100644 index 175d3f668ac39..0000000000000 --- a/onnxruntime/core/mlas/lib/dequantize.cpp +++ /dev/null @@ -1,395 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - dequantize.cpp - -Abstract: - - This module implements routines to dequantize buffers. - - The dequantization formula as specified in the ONNX operator documentation is: - - Output = (Input - ZeroPoint) * Scale - ---*/ - -#include "mlasi.h" - -// -// DequantizeLinear reference implementation using the C++ runtime. -// - -template -static -MLAS_FORCEINLINE -void -MlasDequantizeLinearRefImpl( - const InputType* Input, - float* Output, - size_t N, - float Scale, - InputType ZeroPoint - ) -/*++ - -Routine Description: - - This routine quantizes the input buffer using the supplied quantization - parameters. - -Arguments: - - Input - Supplies the input buffer with quantized data. - - Output - Supplies the output buffer. - - N - Supplies the number of elements to process. - - Scale - Supplies the quantization scale. - - ZeroPoint - Supplies the quantization zero point value. - -Return Value: - - None. - ---*/ -{ - int32_t ZeroPointS32 = static_cast(ZeroPoint); - - for (size_t n = 0; n < N; n++) { - Output[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; - } -} - -#if defined(MLAS_SSE2_INTRINSICS) -// Implementation for Intel SSE 2. Refer to the Intel Intrisics Guide: -// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html - -void -MLASCALL -MlasDequantizeLinearS8Kernel( - const int8_t* Input, - float* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); - const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s - const __m128i Zeros = _mm_setzero_si128(); - - while (N >= 16) { - // Load a vector of 16 int8s: [0 ... 15] - __m128i VectorS8 = _mm_loadu_si128(reinterpret_cast(Input)); - - // Sign-extend into 2 vectors of 8 int16s - __m128i SignMaskS8 = _mm_cmpgt_epi8(Zeros, VectorS8); // 0xFF for every negative byte in VectorS8 - __m128i VectorS16_0 = _mm_unpacklo_epi8(VectorS8, SignMaskS8); // [0 ... 7] - __m128i VectorS16_1 = _mm_unpackhi_epi8(VectorS8, SignMaskS8); // [8 ... 15] - - // Subtract the zero-points in int16 domain. - VectorS16_0 = _mm_sub_epi16(VectorS16_0, ZeroPointS16Vector); - VectorS16_1 = _mm_sub_epi16(VectorS16_1, ZeroPointS16Vector); - - // Sign-extend into 4 vectors of 4 int32s - __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); - __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] - __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] - - __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); - __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] - __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] - - // Cast each int32x4 to float and multiply by the scale vector. - __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); - __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); - __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); - __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); - - // Store each int32x4 into the output. - _mm_storeu_ps(Output + 0, VectorF32_0); - _mm_storeu_ps(Output + 4, VectorF32_1); - _mm_storeu_ps(Output + 8, VectorF32_2); - _mm_storeu_ps(Output + 12, VectorF32_3); - - Input += 16; - Output += 16; - N -= 16; - } - - // Handle leftover elements (< 16) with the scalar reference implementation. - MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasDequantizeLinearU8Kernel( - const uint8_t* Input, - float* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ) -{ - const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); - const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s - const __m128i Zeros = _mm_setzero_si128(); - - while (N >= 16) { - // Load a vector of 16 uint8s: [0 ... 15] - __m128i VectorU8 = _mm_loadu_si128(reinterpret_cast(Input)); - - // Zero-extend into 2 vectors of 8 uint16s - __m128i VectorU16_0 = _mm_unpacklo_epi8(VectorU8, Zeros); // [0 ... 7] - __m128i VectorU16_1 = _mm_unpackhi_epi8(VectorU8, Zeros); // [8 ... 15] - - // Subtract the zero-points as uint16s. Due to two's compliment, negative results can be reinterpreted as int16 - __m128i VectorS16_0 = _mm_sub_epi16(VectorU16_0, ZeroPointS16Vector); - __m128i VectorS16_1 = _mm_sub_epi16(VectorU16_1, ZeroPointS16Vector); - - // Sign-extend into 4 vectors of 4 int32s - __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); - __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] - __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] - - __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); - __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] - __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] - - // Cast each int32x4 to float and multiply by the scale vector. - __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); - __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); - __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); - __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); - - // Store each int32x4 into the output. - _mm_storeu_ps(Output + 0, VectorF32_0); - _mm_storeu_ps(Output + 4, VectorF32_1); - _mm_storeu_ps(Output + 8, VectorF32_2); - _mm_storeu_ps(Output + 12, VectorF32_3); - - Input += 16; - Output += 16; - N -= 16; - } - - // Handle leftover elements (< 16) with the scalar reference implementation. - MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasDequantizeLinear( - const int8_t* Input, - float* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().DequantizeLinearS8Kernel( -#else - MlasDequantizeLinearS8Kernel( -#endif - Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasDequantizeLinear( - const uint8_t* Input, - float* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ) -{ -#if defined(MLAS_TARGET_AMD64) - GetMlasPlatform().DequantizeLinearU8Kernel( -#else - MlasDequantizeLinearU8Kernel( -#endif - Input, Output, N, Scale, ZeroPoint); -} -#elif defined(MLAS_NEON64_INTRINSICS) -// Implementation for ARM64 NEON. Refer to the ARM instrinsics guide: -// https://developer.arm.com/architectures/instruction-sets/intrinsics/ - -void -MLASCALL -MlasDequantizeLinearS8Kernel( - const int8_t* Input, - float* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); - const int16x8_t ZeroPointVector = vdupq_n_s16(ZeroPoint); // Broadcast ZeroPoint (sign-extended to 16bits) - - while (N >= 16) { - // Load a vector of 16 int8s: [0 ... 15] - int8x16_t VectorS8 = vld1q_s8(Input); - - // Sign-extend into 2 vectors of 8 int16s - int16x8_t VectorS16_0 = vmovl_s8(vget_low_s8(VectorS8)); // [0 ... 7] - int16x8_t VectorS16_1 = vmovl_s8(vget_high_s8(VectorS8)); // [8 ... 15] - - // Subtract the zero-points in int16 domain. - VectorS16_0 = vsubq_s16(VectorS16_0, ZeroPointVector); - VectorS16_1 = vsubq_s16(VectorS16_1, ZeroPointVector); - - // Sign-extend into 4 vectors of 4 int32s - int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] - int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] - int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] - int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] - - // Cast each int32x4 to float and multiply by the scale vector. - float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); - float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); - float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); - float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); - - // Store each int32x4 into the output. - vst1q_f32(Output + 0, VectorF32_0); - vst1q_f32(Output + 4, VectorF32_1); - vst1q_f32(Output + 8, VectorF32_2); - vst1q_f32(Output + 12, VectorF32_3); - - N -= 16; - Input += 16; - Output += 16; - } - - // Handle leftover elements (< 16) with the scalar reference implementation. - MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); -} - -void -MLASCALL -MlasDequantizeLinearU8Kernel( - const uint8_t* Input, - float* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ) -{ - const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); - const uint8x8_t ZeroPointVector = vdup_n_u8(ZeroPoint); // Broadcast ZeroPoint to 8 uint8s - - while (N >= 16) { - // Load a vector of 16 uint8s: [0 ... 15] - uint8x16_t VectorU8 = vld1q_u8(Input); - - // Subtract zero-point. The vsubl_u8 instruction zero-extends its arguments to uint16 first. - // The reinterpret from uint16x8 to int16x8 is actually a NOP. - int16x8_t VectorS16_0 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(VectorU8), ZeroPointVector)); // [0 ... 7] - int16x8_t VectorS16_1 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(VectorU8), ZeroPointVector)); // [8 ... 15] - - // Sign-extend into 4 vectors of 4 int32s - int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] - int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] - int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] - int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] - - // Cast each int32x4 to float and multiply by the scale vector. - float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); - float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); - float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); - float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); - - // Store each int32x4 into the output. - vst1q_f32(Output + 0, VectorF32_0); - vst1q_f32(Output + 4, VectorF32_1); - vst1q_f32(Output + 8, VectorF32_2); - vst1q_f32(Output + 12, VectorF32_3); - - N -= 16; - Input += 16; - Output += 16; - } - - // Handle leftover elements (< 16) with the scalar reference implementation. - MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasDequantizeLinear( - const int8_t* Input, - float* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ) -{ - MlasDequantizeLinearS8Kernel(Input, Output, N, Scale, ZeroPoint); -} - -template<> -void -MLASCALL -MlasDequantizeLinear( - const uint8_t* Input, - float* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ) -{ - MlasDequantizeLinearU8Kernel(Input, Output, N, Scale, ZeroPoint); -} -#else -// Implementation that uses the scalar reference implementation. - -template -void -MLASCALL -MlasDequantizeLinear( - const InputType* Input, - float* Output, - size_t N, - float Scale, - InputType ZeroPoint - ) -{ - MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); -} - -template -void -MLASCALL -MlasDequantizeLinear( - const int8_t* Input, - float* Output, - size_t N, - float Scale, - int8_t ZeroPoint - ); - -template -void -MLASCALL -MlasDequantizeLinear( - const uint8_t* Input, - float* Output, - size_t N, - float Scale, - uint8_t ZeroPoint - ); - -#endif diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 0879d1b0ba510..0af3cd2e33b02 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -747,24 +747,6 @@ void float Scale, int8_t ZeroPoint); -typedef -void -(MLASCALL MLAS_DEQUANTIZE_LINEAR_U8_KERNEL)( - const uint8_t* Input, - float* Output, - size_t N, - float Scale, - uint8_t ZeroPoint); - -typedef -void -(MLASCALL MLAS_DEQUANTIZE_LINEAR_S8_KERNEL)( - const int8_t* Input, - float* Output, - size_t N, - float Scale, - int8_t ZeroPoint); - template struct MLAS_QUANT_KERNEL { @@ -921,8 +903,6 @@ extern "C" { MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; #if defined(MLAS_TARGET_AMD64) - MLAS_DEQUANTIZE_LINEAR_S8_KERNEL MlasDequantizeLinearS8Kernel; - MLAS_DEQUANTIZE_LINEAR_U8_KERNEL MlasDequantizeLinearU8Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelAvx512F; @@ -1266,8 +1246,6 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; - MLAS_DEQUANTIZE_LINEAR_S8_KERNEL* DequantizeLinearS8Kernel; - MLAS_DEQUANTIZE_LINEAR_U8_KERNEL* DequantizeLinearU8Kernel; uint32_t NchwcBlockSize; uint32_t PreferredBufferAlignment; int32_t MaximumThreadCount; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 45bba5363d4f2..45d3a876beb86 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -285,8 +285,6 @@ Return Value: this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; - this->DequantizeLinearS8Kernel = MlasDequantizeLinearS8Kernel; - this->DequantizeLinearU8Kernel = MlasDequantizeLinearU8Kernel; #ifndef __APPLE__ #ifndef FORCE_GENERIC_ALGORITHMS this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse; diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index fa645939a6395..dcc030cb3467d 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -89,10 +89,23 @@ uint64_t GetLuidKey(LUID luid) { return (uint64_t(luid.HighPart) << 32) | luid.LowPart; } +// Converts a wide string (up to 4 characters) representing a hardware ID component (e.g., "ABCD" from "VEN_ABCD") +// into a uint32_t. The conversion is done in a little-endian manner, meaning the first character +// of the string becomes the least significant byte of the integer, and the fourth character +// becomes the most significant byte. +uint32_t WStringToUint32Id(const std::wstring& vendor_name) { + uint32_t vendor_id = 0; + for (size_t i = 0; i < 4 && i < vendor_name.size(); ++i) { + // For little-endian, place each character at the appropriate byte position + // First character goes into lowest byte, last character into highest byte + vendor_id |= static_cast(vendor_name[i] & 0xFF) << (i * 8); + } + return vendor_id; +} + // returns info for display and processor entries. key is (vendor_id << 32 | device_id) // npus: (vendor_id << 32 | device_id) for devices we think are NPUs from DXCORE -std::unordered_map GetDeviceInfoSetupApi(const std::unordered_set& npus, - bool& have_remote_display_adapter) { +std::unordered_map GetDeviceInfoSetupApi(const std::unordered_set& npus) { std::unordered_map device_info; const GUID local_DXCORE_ADAPTER_ATTRIBUTE_D3D12_GENERIC_ML = {0xb71b0d41, 0x1088, 0x422f, 0xa2, 0x7c, 0x2, 0x50, 0xb7, 0xd3, 0xa9, 0x88}; @@ -138,7 +151,8 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde if (auto idx = hardware_id.find(prefix); idx != std::wstring::npos) { auto id = hardware_id.substr(idx + prefix.size(), 4); if (id.size() == 4) { - return static_cast(std::stoul(id, nullptr, 16)); + // DXCore reports vendor and device IDs as 32-bit integer representations of the ASCII string. + return WStringToUint32Id(id); } } @@ -156,11 +170,6 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde // Won't always have a vendor id from an ACPI entry. ACPI is not defined for this purpose. if (vendor_id == 0 && device_id == 0) { - static const std::wstring remote_display_adapter_id(L"RdpIdd_IndirectDisplay"); - if (guid == GUID_DEVCLASS_DISPLAY && remote_display_adapter_id == buffer) { - have_remote_display_adapter = true; - } - continue; } @@ -296,7 +305,7 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde } // returns LUID to DeviceInfo -std::unordered_map GetDeviceInfoD3D12(bool have_remote_display_adapter) { +std::unordered_map GetDeviceInfoD3D12() { std::unordered_map device_info; ComPtr factory; @@ -305,8 +314,6 @@ std::unordered_map GetDeviceInfoD3D12(bool have_remote_dis return device_info; } - UINT num_adapters = 0; - ComPtr adapter; for (UINT i = 0; factory->EnumAdapters1(i, adapter.ReleaseAndGetAddressOf()) != DXGI_ERROR_NOT_FOUND; ++i) { DXGI_ADAPTER_DESC1 desc; @@ -332,12 +339,9 @@ std::unordered_map GetDeviceInfoD3D12(bool have_remote_dis info.metadata[L"LUID"] = std::to_wstring(key); info.metadata[L"DxgiAdapterNumber"] = std::to_wstring(i); info.metadata[L"DxgiVideoMemory"] = std::to_wstring(desc.DedicatedVideoMemory / (1024 * 1024)) + L" MB"; - - ++num_adapters; } - // iterate by high-performance GPU preference to add that info. - UINT cur_adapter = 0; + // iterate by high-performance GPU preference to add that info for (UINT i = 0; factory->EnumAdapterByGpuPreference( i, DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE, IID_PPV_ARGS(adapter.ReleaseAndGetAddressOf())) != DXGI_ERROR_NOT_FOUND; @@ -348,41 +352,12 @@ std::unordered_map GetDeviceInfoD3D12(bool have_remote_dis } uint64_t key = GetLuidKey(desc.AdapterLuid); - auto it = device_info.find(key); - if (it == device_info.end()) { - continue; - } - DeviceInfo& info = it->second; - - // try and drop the Microsoft Remote Display Adapter. it does not have the DXGI_ADAPTER_FLAG_SOFTWARE flag set - // and the vendor id, device id and description are the same as the real device. the LUID is different to the real - // device. - // Assumption: it will have the worst performance index of the devices we're considering so we only check the - // last adapter - if (num_adapters > 1 && have_remote_display_adapter && cur_adapter == num_adapters - 1) { - ComPtr output; - if (adapter->EnumOutputs(0, &output) == DXGI_ERROR_NOT_FOUND) { - // D3D_DRIVER_TYPE_WARP. Software based or disabled adapter. - // An adapter can be disabled in an RDP session. e.g. integrated GPU is disabled if there's a discrete GPU - - // if we have seen this vendor_id+device_id combination with a different LUID before we drop it. - if (std::any_of(device_info.begin(), device_info.end(), - [key, &info](const auto& entry) { - const auto& entry_info = entry.second; - return key != entry.first && - info.vendor_id == entry_info.vendor_id && - info.device_id == entry_info.device_id; - })) { - device_info.erase(key); - continue; - } - } + auto it = device_info.find(key); + if (it != device_info.end()) { + DeviceInfo& info = it->second; + info.metadata[L"DxgiHighPerformanceIndex"] = std::to_wstring(i); } - - info.metadata[L"DxgiHighPerformanceIndex"] = std::to_wstring(i); - - ++cur_adapter; } return device_info; @@ -522,12 +497,10 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor } } - // setupapi_info. key is vendor_id+device_id - bool have_remote_display_adapter = false; // set if we see the RdpIdd_IndirectDisplay hardware ID. - std::unordered_map setupapi_info = GetDeviceInfoSetupApi(npus, have_remote_display_adapter); - // d3d12 info. key is luid - std::unordered_map luid_to_d3d12_info = GetDeviceInfoD3D12(have_remote_display_adapter); + std::unordered_map luid_to_d3d12_info = GetDeviceInfoD3D12(); + // setupapi_info. key is vendor_id+device_id + std::unordered_map setupapi_info = GetDeviceInfoSetupApi(npus); // Ensure we have at least one CPU bool found_cpu = false; diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index e123414b03b21..2817dda9d0085 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -99,7 +99,7 @@ common::Status SoftmaxCPU(size_t N, float* Ydata, bool logarithmic, onnxruntime::concurrency::ThreadPool* thread_pool) { - MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, 0.0f, thread_pool); + MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, thread_pool); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h index f7cc2523adbf6..3359b2a69fe83 100644 --- a/onnxruntime/core/providers/cpu/ml/ml_common.h +++ b/onnxruntime/core/providers/cpu/ml/ml_common.h @@ -445,7 +445,7 @@ void batched_update_scores_inplace(gsl::span scores, int64_t num_batches_in, } if (use_mlas) { - MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, 0.0f, threadpool); + MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, threadpool); } else { while (s < s_end) { gsl::span scores_for_batch(s, s + batch_size); diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index c691be6ffd0e8..adb2aee171f39 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include #include "core/framework/element_type_lists.h" #include "core/framework/float8.h" @@ -302,31 +301,14 @@ struct DequantizeLinearApply { * @param[in] zero_point same shape as scale */ void op(size_t M, size_t K, size_t N, const T* input, - const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { + const OutT* scale, OutT* output, const T* zero_point) { for (size_t m = 0; m < M; m++) { for (size_t k = 0; k < K; k++) { -#if defined(ORT_CLIENT_PACKAGE_BUILD) - // TODO: Only using multithreaded/SIMD DQ when ORT is built for client/on-device workloads. - // Make this the default behavior after more testing. - if constexpr (std::is_same_v || std::is_same_v) { - ParDequantizeLinearStd(input, output, N, scale[k], zero_point ? zero_point[k] : 0, thread_pool); - input += N; - output += N; - } else { - auto zp = zero_point ? static_cast(zero_point[k]) : 0; - auto sc = static_cast(scale[k]); - for (size_t n = 0; n < N; n++) { - *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); - } - } -#else - ORT_UNUSED_PARAMETER(thread_pool); auto zp = zero_point ? static_cast(zero_point[k]) : 0; auto sc = static_cast(scale[k]); for (size_t n = 0; n < N; n++) { *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); } -#endif // defined(ORT_CLIENT_PACKAGE_BUILD) } } } @@ -345,8 +327,7 @@ struct DequantizeLinearApply { * @param[in] zero_point same shape as scale */ void op(size_t M, size_t K, size_t N, size_t quant_block_size, - const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { - ORT_UNUSED_PARAMETER(thread_pool); + const T* input, const OutT* scale, OutT* output, const T* zero_point) { if (zero_point) { for (size_t m = 0; m < M; m++) { for (size_t bd = 0; bd < K; bd += quant_block_size) { @@ -387,8 +368,7 @@ template struct DequantizeLinearApply { // per-tensor/layer or per-axis quantization void op(size_t M, size_t K, size_t N, - const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { - ORT_UNUSED_PARAMETER(thread_pool); + const T* input, const OutT* scale, OutT* output, const T* zero_point) { size_t input_index = 0; for (size_t m = 0; m < M; m++) { @@ -414,8 +394,7 @@ struct DequantizeLinearApply { // Blocked quantization // TODO(fajin) : add mlas kernel to utilize multithreading, refer MlasDequantizeBlockwise. void op(size_t M, size_t K, size_t N, size_t quant_block_size, - const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { - ORT_UNUSED_PARAMETER(thread_pool); + const T* input, const OutT* scale, OutT* output, const T* zero_point) { size_t input_index = 0; if (zero_point) { @@ -461,36 +440,36 @@ struct DequantizeLinearApply { #if !defined(DISABLE_FLOAT8_TYPES) -#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ - template \ - struct DequantizeLinearApply { \ - /* Per-tensor/layer or per-axis quantization */ \ - void op(size_t M, size_t K, size_t N, \ - const T* input, const OutT* scale, OutT* output, const T*, concurrency::ThreadPool*) { \ - for (size_t m = 0; m < M; m++) { \ - for (size_t bd = 0; bd < K; bd++) { \ - auto sc = scale[bd]; \ - for (size_t bs = 0; bs < N; bs++, input++) { \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - } \ - } \ - /* Blocked quantization */ \ - void op(size_t M, size_t K, size_t N, size_t quant_block_size, \ - const T* input, const OutT* scale, OutT* output, const T*, concurrency::ThreadPool*) { \ - for (size_t m = 0; m < M; m++) { \ - for (size_t bd = 0; bd < K; bd += quant_block_size) { \ - for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { \ - for (size_t bs = 0; bs < N; bs++, input++) { \ - auto sc = static_cast(scale[bs]); \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - scale += N; \ - } \ - } \ - } \ +#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ + template \ + struct DequantizeLinearApply { \ + /* Per-tensor/layer or per-axis quantization */ \ + void op(size_t M, size_t K, size_t N, \ + const T* input, const OutT* scale, OutT* output, const T*) { \ + for (size_t m = 0; m < M; m++) { \ + for (size_t bd = 0; bd < K; bd++) { \ + auto sc = scale[bd]; \ + for (size_t bs = 0; bs < N; bs++, input++) { \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + } \ + } \ + /* Blocked quantization */ \ + void op(size_t M, size_t K, size_t N, size_t quant_block_size, \ + const T* input, const OutT* scale, OutT* output, const T*) { \ + for (size_t m = 0; m < M; m++) { \ + for (size_t bd = 0; bd < K; bd += quant_block_size) { \ + for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { \ + for (size_t bs = 0; bs < N; bs++, input++) { \ + auto sc = static_cast(scale[bs]); \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + scale += N; \ + } \ + } \ + } \ }; DEQUANTIZE_LINEAR_APPLY_FLOAT8(Float8E4M3FN) @@ -534,7 +513,6 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { const auto to = x_scale.GetElementType(); const T* input = x.Data(); constexpr bool is_4bit = boost::mp11::mp_contains, T>::value; - concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); if (to == ONNX_NAMESPACE::TensorProto::FLOAT) { const float* scale = x_scale.Data(); @@ -544,12 +522,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { static_cast(broadcast_dim), static_cast(process_block_size), static_cast(block_size_), - input, scale, output, zero_point, thread_pool); + input, scale, output, zero_point); } else { DequantizeLinearApply().op(static_cast(process_block_count), static_cast(broadcast_dim), static_cast(process_block_size), - input, scale, output, zero_point, thread_pool); + input, scale, output, zero_point); } } else if (to == ONNX_NAMESPACE::TensorProto::FLOAT16) { const MLFloat16* scale = x_scale.Data(); @@ -559,12 +537,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { static_cast(broadcast_dim), static_cast(process_block_size), static_cast(block_size_), - input, scale, output, zero_point, thread_pool); + input, scale, output, zero_point); } else { DequantizeLinearApply().op(static_cast(process_block_count), static_cast(broadcast_dim), static_cast(process_block_size), - input, scale, output, zero_point, thread_pool); + input, scale, output, zero_point); } } else if (to == ONNX_NAMESPACE::TensorProto::BFLOAT16) { ORT_THROW("DequantizeLinear into BFLOAT16 is not implemented yet."); diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index f00bf51ae143d..2de496a9168a0 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -313,10 +313,8 @@ CUDA_Provider* GetProvider() { // OrtEpApi infrastructure to be able to use the CUDA EP as an OrtEpFactory for auto EP selection. struct CudaEpFactory : OrtEpFactory { CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in} { - ort_version_supported = ORT_API_VERSION; GetName = GetNameImpl; GetVendor = GetVendorImpl; - GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; @@ -333,11 +331,6 @@ struct CudaEpFactory : OrtEpFactory { return factory->vendor.c_str(); } - static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { - const auto* factory = static_cast(this_ptr); - return factory->vendor_id; - } - static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { return ORT_VERSION; } @@ -381,7 +374,6 @@ struct CudaEpFactory : OrtEpFactory { const OrtApi& ort_api; const std::string ep_name{kCudaExecutionProvider}; // EP name const std::string vendor{"Microsoft"}; // EP vendor name - uint32_t vendor_id{0x1414}; // Microsoft vendor ID }; extern "C" { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 9611cb82d5a62..a5066a41981e5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -781,10 +781,7 @@ namespace Dml // this branch could be reached with a bad custom operator or malformed file. If // a legitimate case reaches here and DML needs to support a new input/output type // besides tensors, then remove the assert. - - // If the model has nodes that use Optional we will arrive here. It's a valid ONNX model but - // TryGetTensorDataType doesn't handle Optional. - // assert(false); + assert(false); nodeContainsSupportedDataTypes = false; return; } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index c5b6507ac847b..711d81186bad1 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1304,7 +1304,7 @@ std::vector NvExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - return std::make_unique(CUDA_PINNED, device_id); + return std::make_unique(device_id, CUDA_PINNED); }, narrow(device_id_)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 21947a22e2b92..86b684f8c6ebd 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -235,7 +235,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(reshape_input, reshape_input_info)); bool needs_reshape = false; - const std::string reshape_prior_out = input_names[0] + "_prior_reshape"; + const std::string reshape4d = input_names[0] + "_pre_reshape"; if (input_shape.size() == 3) { needs_reshape = true; // build new_shape = {N, 1, C, L} @@ -245,24 +245,25 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra input_shape[1], input_shape[2]}; - QnnTensorWrapper reshape_prior_tensor( - reshape_prior_out, + const std::string reshape_node_name = "pre_reshape"; + QnnTensorWrapper rw( + reshape4d, QNN_TENSOR_TYPE_NATIVE, reshape_input_info.qnn_data_type, reshape_input_info.quant_param.Copy(), std::move(new_shape)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_prior_tensor)), - "Failed to add reshape prior tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(rw)), + "Failed to add reshape-4d tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - utils::GetNodeName(node_unit) + "_reshape_prior", + reshape_node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_RESHAPE, + "Reshape", {input_names[0]}, - {reshape_prior_out}, + {reshape4d}, {}, do_op_validation), - "Failed to create reshape prior node for pool op."); - input_names[0] = reshape_prior_out; + "Failed to create reshape-4d node."); + input_names[0] = reshape4d; input_shape = {input_shape[0], 1, input_shape[1], input_shape[2]}; } @@ -445,7 +446,9 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } const auto& outputs = node_unit.Outputs(); const std::string real_out = outputs[0].node_arg.Name(); - const std::string pool_out = real_out + "_reshape_after"; + const std::string pool_name = "poolmax2d"; + const std::string pool_out = real_out + "_post_reshape"; + const std::string post_reshape_node_name = "post_reshape"; const std::string qnn_op = GetQnnOpType(op_type); TensorInfo output_info{}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); @@ -463,34 +466,33 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra "Failed to add tensor for pool_out"); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - utils::GetNodeName(node_unit) + "_pool2d", + pool_name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op, - {reshape_prior_out}, + {reshape4d}, {pool_out}, std::move(param_tensor_names), do_op_validation), - "Failed to create pool node for rank-3 input."); + "Failed to create QNN Pool node for rank-3 input."); std::vector final_shape3d = output_info.shape; - QnnTensorWrapper reshape_after_tensor( + QnnTensorWrapper reshape_back_tensor( real_out, tensor_type, output_info.qnn_data_type, output_info.quant_param.Copy(), std::move(final_shape3d)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_after_tensor)), - "Failed to add reshape after tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_back_tensor)), "Failed to add tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - utils::GetNodeName(node_unit) + "_reshape_after", + post_reshape_node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_RESHAPE, + "Reshape", {pool_out}, {real_out}, {}, do_op_validation), - "Failed to create reshape after node for pool op."); + "Failed to create reshape-back node."); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 502ea86b689f4..2650316dd07ac 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -9,7 +9,7 @@ namespace onnxruntime { namespace qnn { -// Operator which only need to handle node inputs & outputs, no attributes or no need to handle attributes +// Operator which only need to hanle node inputs & outputs, no attributes or no need to handle attributes class SimpleOpBuilder : public BaseOpBuilder { public: SimpleOpBuilder() : BaseOpBuilder("SimpleOpBuilder") {} @@ -38,7 +38,7 @@ class SimpleOpBuilder : public BaseOpBuilder { const logging::Logger& logger, bool do_op_validation) const ORT_MUST_USE_RESULT; - static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest", "linear"}; + static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest"}; static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; static constexpr std::array scatternd_supported_reduction = {"none", "add", "mul"}; }; @@ -60,8 +60,8 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, // To DO: Remove once QNN CPU supports ScatterND const auto qnn_backend_type = qnn_model_wrapper.GetQnnBackendType(); if (op_type == "ScatterND") { - ORT_RETURN_IF(qnn_backend_type == QnnBackendType::CPU, - "QNN EP does not support ScatterND op on CPU backend. Falling back to ORT CPU."); + ORT_RETURN_IF_NOT(qnn_backend_type == QnnBackendType::HTP, + "QNN EP only supports ScatterND op on HTP backend. Falling back to ORT CPU."); } // ONNX's Min, Max, and Sum operators accept a variable number of inputs (i.e., variadic). @@ -233,12 +233,12 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, std::string mode = node_helper.Get("mode", "linear"); Qnn_Scalar_t mode_qnn_scalar = QNN_SCALAR_INIT; mode_qnn_scalar.dataType = QNN_DATATYPE_UINT_32; - if ("linear" == mode || "bilinear" == mode) { + if ("bilinear" == mode) { mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_BILINEAR; } else if ("nearest" == mode) { mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_NEAREST; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support [linear, bilinear, nearest]."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support bilinear & nearest."); } QnnParamWrapper mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_MODE, mode_qnn_scalar); param_tensor_names.push_back(mode_param.GetParamTensorName()); @@ -254,7 +254,7 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, } else if ("reflection" == padding_mode) { padding_mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_PADDING_MODE_REFLECTION; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support [zeros, border, reflection]."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support zeros, border & reflection."); } QnnParamWrapper padding_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_PADDING_MODE, padding_mode_qnn_scalar); param_tensor_names.push_back(padding_mode_param.GetParamTensorName()); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 3dc103046424e..d22edaf33eb1c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -839,23 +839,6 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord return Status::OK(); } -Status QnnBackendManager::SetContextPriority(ContextPriority context_priority) { - QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; - ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority, context_priority_config)); - - QnnContext_Config_t* configs[] = {&context_priority_config, nullptr}; - for (const auto& context_handle : contexts_) { - auto result = qnn_interface_.contextSetConfig(context_handle, (const QnnContext_Config_t**)configs); - ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to set context priority for context handle: ", context_handle); - } - - return Status::OK(); -} - -Status QnnBackendManager::ResetContextPriority() { - return SetContextPriority(context_priority_); -} - Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { if (true == context_created_) { LOGS_DEFAULT(INFO) << "Context created already."; @@ -1443,33 +1426,13 @@ Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, return Status::OK(); } -Status QnnBackendManager::SetRpcPowerConfigs(uint32_t htp_power_config_client_id, - uint32_t rpc_control_latency, - uint32_t rpc_polling_time) { +Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency) { // This function is called in QNN EP's OnRunStart() even if QNN backend setup failed and the model is assigned // to a different EP. Therefore, we have to check that backend setup actually completed before trying to // set RPC control latency. Otherwise, this causes a segfault because the QNN backend library is unloaded. ORT_RETURN_IF_NOT(backend_setup_completed_, "Cannot set HTP RPC control latency if backend setup is not complete."); - - constexpr int kNumRpcPollingPowerConfigs = 2; - std::vector rpc_power_configs; - rpc_power_configs.reserve(kNumRpcPollingPowerConfigs); - - // Set rpc control latency here if (rpc_control_latency != 0) { - auto& rpc_control_latency_cfg = rpc_power_configs.emplace_back(); - rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; - rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency; - } - - // Note: v68 does not support rpc polling mode - if (rpc_polling_time != 0) { - auto& rpc_polling_time_cfg = rpc_power_configs.emplace_back(); - rpc_polling_time_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME; - rpc_polling_time_cfg.rpcPollingTimeConfig = rpc_polling_time; - } - - if (rpc_power_configs.size() > 0) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); @@ -1479,6 +1442,15 @@ Status QnnBackendManager::SetRpcPowerConfigs(uint32_t htp_power_config_client_id "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; + // Set rpc control latency here, but note that v68 doesn't support rpc polling mode. + constexpr int kNumRpcPollingPowerConfigs = 2; + std::vector rpc_power_configs(kNumRpcPollingPowerConfigs); + QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency_cfg = rpc_power_configs[0]; + // v68 doesn't support this. + QnnHtpPerfInfrastructure_PowerConfig_t& rpc_polling_time = rpc_power_configs[1]; + rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; + rpc_polling_time.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME; + rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency; std::vector perf_power_configs_ptr = ObtainNullTermPtrVector(rpc_power_configs); status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data()); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 2a71c7391b180..3e68df3024565 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -159,9 +159,8 @@ class QnnBackendManager : public std::enable_shared_from_this Status SetHtpPowerConfig(uint32_t htp_power_config_client_id, HtpPerformanceMode htp_performance_mode); - Status SetRpcPowerConfigs(uint32_t htp_power_config_client_id, - uint32_t rpc_control_latency, - uint32_t rpc_polling_time); + Status SetRpcControlLatency(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency); const QNN_INTERFACE_VER_TYPE& GetQnnInterface() { return qnn_interface_; } @@ -220,11 +219,6 @@ class QnnBackendManager : public std::enable_shared_from_this // For each node name, a mapping to the context handle will be created void ProcessContextFromBinListAsync(Qnn_ContextHandle_t handle, void* notifyParam); - // Sets the context priority to the given value, if valid - Status SetContextPriority(ContextPriority context_priority); - // Resets the context priority to the session default as defined by context_priority_ - Status ResetContextPriority(); - private: Status LoadBackend(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 3acb3347acee1..236447cc95c3d 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -1356,8 +1356,7 @@ QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* uint32_t device_id, uint32_t core_id, qnn::HtpPerformanceMode default_htp_performance_mode, - uint32_t default_rpc_control_latency, - uint32_t default_rpc_polling_time) + uint32_t default_rpc_control_latency) : qnn_backend_manager_(qnn_backend_manager) { Status rt = qnn_backend_manager_->CreateHtpPowerCfgId(device_id, core_id, htp_power_config_id_); is_htp_power_config_id_valid_ = rt.IsOK(); @@ -1368,10 +1367,9 @@ QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetHtpPowerConfig(htp_power_config_id_, default_htp_performance_mode)); } - if (default_rpc_control_latency > 0 || default_rpc_polling_time > 0) { - ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcPowerConfigs(htp_power_config_id_, - default_rpc_control_latency, - default_rpc_polling_time)); + if (default_rpc_control_latency > 0) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcControlLatency(htp_power_config_id_, + default_rpc_control_latency)); } } } @@ -1402,8 +1400,7 @@ QNNExecutionProvider::PerThreadContext& QNNExecutionProvider::GetPerThreadContex if (context_state_.retired_context_pool.empty()) { uint32_t core_id = 0; context = std::make_shared(qnn_backend_manager_.get(), device_id_, core_id, - default_htp_performance_mode_, default_rpc_control_latency_, - default_rpc_polling_time_); + default_htp_performance_mode_, default_rpc_control_latency_); } else { context = context_state_.retired_context_pool.back(); context_state_.retired_context_pool.pop_back(); @@ -1471,21 +1468,15 @@ Status QNNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_optio LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; } - uint32_t rpc_polling_time = 0; - if (qnn::HtpPerformanceMode::kHtpBurst != htp_performance_mode) { - rpc_polling_time = 9999; - } - if (GetPerThreadContext().IsHtpPowerConfigIdValid()) { if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) { ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), htp_performance_mode)); } - if (rpc_control_latency > 0 || rpc_polling_time > 0) { - ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcPowerConfigs(GetPerThreadContext().GetHtpPowerConfigId(), - rpc_control_latency, - rpc_polling_time)); + if (rpc_control_latency > 0) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcControlLatency(GetPerThreadContext().GetHtpPowerConfigId(), + rpc_control_latency)); } } @@ -1554,38 +1545,4 @@ OrtDevice QNNExecutionProvider::GetOrtDeviceByMemType(OrtMemType /* em_type */) return default_device_; } -Status QNNExecutionProvider::SetEpDynamicOptions(gsl::span keys, - gsl::span values) { - if (keys.size() != values.size()) { - LOGS_DEFAULT(ERROR) << "SetEpDynamicOptions: number of keys (" << keys.size() - << ") does not equal number of values (" << values.size() << ")."; - } - auto key_it = keys.begin(); - auto value_it = values.begin(); - - while (key_it != keys.end() && value_it != values.end()) { - std::string key(*key_it); - std::string value(*value_it); - - if (key == kOrtEpDynamicOptionsWorkloadType) { - if (value == "Default") { - ORT_RETURN_IF_ERROR(qnn_backend_manager_->ResetContextPriority()); - } else if (value == "Efficient") { - ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetContextPriority(qnn::ContextPriority::LOW)); - } else { - LOGS_DEFAULT(ERROR) << "Invalid EP Workload Type: " << value; - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid EP Workload Type."); - } - } else { - LOGS_DEFAULT(ERROR) << "EP Dynamic Option \"" << key << "\" is not currently supported."; - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported EP Dynamic Option"); - } - - key_it++; - value_it++; - } - - return Status::OK(); -} - } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 6adf613932d66..06f9726ae96cf 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -57,9 +57,6 @@ class QNNExecutionProvider : public IExecutionProvider { OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; - Status SetEpDynamicOptions(gsl::span keys, - gsl::span value) override; - private: std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, @@ -99,7 +96,6 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t device_id_ = 0; qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; uint32_t default_rpc_control_latency_ = 0; - uint32_t default_rpc_polling_time_ = 0; bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; bool stop_share_ep_contexts_ = false; @@ -120,8 +116,7 @@ class QNNExecutionProvider : public IExecutionProvider { PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, uint32_t device_id, uint32_t core_id, qnn::HtpPerformanceMode default_htp_performance_mode, - uint32_t default_rpc_control_latency, - uint32_t default_rpc_polling_time); + uint32_t default_rpc_control_latency); ~PerThreadContext(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc index 785177ce37788..c679ea1adb286 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc @@ -125,10 +125,8 @@ struct QnnEpFactory : OrtEpFactory { OrtHardwareDeviceType hw_type, const char* qnn_backend_type) : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, qnn_backend_type{qnn_backend_type} { - ort_version_supported = ORT_API_VERSION; GetName = GetNameImpl; GetVendor = GetVendorImpl; - GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; @@ -144,12 +142,7 @@ struct QnnEpFactory : OrtEpFactory { static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); - return factory->ep_vendor.c_str(); - } - - static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { - const auto* factory = static_cast(this_ptr); - return factory->ep_vendor_id; + return factory->vendor.c_str(); } static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { @@ -202,9 +195,8 @@ struct QnnEpFactory : OrtEpFactory { } const OrtApi& ort_api; - const std::string ep_name; // EP name - const std::string ep_vendor{"Microsoft"}; // EP vendor name - uint32_t ep_vendor_id{0x1414}; // Microsoft vendor ID + const std::string ep_name; // EP name + const std::string vendor{"Microsoft"}; // EP vendor name // Qualcomm vendor ID. Refer to the ACPI ID registry (search Qualcomm): https://uefi.org/ACPI_ID_List const uint32_t vendor_id{'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24)}; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 1e9fafe8aa323..90a4294fb47f0 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -7,25 +7,6 @@ #include "tensorrt_execution_provider_custom_ops.h" #include "tensorrt_execution_provider.h" -// The filename extension for a shared library is different per platform -#ifdef _WIN32 -#define LIBRARY_PREFIX -#define LIBRARY_EXTENSION ORT_TSTR(".dll") -#elif defined(__APPLE__) -#define LIBRARY_PREFIX "lib" -#define LIBRARY_EXTENSION ".dylib" -#else -#define LIBRARY_PREFIX "lib" -#define LIBRARY_EXTENSION ".so" -#endif - -#ifdef _WIN32 -#define ORT_DEF2STR_HELPER(x) L#x -#else -#define ORT_DEF2STR_HELPER(X) #X -#endif -#define ORT_DEF2STR(x) ORT_DEF2STR_HELPER(x) - namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(bool verbose); @@ -77,31 +58,8 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& // Get all registered TRT plugins from registry LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ..."; TensorrtLogger trt_logger = GetTensorrtLogger(false); - try { - void* library_handle = nullptr; - const auto& env = onnxruntime::GetDefaultEnv(); -#if NV_TENSORRT_MAJOR < 10 - auto full_path = env.GetRuntimePath() + - PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin") LIBRARY_EXTENSION); -#else -#ifdef _WIN32 - auto full_path = PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin_" ORT_DEF2STR(NV_TENSORRT_MAJOR)) LIBRARY_EXTENSION); -#else - auto full_path = PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin") LIBRARY_EXTENSION ORT_TSTR("." ORT_DEF2STR(NV_TENSORRT_MAJOR))); -#endif -#endif - - ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, false, &library_handle)); + initLibNvInferPlugins(&trt_logger, ""); - bool (*dyn_initLibNvInferPlugins)(void* logger, char const* libNamespace); - ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "initLibNvInferPlugins", (void**)&dyn_initLibNvInferPlugins)); - if (!dyn_initLibNvInferPlugins(&trt_logger, "")) { - LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugin library was found but was not able to initialize default plugins."; - } - LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugins successfully loaded."; - } catch (const std::exception&) { - LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugin library is not on the path and is therefore ignored"; - } int num_plugin_creator = 0; auto plugin_creators = getPluginRegistry()->getAllCreators(&num_plugin_creator); std::unordered_set registered_plugin_names; diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index 113a3f31be7f9..e8140a4d59eab 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -193,21 +193,27 @@ class BucketCacheManager : public IBufferCacheManager { } void ReleaseBuffer(WGPUBuffer buffer) override { - auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); - - auto it = buckets_.find(buffer_size); - if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { - it->second.emplace_back(buffer); - } else { - wgpuBufferRelease(buffer); - } + pending_buffers_.emplace_back(buffer); } void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { - // no-op + for (auto& buffer : pending_buffers_) { + auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { + it->second.emplace_back(buffer); + } else { + wgpuBufferRelease(buffer); + } + } + + pending_buffers_.clear(); } ~BucketCacheManager() { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } for (auto& pair : buckets_) { for (auto& buffer : pair.second) { wgpuBufferRelease(buffer); @@ -236,6 +242,7 @@ class BucketCacheManager : public IBufferCacheManager { } std::unordered_map buckets_limit_; std::unordered_map> buckets_; + std::vector pending_buffers_; std::vector buckets_keys_; }; diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc index 313a96ba25509..7f92ea4ed3776 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.cc +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -52,28 +52,10 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T1", CastOpTypeConstraints()) .TypeConstraint("T2", CastOpTypeConstraints()), Cast); -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Cast, - kOnnxDomain, - 19, 20, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T1", CastOpTypeConstraints()) - .TypeConstraint("T2", CastOpTypeConstraints()), - Cast); -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Cast, - kOnnxDomain, - 21, 22, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T1", CastOpTypeConstraints()) - .TypeConstraint("T2", CastOpTypeConstraints()), - Cast); ONNX_OPERATOR_KERNEL_EX( Cast, kOnnxDomain, - 23, + 19, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T1", CastOpTypeConstraints()) diff --git a/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc index 9f07e2d2a3988..f13e86c185928 100644 --- a/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc +++ b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc @@ -146,24 +146,24 @@ Status ScatterND::ComputeInternal(ComputeContext& context) const { const auto* updates = context.Input(2); const auto& input_shape = input->Shape(); const auto& indices_shape = indices->Shape(); - auto* output = context.Output(0, input_shape); - const void* source = input->DataRaw(); - void* target = output->MutableDataRaw(); - // If source and target pointers are not equal (non-inplace operation), we need to copy the data. - if (target != source) { - ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input, *output)); - } - if (indices_shape.Size() == 0) { - // If the indices are empty, we can return early. - return Status::OK(); - } auto indices_rank = indices_shape.NumDimensions(); auto last_index_dimension = static_cast(indices_shape[indices_rank - 1]); auto num_updates_elements = static_cast(input_shape.SizeFromDimension(last_index_dimension)); // TODO: support bool with components 4. const size_t components = 1; auto output_size = static_cast((indices_shape.SizeToDimension(indices_rank - 1) + components - 1) / components); + auto* output = context.Output(0, input_shape); + if (output_size == 0) { + // If the output tensor is empty, we can return early. + return Status::OK(); + } MLDataType data_type = input->DataType(); + const void* source = input->DataRaw(); + void* target = output->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input, *output)); + } ScatterNDProgram program(reduction_, data_type); program .CacheHint(static_cast(reduction_)) diff --git a/onnxruntime/core/providers/webgpu/tensor/slice.cc b/onnxruntime/core/providers/webgpu/tensor/slice.cc index 7e8b434431781..39432db5113d1 100644 --- a/onnxruntime/core/providers/webgpu/tensor/slice.cc +++ b/onnxruntime/core/providers/webgpu/tensor/slice.cc @@ -172,8 +172,8 @@ Status Slice::ComputeInternal(ComputeContext& context) const { } if (step < 0) { // we are slicing in reverse - start = dim_value > 0 ? std::clamp(start, int64_t{0}, dim_value - 1) : 0; - end = dim_value > 0 ? std::clamp(end, int64_t{-1}, dim_value - 1) : -1; + start = std::clamp(start, int64_t{0}, dim_value - 1); + end = std::clamp(end, int64_t{-1}, dim_value - 1); // note that we are flipping start and end to switch to forward step signs.push_back(-1); steps.push_back(static_cast(-step)); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 6e09f494f4a8d..460d220ecf1b9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -123,9 +123,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 8, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Cast); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Cast); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Cast); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Cast); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, float, Clip); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, float, Clip); @@ -457,9 +455,7 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast), KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast), - KERNEL_CREATE_INFO_VERSIONED(19, 20, Cast), - KERNEL_CREATE_INFO_VERSIONED(21, 22, Cast), - KERNEL_CREATE_INFO(23, Cast), + KERNEL_CREATE_INFO(19, Cast), // // activations BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/README.md b/onnxruntime/core/providers/webgpu/wgsl_templates/README.md index 6bd2f98cc5713..c1a62e7fa7858 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/README.md +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/README.md @@ -64,7 +64,7 @@ This section includes instructions for how to use the template system in the dev 1. Create WGSL template files in `.wgsl.template` extension. - [Reference: Template Syntax](https://github.com/fs-eire/wgsl-template?tab=readme-ov-file#template-syntax) - - [Reference: Built-in Utilities](https://github.com/fs-eire/wgsl-template?tab=readme-ov-file#Utilities) + - [Reference: Built-in Utilities](#Utilities) - [Example: Pad](../tensor/pad.wgsl.template) 2. In the implementation of `YourProgram::GenerateShaderCode()`, load and use the generated template files. @@ -117,4 +117,4 @@ This section includes instructions for how to use the template system in the dev 1. Build ORT once with dynamic template mode 2. Launch wgsl-gen in watch mode 3. Run ORT to debug/validate the shader - 4. Make changes to the template files, and repeat step (c) + 4. Make changes to the template files, and repeat step (3) diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json index df1940ed6416b..7cde6c17f54e9 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json @@ -9,13 +9,13 @@ "version": "1.0.0", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.13" + "@fs-eire/wgsl-template": "^0.1.3" } }, "node_modules/@fs-eire/wgsl-template": { - "version": "0.1.13", - "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.13.tgz", - "integrity": "sha512-SOQjVCQCUmXb9qYr2E3CKNs88/FzINuhFJiobBEkSAsyKtJby9oFWGZnrEO+hIl/oDTLA01LbjiDxuf6TGHE/w==", + "version": "0.1.10", + "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.10.tgz", + "integrity": "sha512-F5qQZxNweZ3ZD3d9RNc/g3nTiW7jyaAVi7SlMOL4wOfXh+Nm/qca2DISNTf3kjpVqkoazMJGbZ6TPQ4a/vjw0g==", "license": "MIT", "dependencies": { "minimist": "^1.2.8" diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json index 246e7365531e0..34831ccddeb33 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json @@ -10,6 +10,6 @@ "author": "", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.13" + "@fs-eire/wgsl-template": "^0.1.3" } } diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 142d64caa64aa..e821265fff80d 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -99,93 +99,69 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n return true; } -// Check if a single input's rank of an ONNX op is supported by corresponding WebNN op. -bool IsInputRankSupported(const emscripten::val& wnn_limits, - const std::string_view webnn_op_type, - const std::string_view input_name, - const size_t input_rank, - const std::string_view node_name, - const logging::Logger& logger) { - const std::string webnn_op_type_str(webnn_op_type); - const std::string input_name_str(input_name); - - if (wnn_limits[webnn_op_type_str].isUndefined()) { - LOGS(logger, VERBOSE) << "WebNN op type: [" << webnn_op_type - << "] is not defined in WebNN MLOpSupportLimits."; - return false; - } - - const emscripten::val input_limits = wnn_limits[webnn_op_type_str][input_name_str]; - - if (input_limits.isUndefined()) { - LOGS(logger, VERBOSE) << "Node name: [" << node_name - << "], WebNN op type: [" << webnn_op_type - << "], input [" << input_name - << "]: limits are not defined in WebNN MLOpSupportLimits."; - return false; - } - - const emscripten::val rank_range = input_limits["rankRange"]; - if (rank_range.isUndefined()) { - LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type - << "] input [" << input_name - << "]: missing 'rankRange' attribute."; - return false; - } - - const emscripten::val min_val = rank_range["min"]; - const emscripten::val max_val = rank_range["max"]; - if (min_val.isUndefined() || max_val.isUndefined()) { - LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type - << "] input [" << input_name - << "]: its 'rankRange' limits is missing valid 'min' or 'max' attributes."; - return false; - } - - size_t min_rank = min_val.as(); - size_t max_rank = max_val.as(); - if (input_rank < min_rank || input_rank > max_rank) { - LOGS(logger, VERBOSE) << "Node name: [" << node_name - << "] WebNN op type [" << webnn_op_type - << "] input [" << input_name << "] rank " << input_rank - << " is not in supported range [" << min_rank << ", " << max_rank << "]"; +// Check if all input tensor ranks of the given node are supported by WebNN. +bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) { + const std::string_view op_type = node.OpType(); + const auto it = op_inputs_map.find(op_type); + if (it == op_inputs_map.end()) { + LOGS(logger, VERBOSE) << "Operator type: [" << op_type << "] is not found in the op inputs map."; return false; } - return true; -} - -bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) { - const std::string_view onnx_op_type = node.OpType(); - const std::string_view webnn_op_type = GetWebNNOpType(onnx_op_type); + const auto& input_defs = node.InputDefs(); + const std::string_view webnn_op_type = it->second.opType; + const std::string webnn_op_type_str(webnn_op_type); - if (webnn_op_type.empty()) { - LOGS(logger, VERBOSE) << "ONNX op type: [" << onnx_op_type << "]'s corresponding WebNN op is not found."; - return false; - } + for (const auto& input : it->second.inputs) { + if (static_cast(input.index) >= input_defs.size() || input_defs[input.index] == nullptr) { + LOGS(logger, VERBOSE) << "Input index [" << input.index + << "] for operator type [" << op_type + << "], corresponding WebNN op type [" << webnn_op_type + << "], WebNN input name [" << input.name + << "] is invalid."; + return false; + } - std::vector inputs; - if (!GetWebNNOpInputs(onnx_op_type, inputs, logger)) { - return false; - } + std::vector input_shape; + if (!GetShape(*input_defs[input.index], input_shape, logger)) { + return false; + } - const auto& input_defs = node.InputDefs(); + const std::string input_name_str(input.name); + if (wnn_limits[webnn_op_type_str].isUndefined() || + wnn_limits[webnn_op_type_str][input_name_str].isUndefined()) { + LOGS(logger, VERBOSE) << "Operator type: [" << op_type + << "], input index: [" << input.index + << "], corresponding WebNN op type: " << webnn_op_type + << ", WebNN input name " << input.name + << " is not defined in wnn_limits."; + return false; + } - for (const auto& input : inputs) { - // If it is an optional input and is absent, skip. - if (!TensorExists(input_defs, input.index)) { - continue; + const auto& input_limits = wnn_limits[webnn_op_type_str][input_name_str]; + if (input_limits["rankRange"].isUndefined()) { + LOGS(logger, VERBOSE) << "Operator type: [" << op_type + << "], input index: [" << input.index + << "], corresponding WebNN op type: " << webnn_op_type + << ", WebNN input name " << input.name + << "'s rankRange is not defined."; + return false; } - std::vector shape; - if (!GetShape(*input_defs[input.index], shape, logger) || - !IsInputRankSupported(wnn_limits, webnn_op_type, input.name, - shape.size(), - node.Name(), logger)) { + int input_dim_size = static_cast(input_shape.size()); + int min_rank = input_limits["rankRange"]["min"].as(); + int max_rank = input_limits["rankRange"]["max"].as(); + + if (input_dim_size < min_rank || input_dim_size > max_rank) { + LOGS(logger, VERBOSE) << "Operator type: [" << op_type + << "], input index: [" << input.index + << "], corresponding WebNN op type: " << webnn_op_type + << ", WebNN input name: " << input.name + << ", input size " << input_dim_size + << " is not in supported range [" << min_rank << ", " << max_rank << "]"; return false; } } - return true; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 50e361ede221e..d59788600f997 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -216,13 +216,6 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger); -bool IsInputRankSupported(const emscripten::val& wnn_limits, - const std::string_view webnn_op_type, - const std::string_view input_name, - const size_t input_rank, - const std::string_view node_name, - const logging::Logger& logger); - // Get a set of nodes supported by WebNN EP. std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, @@ -251,33 +244,6 @@ inline std::string_view GetWebNNOpType(const std::string_view onnx_op_type) { return (it != op_inputs_map.end()) ? it->second.opType : ""; } -// Get corresponding input name of WebNN op type by ONNX op type from op_input_map -inline std::string_view GetWebNNInputName(const std::string_view onnx_op_type, const int input_index) { - const auto it = op_inputs_map.find(onnx_op_type); - - if (it != op_inputs_map.end()) { - for (const auto& input : it->second.inputs) { - if (input.index == input_index) { - return input.name; - } - } - } - - return ""; -} - -inline bool GetWebNNOpInputs(const std::string_view onnx_op_type, - std::vector& inputs, - const logging::Logger& logger) { - const auto it = op_inputs_map.find(onnx_op_type); - if (it == op_inputs_map.end()) { - LOGS(logger, VERBOSE) << "WebNN op inputs not found for op type: " << onnx_op_type; - return false; - } - inputs = it->second.inputs; - return true; -} - bool AreDataTypesSame(const std::string_view op_type, gsl::span input_types, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index fdf1709d87bac..fc630af8cf1e3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -18,6 +18,10 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool IsOpSupportedImpl(const GraphViewer&, const Node& node, + WebnnDeviceType device_type, const logging::Logger& logger) const override; }; // Add operator related. @@ -61,6 +65,20 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } +// Operator support related. +bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const GraphViewer& /* initializers */, + const Node& node, + WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + return true; +} + void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index 3c8e7fa34f7ed..b0ec006db6986 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -62,12 +62,13 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, int32_t input_type; if (!GetType(input, input_type, logger)) return false; - const std::string_view webnn_op_type = GetWebNNOpType(op_type); + if (webnn_op_type.empty()) + return false; + const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); return IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits, - webnn_input_name, "input", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + webnn_input_name, "input", logger); } bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 851dc373923ac..280ffc83eae89 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -73,10 +73,9 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod return false; } - const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); + std::string webnn_input_name = op_type == "PRelu" ? "input" : "a"; std::string onnx_input_name = op_type == "PRelu" || op_type == "Pow" ? "X" : "A"; - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger); } void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index db5e8cd51656c..8589237617745 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -75,8 +75,7 @@ bool ConcatOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); } void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index e0bfb3bd682e8..b9383a63fe307 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -324,7 +324,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N x_zero_point = model_builder.CreateOrGetConstant(x_type, 0); } - // Scale is not used by ConvInteger but required by DequantizeLinear. So set it to default value 1.0f. + // Scale is not used by ConvInteger but required by DequantizeLinear. So set it to deafult value 1.0f. // The x_zero_point must be a scalar and the scale input should have the same shape as the zero point input. // So the x_scale must be a scalar too. x_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f); diff --git a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc index f3c392b608e45..7528d9ad2ff51 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc @@ -77,6 +77,10 @@ bool CumSumOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + const std::string axis_name = GetTensorName(input_defs, 1); // Inputs contain optional 'axis' input. const auto* init = graph_viewer.GetConstantInitializer(axis_name); diff --git a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc index 37a00fcb12abd..c22dd9e97bb1a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc @@ -21,6 +21,11 @@ class DropoutOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const GraphViewer&, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; // Add operator related. @@ -60,13 +65,26 @@ Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val options = emscripten::val::object(); options.set("label", output_defs[1]->Name() + "_identity"); // Add additional identity op in case the mask is the output of a WebNN graph, - // because WebNN does not support a constant operand as output. + // beacuse WebNN does not support a constant operand as output. emscripten::val mask_output = model_builder.GetBuilder().call("identity", one_constant, options); model_builder.AddOperand(output_defs[1]->Name(), std::move(mask_output)); } return Status::OK(); } +// Operator support related. +bool DropoutOpBuilder::IsOpSupportedImpl(const GraphViewer&, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + return true; +} + void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc index 6aa760c0f4baf..e5b4fcddc4221 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc @@ -28,8 +28,6 @@ class EinsumOpBuilder : public BaseOpBuilder { const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; - bool HasSupportedOutputsImpl(const Node& /* node */, const emscripten::val& /* wnn_limits */, - const logging::Logger& /* logger */) const override; }; // Helper functions, thanks for DML EP's OperatorHelper. @@ -44,6 +42,12 @@ enum class RecognizedOperatorType { Total, }; +struct RecognizedOperatorInfo { + RecognizedOperatorType recognized_operator_type; + std::initializer_list component_ranks; + std::initializer_list label_indices; +}; + struct Component { uint32_t label_index_begin; uint32_t label_index_end; @@ -594,7 +598,7 @@ Status EinsumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } } - // transpose input + // tranpose input std::vector permutation(input_labels.size()); for (uint32_t idx = 0; idx < input_labels.size(); idx++) { if (idx != diagonal_idx_1 && idx != diagonal_idx_2) { @@ -616,7 +620,7 @@ Status EinsumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options_trilu.set("upper", false); output = model_builder.GetBuilder().call("triangular", output, options_trilu); // tril - // reduceSum to achieve the diagonal values + // reducesum to achieve the diagonal values std::vector input_shape; std::vector reduced_axes; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); @@ -696,6 +700,12 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const GraphViewer&, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + if (input_defs.size() > 2) { + // TODO: Support more than two inputs. + LOGS(logger, VERBOSE) << "EinSum only supports up to two inputs."; + return false; + } + NodeAttrHelper helper(node); const auto equation = helper.Get("equation", std::string(" ")); std::vector label_indices; @@ -714,6 +724,13 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const GraphViewer&, return false; } + RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, + output_dimensions); + if (recognized_operator_type == RecognizedOperatorType::None) { + LOGS(logger, VERBOSE) << "The equation is not supported in Einsum."; + return false; + } + return true; } @@ -721,14 +738,9 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - if (input_defs.size() > 2) { - // TODO: Support more than two inputs. - LOGS(logger, VERBOSE) << "EinSum only supports up to two inputs."; - return false; - } - const std::string_view op_type = node.OpType(); - int32_t input0_type, input1_type; + int32_t input0_type; + int32_t input1_type; bool has_input1 = TensorExists(input_defs, 1); if (!GetType(*input_defs[0], input0_type, logger) || @@ -742,13 +754,6 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod return false; } - std::vector input0_shape; - std::vector input1_shape; - if (!GetShape(*input_defs[0], input0_shape, logger) || - (has_input1 && !GetShape(*input_defs[1], input1_shape, logger))) { - return false; - } - NodeAttrHelper helper(node); const auto equation = helper.Get("equation", std::string(" ")); std::vector label_indices; @@ -765,54 +770,17 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, output_dimensions); - std::string_view decomposed_op_type; if (recognized_operator_type == RecognizedOperatorType::None) { LOGS(logger, VERBOSE) << "The equation is not supported in Einsum."; return false; - } else if (recognized_operator_type == RecognizedOperatorType::Multiply) { - decomposed_op_type = "Mul"; - } else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) { - decomposed_op_type = "ReduceSum"; - } else if (recognized_operator_type == RecognizedOperatorType::Diagonal) { - decomposed_op_type = "Trilu"; - } else if (recognized_operator_type == RecognizedOperatorType::Transpose) { - decomposed_op_type = "Transpose"; } else if (recognized_operator_type == RecognizedOperatorType::Pairwise) { - decomposed_op_type = "MatMul"; - } else { // Identity - // For the Identity case, we simply forward the input to the output without any modification. - return true; - } - - const std::string_view wnn_input0_name = GetWebNNInputName(decomposed_op_type, 0); - const std::string_view decompose_wnn_op_type = GetWebNNOpType(decomposed_op_type); - if (decompose_wnn_op_type.empty() || - !IsDataTypeSupportedByWebNNOp(op_type, decompose_wnn_op_type, input0_type, - wnn_limits, wnn_input0_name, "inputs", logger) || - !IsInputRankSupported(wnn_limits, decompose_wnn_op_type, wnn_input0_name, - input0_shape.size(), node.Name(), logger)) { - return false; - } - - if (has_input1) { - const std::string_view wnn_input1_name = GetWebNNInputName(decomposed_op_type, 1); - return IsDataTypeSupportedByWebNNOp(op_type, decompose_wnn_op_type, input1_type, - wnn_limits, wnn_input1_name, "inputs", logger) && - IsInputRankSupported(wnn_limits, decompose_wnn_op_type, wnn_input1_name, - input1_shape.size(), node.Name(), logger); + // Map to WebNN's gemm or matmul + return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger); + } else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) { + return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger); + } else { + return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger); } - - return true; -} - -bool EinsumOpBuilder::HasSupportedOutputsImpl(const Node& /* node */, - const emscripten::val& /* wnn_limits */, - const logging::Logger& /* logger */) const { - // The Einsum op produces output with the same data type as its input. - // Therefore, checking the output data type is unnecessary. - // This override prevents calling the base class implementation, as the base implementation - // would return false due to Einsum being a decomposed op. - return true; } void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc index ae4c3705fdb2e..06beb56415609 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc @@ -56,14 +56,14 @@ bool GatherElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const N const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t data_type, indices_type; + int32_t data_type; + int32_t indices_type; if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { return false; } return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc index af508c2800f4b..9200c596c0e53 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc @@ -61,14 +61,14 @@ bool GatherNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& n const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t data_type, indices_type; + int32_t data_type; + int32_t indices_type; if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { return false; } return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index 7111a8f6beaa3..d84c70032e1d1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -20,6 +20,8 @@ class GatherOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. + bool IsOpSupportedImpl(const GraphViewer&, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -48,20 +50,38 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. + +bool GatherOpBuilder::IsOpSupportedImpl(const GraphViewer&, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + const auto rank = input_shape.size(); + if (rank < 1) { + LOGS(logger, VERBOSE) << "Gather only supports input shapes >= 1D, but input is " + << rank << "d shape"; + return false; + } + + return true; +} + bool GatherOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t input_type, indices_type; - + int32_t input_type; + int32_t indices_type; if (!GetType(input, input_type, logger) || !GetType(indices, indices_type, logger)) return false; return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 7af17fdc5db78..02f46c85d1d06 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -91,7 +91,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); std::vector a_zero_point_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[2], a_zero_point_shape, logger), "Cannot get shape of a_zero_point"); - // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to default value 1.0f. + // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to deafult value 1.0f. // The scale input should have the same shape as the zero point input. a_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, @@ -268,45 +268,11 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - if (op_type == "Gemm") { - return IsInputRankSupportedByOp(node, wnn_limits, logger) && - IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); - } else if (op_type == "MatMulInteger") { - // Check up to 4 inputs for MatMulInteger - for (size_t i = 0; i < input_defs.size(); ++i) { - std::vector shape; - if (!GetShape(*input_defs[i], shape, logger)) { - return false; - } - - // We made workaround to support 1D for input A and B, skip further checks if they are 1D - if (i <= 1 && shape.size() == 1) { - continue; - } - - // For DequantizeLinear, input indices: 0 (x), 1 (scale), 2 (zero_point) - if (!IsInputRankSupported(wnn_limits, "dequantizeLinear", - (i < 2) ? "input" : "zeroPoint", - shape.size(), node.Name(), logger)) { - return false; - } - } + if (op_type == "MatMulInteger") { + // The first decomposed op of MatMulInteger is DequantizeLinear, and so + // we only need to ensure it supports the input0_type. return IsDataTypeSupportedByOp("DequantizeLinear", input0_type, wnn_limits, "input", "x", logger); - } else { // MatMul - for (int i = 0; i < 2; ++i) { - std::vector shape; - if (!GetShape(*input_defs[i], shape, logger)) { - return false; - } - - if (shape.size() == 1) { - continue; - } - - if (!IsInputRankSupported(wnn_limits, "matmul", (i == 0) ? "a" : "b", shape.size(), node.Name(), logger)) { - return false; - } - } + } else { return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index 95e75a3083cc2..dfe80dd419092 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -219,8 +219,7 @@ bool GruOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger); } bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 55d468c4843cb..42940083cad8e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -91,10 +91,8 @@ bool LogicalOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no } } - const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); std::string onnx_input_name = op_type == "Not" ? "X" : "A"; - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); } void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc index e8aab725375ad..8936bda875aef 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc @@ -21,6 +21,8 @@ class LRNOpBuilder : public BaseOpBuilder { // Operator support related. private: + bool IsOpSupportedImpl(const GraphViewer&, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, @@ -126,10 +128,11 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. -bool LRNOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, - const emscripten::val& wnn_limits, const logging::Logger& logger) const { +bool LRNOpBuilder::IsOpSupportedImpl(const GraphViewer&, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - std::vector input_shape; if (!GetShape(*input_defs[0], input_shape, logger)) return false; @@ -140,6 +143,12 @@ bool LRNOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } + return true; +} + +bool LRNOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); int32_t input_type = 0; if (!GetType(*input_defs[0], input_type, logger)) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 04d59e2f30d15..09e584bc66f8a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -242,8 +242,7 @@ bool LstmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } bool LstmOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc index 9ab403b7051d2..111d03571e974 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc @@ -48,7 +48,7 @@ void MatMulNBitsBuilder::AddInitializersToSkip(ModelBuilder& model_builder, cons // DequantizeLinear + Transpose + MatMul. Given that the CPU EP currently only supports // 4-bit quantization, we only handle 4-bit quantization here. // -// To align with WebNN's dequantizeLinear op constraints, the following transformations are +// To align with WebNN's dequantizeLinear op contraints, the following transformations are // required for MatMulNBits inputs: // 1. B: must be a constant initializer and registered as a 'uint4' WebNN constant with shape // [N, n_blocks_per_col, blob_size * 2]. @@ -159,6 +159,10 @@ bool MatMulNBitsBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const logging::Logger& logger) const { const auto& name = node.Name(); const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } // Inputs B and zero_points (if present) must be initializers if (!graph_viewer.GetConstantInitializer(input_defs[1]->Name())) { // B @@ -189,10 +193,6 @@ bool MatMulNBitsBuilder::HasSupportedInputsImpl(const GraphViewer&, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - return false; - } int32_t A_type = 0; int32_t B_type = 0; @@ -227,13 +227,10 @@ bool MatMulNBitsBuilder::HasSupportedInputsImpl(const GraphViewer&, return false; } - // Data type: Currently, only 4-bit quantization is supported, represented as the uint4 data type in WebNN. - // Ensure that the uint4 data type is supported by WebNN's dequantizeLinear op. - // Input rank: Only the rank of the first input (A) is flexible. Verify that its rank is supported by - // WebNN's matmul op. + // We only support 4-bit quantization, which is represented as the uint4 data type in WebNN. + // Ensure that uint4 is supported. return IsDataTypeSupportedByOp("DequantizeLinear", ONNX_NAMESPACE::TensorProto_DataType_UINT4, - wnn_limits, "input", "x", logger) && - IsInputRankSupported(wnn_limits, "matmul", "a", input_shape.size(), node.Name(), logger); + wnn_limits, "input", "x", logger); } bool MatMulNBitsBuilder::HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 9f5ac6ef15735..4e4014e3553ea 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -20,6 +20,8 @@ class MaxMinOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const Node& node, + WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -66,6 +68,25 @@ Status MaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. +bool MaxMinOpBuilder::IsOpSupportedImpl(const GraphViewer&, + const Node& node, + WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + if (input_defs.size() < 1) { + LOGS(logger, VERBOSE) << op_type << " requires at least one input (data)"; + return false; + } + + return true; +} + bool MaxMinOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); @@ -87,8 +108,7 @@ bool MaxMinOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); } void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 9fb643f055ef3..148eacac98e4a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -46,14 +46,28 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); + std::vector scale_shape; const size_t scale_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 2 : 1; + ORT_RETURN_IF_NOT(GetShape(*input_defs[scale_input_index], scale_shape, logger), "Cannot get scale shape"); + const auto scale_size = scale_shape.size(); + // Except LayerNormalization, other normalization ops' scale input should be 1-D. + if (op_type == "LayerNormalization") { + ORT_RETURN_IF_NOT(scale_size >= 1 && scale_size <= rank, + "The scale size should be less than or equal to input size."); + } else { + ORT_RETURN_IF_NOT(scale_size == 1, "The scale size should be one."); + } + emscripten::val scale = model_builder.GetOperand(input_defs[scale_input_index]->Name()); options.set("scale", scale); const size_t bias_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 3 : 2; emscripten::val bias = emscripten::val::undefined(); if (TensorExists(input_defs, bias_input_index)) { - // Bias input exists. + // Bias input exists, and bias's shape should be the same as scale's shape. + std::vector bias_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[bias_input_index], bias_shape, logger), "Cannot get bias shape"); + ORT_RETURN_IF_NOT(bias_shape == scale_shape, "The bias' shape should be equal to scale's shape."); bias = model_builder.GetOperand(input_defs[bias_input_index]->Name()); options.set("bias", bias); } @@ -265,6 +279,12 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const GraphViewer&, return false; } + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + LOGS(logger, VERBOSE) << "Cannot get input shape."; + return false; + } + const auto& output_defs = node.OutputDefs(); if (op_type == "SkipSimplifiedLayerNormalization") { if (output_defs.size() > 4) { @@ -296,28 +316,33 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const No const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); - - std::vector input_types; - bool all_types_valid = true; - - // Iterate through all inputs and check their existence and types - for (size_t i = 0; i <= input_defs.size(); ++i) { - if (TensorExists(input_defs, i)) { - int32_t input_type; - if (!GetType(*input_defs[i], input_type, logger)) { - all_types_valid = false; - break; - } - input_types.push_back(input_type); - } - } - - // Return false if any input type is invalid - if (!all_types_valid) { + int32_t input0_type; // input data type + int32_t input1_type; // scale data type + int32_t input2_type; // B data type + int32_t input3_type; // mean data type + int32_t input4_type; // var data type + bool has_input2 = TensorExists(input_defs, 2); + bool has_input3 = TensorExists(input_defs, 3); + bool has_input4 = TensorExists(input_defs, 4); + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger) || + (has_input2 && !GetType(*input_defs[2], input2_type, logger)) || + (has_input3 && !GetType(*input_defs[3], input3_type, logger)) || + (has_input4 && !GetType(*input_defs[4], input4_type, logger))) { return false; } - // Check if all input data types are the same + std::vector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); + } + if (has_input3) { + input_types.push_back(input3_type); + } + if (has_input4) { + input_types.push_back(input4_type); + } if (!AreDataTypesSame(op_type, input_types, logger)) { return false; } @@ -330,29 +355,13 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const No const std::string_view webnn_op_type = GetWebNNOpType(decomposed_op_type); const std::string_view webnn_input_name = GetWebNNOpFirstInputName(decomposed_op_type); if (!IsDataTypeSupportedByWebNNOp( - op_type, webnn_op_type, input_types[0], wnn_limits, webnn_input_name, "input", logger)) { + op_type, webnn_op_type, input0_type, wnn_limits, webnn_input_name, "input", logger)) { return false; } } - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - return false; - } - // It's complicated to check all the decomposed ops' input rank support. - // Ensure at least the first input rank is supported by the decomposed ops (pow and div accept the first input). - return IsInputRankSupported(wnn_limits, "pow", "a", input_shape.size(), node.Name(), logger) && - IsInputRankSupported(wnn_limits, "div", "a", input_shape.size(), node.Name(), logger); + return true; } else { - bool is_data_type_supported = IsDataTypeSupportedByOp(op_type, input_types[0], wnn_limits, "input", "X", logger); - if (op_type == "InstanceNormalization") { - // Skip input rank check for InstanceNormalization, as we will reshape the input to 4D if necessary. - return is_data_type_supported; - } - - // For other ops, check both data type and input rank compatibility. - bool is_input_rank_supported = IsInputRankSupportedByOp(node, wnn_limits, logger); - return is_input_rank_supported && is_data_type_supported; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc index 5d921c5176a64..f2a3f08b73148 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -133,6 +133,20 @@ bool PoolOpBuilder::IsOpSupportedImpl(const GraphViewer&, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& op_type = node.OpType(); + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + const auto input_size = input_shape.size(); + if (input_size != 4) { + LOGS(logger, VERBOSE) + << op_type << " only supports rank-4 tensor, input [" + << input_defs[0]->Name() << "] has actual dim count " << input_size; + return false; + } + NodeAttrHelper helper(node); if (op_type == "AveragePool" || op_type == "LpPool" || op_type == "MaxPool") { if (helper.Get("kernel_shape", std::vector{1, 1}).size() != 2) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index 053c41773db40..dd25fb9bf9315 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -167,8 +167,7 @@ bool QDQOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsInputRankSupportedByOp(node, wnn_limits, logger) && - IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "scale", "x_scale", logger) && (!has_input2 || IsDataTypeSupportedByOp(op_type, input2_type, wnn_limits, "zeroPoint", "x_zero_point", logger)); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index 6ea9b0a440d93..a3a0397eda4a3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -128,10 +128,16 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + const auto& op_type = node.OpType(); const std::string axes_name = GetTensorName(input_defs, 1); // If the optional input 'axes' is provided, it must be an initializer. if (!axes_name.empty() && !graph_viewer.GetConstantInitializer(axes_name)) { - LOGS(logger, VERBOSE) << "Input axes of " << node.OpType() << " must be a constant"; + LOGS(logger, VERBOSE) << "Input axes of " << op_type << " must be a constant"; return false; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc index 0444ae3afb56a..8cbb381e0f53e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc @@ -79,6 +79,11 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + const auto& perm_name = input_defs[1]->Name(); const auto* perm_init = graph_viewer.GetConstantInitializer(perm_name); if (!perm_init) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc index 37071b1030e11..893ca9d2419c7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc @@ -285,7 +285,7 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build sign_buffer.set(1, 1.0f); } else if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { if (model_builder.IsFloat16ArrayAvailable()) { - // Float16Array is available - use Float16Array. + // Float16Array is avaliable - use Float16Array. sign_buffer = emscripten::val::global("Float16Array").new_(2); sign_buffer.set(0, -1.0f); sign_buffer.set(1, 1.0f); diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc index c2974bd988f6b..f894e8bfbd517 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc @@ -71,6 +71,7 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& updates = *node.InputDefs()[2]; + const std::string_view op_type = node.OpType(); int32_t data_type; int32_t indices_type; @@ -84,11 +85,8 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const return false; } - const std::string_view op_type = node.OpType(); - return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc index a7788cfd847e9..e61ac3dcc9617 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc @@ -63,6 +63,7 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& updates = *node.InputDefs()[2]; + const std::string_view op_type = node.OpType(); int32_t data_type; int32_t indices_type; @@ -75,10 +76,9 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& if (data_type != updates_type) { return false; } - const std::string_view op_type = node.OpType(); + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index 5efbfe932c602..8853891ff8ed6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -136,6 +136,10 @@ bool SliceOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const No const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } if (input_defs.size() < 3) { LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 3 inputs (data, starts, ends) but got " @@ -162,17 +166,10 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const GraphViewer& graph_viewer, con const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& input = *input_defs[0]; - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - return false; - } - + const std::string_view op_type = node.OpType(); int32_t input_type; - if (!GetType(input, input_type, logger)) { + if (!GetType(input, input_type, logger)) return false; - } - - const std::string_view op_type = node.OpType(); // If there is step < 0, check data type support of reverse. if (TensorExists(input_defs, 4)) { @@ -181,15 +178,13 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const GraphViewer& graph_viewer, con if (!init || !ReadIntArrayFrom1DTensor(*init, steps, graph_viewer, logger)) return false; if (std::any_of(steps.begin(), steps.end(), [](int64_t step) { return step < 0; })) { - if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger) || - !IsInputRankSupported(wnn_limits, "reverse", "input", input_shape.size(), node.Name(), logger)) { + if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger)) { return false; } } } - return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger); } void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 99d137f81864c..23e73bb8f1e74 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -18,6 +18,11 @@ class SoftmaxOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const GraphViewer&, const Node& node, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -41,6 +46,20 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } +// Operator support related. + +bool SoftmaxOpBuilder::IsOpSupportedImpl(const GraphViewer&, + const Node& node, + const WebnnDeviceType /* device_type */, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + return true; +} + void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc index 7e34e35ebac16..1ba6df9febf14 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc @@ -127,6 +127,9 @@ bool SqueezeUnsqueezeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewe const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; if (input_defs.size() < 1) { LOGS(logger, ERROR) << op_type << " has no input tensor"; diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 8973757a24e99..7a7f64b1ec96d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -66,8 +66,7 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no return false; } - return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger) && - IsInputRankSupportedByOp(node, wnn_limits, logger); + return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); } void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc index 24d96588559ae..29b232026d7df 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc @@ -77,6 +77,15 @@ bool TileOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, return false; } + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + + if (input_shape.empty()) { + LOGS(logger, VERBOSE) << "Tile does not support empty input shape"; + return false; + } + return true; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc index 7a4d172c556fa..5a267557b9454 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc @@ -76,6 +76,15 @@ bool TriangularOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + const auto input_size = input_shape.size(); + if (input_size < 2) { + LOGS(logger, VERBOSE) << "Triangular only supports input size >= 2D shape, input is " + << input_size << "d shape"; + return false; + } const std::string diagonal_name = GetTensorName(input_defs, 1); // Inputs contain optional 'diagonal' input. diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h index 1c30fed7a7916..5e860eea7cac9 100644 --- a/onnxruntime/core/providers/webnn/builders/map_info.h +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -47,7 +47,6 @@ constexpr std::array supported_fallback // Use ONNX-to-ONNX op mapping to improve the search complexity for WebNN ops in the op_inputs_map. const std::map> decomposed_op_map = { {"ConvInteger", {"Cast", "Conv", "DequantizeLinear"}}, - {"Einsum", {"MatMul", "Mul", "ReduceSum", "Reshape", "Transpose", "Trilu"}}, {"GroupQueryAttention", {"Add", "Cast", "Concat", "CumSum", "Div", "Expand", "Less", "MatMul", "Reshape", "ScatterND", "Softmax", "Transpose", "Where"}}, @@ -140,7 +139,7 @@ const std::unordered_map op_inputs_map = { {"Mul", {"mul", {{0, "a"}, {1, "b"}}}}, {"Pow", {"pow", {{0, "a"}, {1, "b"}}}}, {"Concat", {"concat", {{0, "inputs"}}}}, - {"Not", {"logicalNot", {{0, "a"}}}}, + {"Not", {"logicalNot", {{0, "input"}}}}, {"Flatten", {"reshape", {{0, "input"}}}}, {"LpPool", {"l2Pool2d", {{0, "input"}}}}, {"Reshape", {"reshape", {{0, "input"}}}}, @@ -160,6 +159,7 @@ const std::unordered_map op_inputs_map = { {"Softsign", {"softsign", {{0, "input"}}}}, {"Unsqueeze", {"reshape", {{0, "input"}}}}, {"Or", {"logicalOr", {{0, "a"}, {1, "b"}}}}, + {"Einsum", {"matmul", {{0, "a"}, {1, "b"}}}}, {"HardSwish", {"hardSwish", {{0, "input"}}}}, {"LeakyRelu", {"leakyRelu", {{0, "input"}}}}, {"MatMul", {"matmul", {{0, "a"}, {1, "b"}}}}, diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index d2cd0639affd0..4468831181d42 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -78,7 +78,7 @@ class ModelBuilder { const GraphViewer& graph_viewer_; const logging::Logger& logger_; const bool is_float16array_available_ = !emscripten::val::global("Float16Array").isUndefined() && - !emscripten::val::global("Float16Array")["from"].isUndefined(); + emscripten::val::global("Float16Array").hasOwnProperty("from"); emscripten::val wnn_context_ = emscripten::val::undefined(); emscripten::val wnn_builder_ = emscripten::val::undefined(); diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index 59b0992d827e1..d910e3ea74b57 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -128,35 +128,6 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelPath, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation, - _In_ OrtModelCompilationOptions* ort_model_compile_options, - const ORTCHAR_T* output_directory, - const ORTCHAR_T* model_name) { - API_IMPL_BEGIN -#if !defined(ORT_MINIMAL_BUILD) - auto model_compile_options = reinterpret_cast(ort_model_compile_options); - - std::string output_dir = PathToUTF8String(output_directory); - if (output_dir.empty()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output directory: path is empty"); - } - - std::string model_name_str = ToUTF8String(model_name); - if (model_name_str.empty()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid model name: string is empty"); - } - - ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetEpContextBinaryInformation(output_dir, model_name_str)); - return nullptr; -#else - ORT_UNUSED_PARAMETER(ort_model_compile_options); - ORT_UNUSED_PARAMETER(output_directory); - ORT_UNUSED_PARAMETER(model_name); - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); -#endif // !defined(ORT_MINIMAL_BUILD) - API_IMPL_END -} - ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelExternalInitializersFile, _In_ OrtModelCompilationOptions* ort_model_compile_options, const ORTCHAR_T* external_initializers_file_path, @@ -277,7 +248,6 @@ static constexpr OrtCompileApi ort_compile_api = { // End of Version 22 - DO NOT MODIFY ABOVE &OrtCompileAPI::ModelCompilationOptions_SetFlags, - &OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index 93cc5dbf20fce..5f11b894f2004 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -30,7 +30,5 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModel ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_options, size_t flags); -ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_ OrtModelCompilationOptions* model_compile_options, - _In_ const ORTCHAR_T* output_dir, _In_ const ORTCHAR_T* model_name); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/ep_api_utils.h index a0904c32011a7..daccd24453371 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/ep_api_utils.h @@ -16,10 +16,6 @@ struct ForwardToFactory { return static_cast(this_ptr)->GetVendor(); } - static uint32_t ORT_API_CALL GetVendorId(const OrtEpFactory* this_ptr) noexcept { - return static_cast(this_ptr)->GetVendorId(); - } - static const char* ORT_API_CALL GetVersion(const OrtEpFactory* this_ptr) noexcept { return static_cast(this_ptr)->GetVersion(); } diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/ep_factory_internal.cc index fa4ef2515ca92..b289010cc6c5b 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/ep_factory_internal.cc @@ -14,19 +14,17 @@ namespace onnxruntime { using Forward = ForwardToFactory; -EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id, +EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, GetSupportedFunc&& get_supported_func, CreateFunc&& create_func) : ep_name_{ep_name}, vendor_{vendor}, - vendor_id_{vendor_id}, get_supported_func_{std::move(get_supported_func)}, create_func_{create_func} { ort_version_supported = ORT_API_VERSION; OrtEpFactory::GetName = Forward::GetFactoryName; OrtEpFactory::GetVendor = Forward::GetVendor; - OrtEpFactory::GetVendorId = Forward::GetVendorId; OrtEpFactory::GetVersion = Forward::GetVersion; OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices; OrtEpFactory::CreateEp = Forward::CreateEp; diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/ep_factory_internal.h index ee08e2233c529..087c0c60f8f4e 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/ep_factory_internal.h @@ -33,13 +33,12 @@ class EpFactoryInternal : public OrtEpFactory { const OrtSessionOptions* session_options, const OrtLogger* logger, std::unique_ptr* ep)>; - EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id, + EpFactoryInternal(const std::string& ep_name, const std::string& vendor, GetSupportedFunc&& get_supported_func, CreateFunc&& create_func); const char* GetName() const noexcept { return ep_name_.c_str(); } const char* GetVendor() const noexcept { return vendor_.c_str(); } - uint32_t GetVendorId() const noexcept { return vendor_id_; } const char* GetVersion() const noexcept; OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -68,7 +67,6 @@ class EpFactoryInternal : public OrtEpFactory { private: const std::string ep_name_; // EP name library was registered with const std::string vendor_; // EP vendor name - const uint32_t vendor_id_; // EP vendor ID const GetSupportedFunc get_supported_func_; // function to return supported devices const CreateFunc create_func_; // function to create the EP instance diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc index ce5736f601b45..25f70f7549a16 100644 --- a/onnxruntime/core/session/ep_library_internal.cc +++ b/onnxruntime/core/session/ep_library_internal.cc @@ -61,8 +61,7 @@ std::unique_ptr EpLibraryInternal::CreateCpuEp() { }; std::string ep_name = kCpuExecutionProvider; - auto cpu_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, - get_supported, create_cpu_ep); + auto cpu_factory = std::make_unique(ep_name, "Microsoft", get_supported, create_cpu_ep); return std::make_unique(std::move(cpu_factory)); } @@ -123,8 +122,7 @@ std::unique_ptr EpLibraryInternal::CreateDmlEp() { return nullptr; }; - auto dml_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, - is_supported, create_dml_ep); + auto dml_factory = std::make_unique(ep_name, "Microsoft", is_supported, create_dml_ep); return std::make_unique(std::move(dml_factory)); } @@ -172,8 +170,7 @@ std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { return nullptr; }; - auto webgpu_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, - is_supported, create_webgpu_ep); + auto webgpu_factory = std::make_unique(ep_name, "Microsoft", is_supported, create_webgpu_ep); return std::make_unique(std::move(webgpu_factory)); } diff --git a/onnxruntime/core/session/ep_library_provider_bridge.cc b/onnxruntime/core/session/ep_library_provider_bridge.cc index 70937bdc5d3e8..73423a4744576 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.cc +++ b/onnxruntime/core/session/ep_library_provider_bridge.cc @@ -72,7 +72,6 @@ Status EpLibraryProviderBridge::Load() { auto internal_factory = std::make_unique(factory->GetName(factory), factory->GetVendor(factory), - factory->GetVendorId(factory), is_supported_fn, create_fn); factory_ptrs_.push_back(internal_factory.get()); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f147242da668f..86a61a4d0ee74 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -423,13 +423,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, { if (!external_intra_op_thread_pool_) { bool allow_intra_op_spinning = -#if !defined(ORT_CLIENT_PACKAGE_BUILD) session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowIntraOpSpinning, "1") == "1"; -#else - // default KOrtSessionOptionsConfigAllowIntraOpSpinning to "0" for ORT builds targeting client/on-device workloads, - // to reduce CPU utilization and improve power efficiency. - session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowIntraOpSpinning, "0") == "1"; -#endif OrtThreadPoolParams to = session_options_.intra_op_param; std::basic_stringstream ss; if (to.name) { @@ -467,13 +461,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, if (session_options_.execution_mode == ExecutionMode::ORT_PARALLEL) { if (!external_inter_op_thread_pool_) { bool allow_inter_op_spinning = -#if !defined(ORT_CLIENT_PACKAGE_BUILD) session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowInterOpSpinning, "1") == "1"; -#else - // default kOrtSessionOptionsConfigAllowInterOpSpinning to "0" for ORT builds targeting client/on-device workloads, - // to reduce CPU utilization and improve power efficiency. - session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowInterOpSpinning, "0") == "1"; -#endif OrtThreadPoolParams to = session_options_.inter_op_param; to.auto_set_affinity = to.thread_pool_size == 0 && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL; std::basic_stringstream ss; diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index bbb110033f54c..5de0f03fafc08 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -72,8 +72,8 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod if (log_manager != nullptr && log_manager->HasDefaultLogger()) { const logging::Logger& logger = log_manager->DefaultLogger(); LOGS(logger, WARNING) << "Output model path length (" << ep_context_gen_options.output_model_file_path.size() - << ") exceeds limit of " << ConfigOptions::kMaxValueLength << " characters." - << "ORT will still generate the expected output file, but EPs will see an empty " + << ") exceeds limit of " << ConfigOptions::kMaxKeyLength << " characters." + << "ORT will still generated the expected output file, but EPs will see an empty " << "output model path in SessionOption's ConfigOptions."; } } @@ -98,36 +98,6 @@ Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr a return Status::OK(); } -Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::string& output_directory, - const std::string& model_name) { - if (output_directory.empty() || model_name.empty()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir or model_name is empty."); - } - - std::filesystem::path output_dir_path(output_directory); - if (output_dir_path.has_filename() && output_dir_path.extension() == "") { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir is not a valid directory."); - } - - std::filesystem::path ctx_model_path = output_directory / std::filesystem::path(model_name); - - if (ctx_model_path.string().size() <= ConfigOptions::kMaxValueLength) { - ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, - ctx_model_path.string().c_str())); - } else { - logging::LoggingManager* log_manager = env_.GetLoggingManager(); - if (log_manager != nullptr && log_manager->HasDefaultLogger()) { - const logging::Logger& logger = log_manager->DefaultLogger(); - LOGS(logger, WARNING) << "output_directory length with model_name length together exceeds limit of " - << ConfigOptions::kMaxValueLength << " characters." - << "ORT will still generate the expected output file, but EPs will see an empty " - << "output path in SessionOption's ConfigOptions."; - } - } - - return Status::OK(); -} - Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_model) { ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry( kOrtSessionOptionEpContextEmbedMode, embed_ep_context_in_model ? "1" : "0")); @@ -176,7 +146,7 @@ Status ModelCompilationOptions::ResetOutputModelSettings() { ep_context_gen_options.output_model_buffer_ptr = nullptr; ep_context_gen_options.output_model_buffer_size_ptr = nullptr; ep_context_gen_options.output_model_buffer_allocator = nullptr; - return Status::OK(); + return session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ""); } Status ModelCompilationOptions::CheckInputModelSettings() const { diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 2824df863013d..f96f0317cdaca 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -72,16 +72,6 @@ class ModelCompilationOptions { Status SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); - /// - /// Sets information relate to EP context binary file. - /// EP use this information to decide the location and context binary file name. - /// Used while compiling model with input and output in memory buffer - /// - /// The folder path to the generated context binary file - /// Model name used to decide the context binary file name: [model_name]_[ep].bin - /// Status indicating potential error - Status SetEpContextBinaryInformation(const std::string& output_directory, const std::string& model_name); - /// /// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute of EPContext /// nodes. Defaults to false (dumped to file). diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index db2a62c77d1bc..e7f60fd48a14f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2591,29 +2591,6 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets) { - API_IMPL_BEGIN - if (num_operator_sets == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_operator_sets' argument is NULL"); - } - - ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetNumOperatorSets(*num_operator_sets)); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::Graph_GetOperatorSets, _In_ const OrtGraph* graph, - _Out_writes_(num_operator_sets) const char** domains, - _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets) { - API_IMPL_BEGIN - gsl::span domains_span(domains, num_operator_sets); - gsl::span versions_span(opset_versions, num_operator_sets); - ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetOperatorSets(domains_span, versions_span)); - - return nullptr; - API_IMPL_END -} - ORT_API_STATUS_IMPL(OrtApis::Graph_GetNumInputs, _In_ const OrtGraph* graph, _Out_ size_t* num_inputs) { API_IMPL_BEGIN if (num_inputs == nullptr) { @@ -2714,91 +2691,6 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, _O API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph, - _In_ const OrtNode** nodes, - _In_ size_t num_nodes, - _Outptr_ OrtGraph** dst_graph) { - API_IMPL_BEGIN - - if (num_nodes == 0) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_nodes' argument should be > 0"); - } - - const EpGraph* ep_graph = EpGraph::ToInternal(src_graph); - if (ep_graph == nullptr) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "src_graph is a ModelEditorGraph which doesn't support Graph_GetSubGraph."); - } - const Graph& graph = ep_graph->GetGraphViewer().GetGraph(); - - // Create a GraphViewer with filtered info - std::unique_ptr indexed_sub_graph = std::make_unique(); - std::unique_ptr metadef = std::make_unique(); - metadef->name = "sub_graph"; - metadef->since_version = 1; - std::unordered_set outputs; - std::unordered_set initializers; - - auto add_inputs = [&](ConstPointerContainer> defs) { - for (const auto* def : defs) { - if (def->Exists()) { - // not the output of a previous node - if (outputs.count(def->Name()) == 0) { - metadef->inputs.push_back(def->Name()); - } else { - // consumed by node so no longer subgraph output - // NOTE: Ignoring edge case where a node output is an overall graph output AND a node input - outputs.erase(def->Name()); - } - - if (graph.IsInitializedTensor(def->Name())) { - initializers.insert(def); - } - } - } - }; - - auto add_node = [&](const Node& node) { - indexed_sub_graph->nodes.push_back(node.Index()); - add_inputs(node.InputDefs()); - add_inputs(node.ImplicitInputDefs()); - - for (const auto* def : node.OutputDefs()) { - outputs.insert(def->Name()); - } - }; - - // Add nodes - for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) { - const OrtNode* ort_node = nodes[node_idx]; - const EpNode* ep_node = EpNode::ToInternal(ort_node); - if (ep_node == nullptr) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Graph_GetSubGraph."); - } - add_node(ep_node->GetInternalNode()); - } - - // Add initializers - for (auto& initializer : initializers) { - metadef->constant_initializers.push_back(initializer->Name()); - } - - // Add outputs - for (auto& output : outputs) { - metadef->outputs.push_back(output); - } - - indexed_sub_graph->SetMetaDef(std::move(metadef)); - auto graph_viewer = std::make_unique(graph, *indexed_sub_graph.get()); - - std::unique_ptr result; - ORT_API_RETURN_IF_STATUS_NOT_OK(EpGraph::Create(std::move(graph_viewer), std::move(indexed_sub_graph), result)); - - *dst_graph = result.release(); - - return nullptr; - API_IMPL_END -} - // // OrtNode // @@ -3030,11 +2922,10 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetNumSubgraphs, _In_ const OrtNode* node, _Ou } ORT_API_STATUS_IMPL(OrtApis::Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, - _Out_writes_opt_(num_subgraphs) const char** attribute_names) { + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs) { API_IMPL_BEGIN gsl::span graphs_span(subgraphs, num_subgraphs); - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetSubgraphs(graphs_span, attribute_names)); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetSubgraphs(graphs_span)); return nullptr; API_IMPL_END } @@ -3052,23 +2943,6 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetGraph, _In_ const OrtNode* node, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetEpName, _In_ const OrtNode* node, - _Outptr_result_maybenull_ const char** out) { - API_IMPL_BEGIN - if (out == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL"); - } - - const EpNode* ep_node = EpNode::ToInternal(node); - if (ep_node == nullptr) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetEpName."); - } - - *out = ep_node->GetEpName().c_str(); - return nullptr; - API_IMPL_END -} - ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { #ifdef ENABLE_TRAINING_APIS if (version >= 13 && version <= ORT_API_VERSION) @@ -3720,8 +3594,6 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ValueInfo_IsFromOuterScope, &OrtApis::Graph_GetName, &OrtApis::Graph_GetOnnxIRVersion, - &OrtApis::Graph_GetNumOperatorSets, - &OrtApis::Graph_GetOperatorSets, &OrtApis::Graph_GetNumInputs, &OrtApis::Graph_GetInputs, &OrtApis::Graph_GetNumOutputs, @@ -3731,7 +3603,6 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Graph_GetNumNodes, &OrtApis::Graph_GetNodes, &OrtApis::Graph_GetParentNode, - &OrtApis::Graph_GetGraphView, &OrtApis::Node_GetId, &OrtApis::Node_GetName, &OrtApis::Node_GetOperatorType, @@ -3751,7 +3622,6 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumSubgraphs, &OrtApis::Node_GetSubgraphs, &OrtApis::Node_GetGraph, - &OrtApis::Node_GetEpName, &OrtApis::GetRunConfigEntry, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 9ab927006c320..cbacbfce0740d 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -631,10 +631,6 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i // OrtGraph ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name); ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); -ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); -ORT_API_STATUS_IMPL(Graph_GetOperatorSets, _In_ const OrtGraph* graph, - _Out_writes_(num_operator_sets) const char** domains, - _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets); ORT_API_STATUS_IMPL(Graph_GetNumInputs, _In_ const OrtGraph* graph, _Out_ size_t* num_inputs); ORT_API_STATUS_IMPL(Graph_GetInputs, _In_ const OrtGraph* graph, _Out_writes_(num_inputs) const OrtValueInfo** inputs, _In_ size_t num_inputs); @@ -649,8 +645,6 @@ ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t* ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph, _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); -ORT_API_STATUS_IMPL(Graph_GetGraphView, _In_ const OrtGraph* graph, _In_ const OrtNode** nodes, _In_ size_t num_nodes, - _Outptr_ OrtGraph** subgraph); // OrtNode ORT_API_STATUS_IMPL(Node_GetId, _In_ const OrtNode* node, _Out_ size_t* node_id); @@ -677,10 +671,8 @@ ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOp ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, - _Out_writes_opt_(num_subgraphs) const char** attribute_names); + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs); ORT_API_STATUS_IMPL(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); -ORT_API_STATUS_IMPL(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options, _In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value); diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 211bf8b2d15a4..e8d62ab86f517 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -22,13 +22,7 @@ namespace onnxruntime { namespace { bool MatchesEpVendor(const OrtEpDevice* d) { - // match on vendor id if provided - uint32_t factory_vendor_id = d->ep_factory->GetVendorId(d->ep_factory); - if (factory_vendor_id != 0 && d->device->vendor_id == factory_vendor_id) { - return true; - } - - // match on vendor name + // TODO: Would be better to match on Id. Should the EP add that in EP metadata? return d->device->vendor == d->ep_vendor; } diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index f7d5cdb98aa1d..0172902bdf4e2 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -1001,53 +1001,4 @@ struct BlockedQuantizeLinear { #endif -/** - * @brief Run MlasDequantizeLinear in parallel, with provided thread pool - */ - -template -void ParDequantizeLinearStd(const InputQuantType* input, - float* output, - size_t num_elems, - float scale, - InputQuantType zero_point, - concurrency::ThreadPool* thread_pool) { - constexpr std::ptrdiff_t block_size = 128; - const std::ptrdiff_t num_blocks = (num_elems + block_size - 1) / block_size; - const TensorOpCost unit_cost{static_cast(block_size * sizeof(InputQuantType)), - static_cast(block_size * sizeof(float)), - static_cast(block_size) * 2.0}; - concurrency::ThreadPool::TryParallelFor(thread_pool, num_blocks, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - auto begin_idx = begin * block_size; - auto end_idx = std::min(static_cast(num_elems), end * block_size); - MlasDequantizeLinear(&(input[begin_idx]), &(output[begin_idx]), end_idx - begin_idx, scale, zero_point); - }); -} - -// Note: this doesn't use MLAS kernel. There are currently no MLAS kernels for fp16 QuantizeLinear or DequantizeLinear. -template -void ParDequantizeLinearStd(const InputQuantType* input, - MLFloat16* output, - size_t num_elems, - MLFloat16 scale, - InputQuantType zero_point, - concurrency::ThreadPool* thread_pool) { - constexpr std::ptrdiff_t block_size = 128; - const std::ptrdiff_t num_blocks = (num_elems + block_size - 1) / block_size; - const TensorOpCost unit_cost{static_cast(block_size * sizeof(InputQuantType)), - static_cast(block_size * sizeof(MLFloat16)), - static_cast(block_size) * 2.0}; - - const int32_t zp_s32 = static_cast(zero_point); - const float sc_f32 = scale.ToFloat(); - - concurrency::ThreadPool::TryParallelFor(thread_pool, num_blocks, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - auto begin_idx = begin * block_size; - auto end_idx = std::min(static_cast(num_elems), end * block_size); - for (; begin_idx != end_idx; ++begin_idx) { - output[begin_idx] = MLFloat16(static_cast(static_cast(input[begin_idx]) - zp_s32) * sc_f32); - } - }); -} - } // namespace onnxruntime diff --git a/onnxruntime/core/util/thread_utils.h b/onnxruntime/core/util/thread_utils.h index 0b99723b2c75b..d63d620dbc321 100644 --- a/onnxruntime/core/util/thread_utils.h +++ b/onnxruntime/core/util/thread_utils.h @@ -19,13 +19,7 @@ struct OrtThreadPoolParams { bool auto_set_affinity = false; // If it is true, the thread pool will spin a while after the queue became empty. -#if !defined(ORT_CLIENT_PACKAGE_BUILD) bool allow_spinning = true; -#else - // default allow_spinning to false for ORT builds targeting client/on-device workloads, - // to reduce CPU utilization and improve power efficiency. - bool allow_spinning = false; -#endif // It it is non-negative, thread pool will split a task by a decreasing block size // of remaining_of_total_iterations / (num_of_threads * dynamic_block_base_) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index e3303dac6c8c5..9a297e451213a 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -42,7 +42,7 @@ def __init__(self, **data: dict[str, Any]): for k, v in data.items(): if not isinstance(k, str): raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.") - if k != "axis" and not isinstance(v, (int, str, np.ndarray, float)): + if k != "axis" and not isinstance(v, (int, str, np.ndarray)): raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.") if k == "axis" and not isinstance(v, int) and v is not None: raise TypeError(f"Axis value must be an int or None, not {type(v)}.") diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index 319c5aa468f7e..fbeae39c39d21 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -86,7 +86,6 @@ "InstanceNormalization": QDQNormalization, "LayerNormalization": QDQNormalization, "BatchNormalization": QDQNormalization, - "TopK": QDQDirect8BitOp, } diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py index 8711e368cd1e6..fe93f5cd358bf 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_clip.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -269,48 +269,42 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): attention_last_node = reshape_qkv add_qk = "" - causal_mask_nodes_1 = None - causal_mask_nodes_2 = None if add_mask is not None: - if add_mask.input[1] == "attention_mask": + # 4D Add after Q x K' + add_qk_nodes = self.model.match_parent_path( + add_mask, + [ + "Where", + "Sub", + "Cast", + "Expand", + "Unsqueeze", + "Unsqueeze", + "Reshape", + "Reshape", + "Cast", + ], + [1, 2, 1, 0, 0, 0, 0, 0, 0], + ) + if add_qk_nodes is not None: add_qk = add_mask.input[1] else: - # 4D Add after Q x K' - add_qk_nodes = self.model.match_parent_path( + # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path + # of computing causal mask. + causal_mask_nodes_1 = self.model.match_parent_path( add_mask, - [ - "Where", - "Sub", - "Cast", - "Expand", - "Unsqueeze", - "Unsqueeze", - "Reshape", - "Reshape", - "Cast", - ], - [1, 2, 1, 0, 0, 0, 0, 0, 0], + ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0, 0], ) - if add_qk_nodes is not None: - add_qk = add_mask.input[1] - else: - # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path - # of computing causal mask. - causal_mask_nodes_1 = self.model.match_parent_path( - add_mask, - ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], - [causal_mask_input_index, 0, 0, 0, 0, 0], - ) - # If the model is exported with batch_size == 1, there is no Concat node - causal_mask_nodes_2 = self.model.match_parent_path( - add_mask, - ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], - [causal_mask_input_index, 0, 0, 0, 0], - ) - - if causal_mask_nodes_1 is None and causal_mask_nodes_2 is None: - logger.debug("fuse_attention: failed to match causal mask subgraph") - return + # If the model is exported with batch_size == 1, there is no Concat node + causal_mask_nodes_2 = self.model.match_parent_path( + add_mask, + ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0], + ) + if causal_mask_nodes_1 is None and causal_mask_nodes_2 is None: + logger.debug("fuse_attention: failed to match causal mask subgraph") + return new_node = self.create_attention_node( mask_index=None, @@ -326,7 +320,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): output=attention_last_node.output[0], add_qk_str=add_qk, scale=None, - causal=(causal_mask_nodes_1 is not None) or (causal_mask_nodes_2 is not None), + causal=(add_mask is not None), ) if new_node is None: logger.debug("fuse_attention: failed to create fused node") diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index e16957eab80a1..6bd698f8b75b4 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,7 +1,7 @@ onnxscript>=0.2.3 optimum>=1.14.1 optree -transformers==4.52.1 +transformers==4.48.0 torch>=2.7.0 onnx==1.17.0 datasets>=2.8.0 diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index e092285d57358..ac696ff3788aa 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -410,7 +410,7 @@ def export_onnx_models( precision == Precision.FLOAT16, model.config.encoder_attention_heads, model.config.d_model, - model.config.decoder_layers, + model.config.num_hidden_layers, use_external_data_format, use_gpu=use_gpu, provider=provider, diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index 37fc72cd26e07..f1758cc52280f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -1,5 +1,5 @@ torch>=2.7.0 -transformers==4.52.3 +transformers>=4.52.3 openai-whisper==20240927 ffmpeg-python datasets diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index e10e616d35d38..fadf271ae913b 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -187,7 +187,7 @@ def input_names(self): *list( chain.from_iterable( (f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}") - for i in range(self.config.decoder_layers) + for i in range(self.config.num_hidden_layers) ) ), ] @@ -205,7 +205,7 @@ def output_names(self): f"present_key_cross_{i}", f"present_value_cross_{i}", ) - for i in range(self.config.decoder_layers) + for i in range(self.config.num_hidden_layers) ) ), ] @@ -214,7 +214,8 @@ def output_names(self): "logits", *list( chain.from_iterable( - (f"present_key_self_{i}", f"present_value_self_{i}") for i in range(self.config.decoder_layers) + (f"present_key_self_{i}", f"present_value_self_{i}") + for i in range(self.config.num_hidden_layers) ) ), ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index cd81edc1001be..26dc3aee7018b 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -127,7 +127,7 @@ def output_names(self): *list( chain.from_iterable( (f"present_key_cross_{i}", f"present_value_cross_{i}") - for i in range(self.config.decoder_layers) + for i in range(self.config.num_hidden_layers) ) ), ] @@ -143,7 +143,7 @@ def output_names(self): f"present_key_cross_{i}", f"present_value_cross_{i}", ) - for i in range(self.config.decoder_layers) + for i in range(self.config.num_hidden_layers) ) ), ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index a236c4da1738e..f66aa22eb0972 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -763,7 +763,7 @@ def optimize_onnx( is_float16: bool, num_attention_heads: int, hidden_size: int, - num_decoder_layers: int, + num_layers: int, use_external_data_format: bool = False, use_gpu: bool = False, provider: str = "cpu", @@ -801,7 +801,7 @@ def optimize_onnx( m = add_cache_indirection_to_mha(m, past_seq_len_name) if output_qk: - m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_decoder_layers, 2))) + m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_layers, 2))) m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py index 8937fea900d14..0b0882eface72 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py @@ -94,14 +94,14 @@ def get_sample_past_key_values( torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), ) - for _ in range(config.decoder_layers) + for _ in range(config.num_hidden_layers) ] cross_attention_kv_caches = [ ( torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype), torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype), ) - for _ in range(config.decoder_layers) + for _ in range(config.num_hidden_layers) ] return flatten_past_key_values(self_attention_kv_caches, cross_attention_kv_caches) @@ -187,7 +187,7 @@ def get_sample_QKs( # noqa: N802 torch.rand( batch_size, num_heads, sequence_length, config.max_source_positions, device=device, dtype=torch_dtype ) - for _ in range(config.decoder_layers) + for _ in range(config.num_hidden_layers) ] return QKs diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py index 4dd5d7de1752b..a7c0d3538b8da 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py @@ -156,7 +156,7 @@ def input_names(self): "alignment_heads", "sot_sequence_length", "segment_length", - *[f"cross_qk_{i}" for i in range(self.config.decoder_layers)], + *[f"cross_qk_{i}" for i in range(self.config.num_hidden_layers)], ] return input_names diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index 44b3f9a213abf..b498c40079f48 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -226,7 +226,7 @@ OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph) { /*static*/ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info) noexcept { + OrtEpGraphSupportInfo* graph_support_info) { ExampleEp* ep = static_cast(this_ptr); size_t num_nodes = 0; @@ -290,7 +290,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes) noexcept { + _Out_writes_(count) OrtNode** ep_context_nodes) { ExampleEp* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; @@ -328,12 +328,6 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[0], &node_input_names[0])); RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[1], &node_input_names[1])); - const char* ep_name = nullptr; - RETURN_IF_ERROR(ort_api.Node_GetEpName(fused_nodes[0], &ep_name)); - if (std::strncmp(ep_name, "example_ep", 11) != 0) { - return ort_api.CreateStatus(ORT_EP_FAIL, "The fused node is expected to assigned to this EP to run on"); - } - // Associate the name of the fused node with our MulKernel. const char* fused_node_name = nullptr; RETURN_IF_ERROR(ort_api.Node_GetName(fused_nodes[0], &fused_node_name)); @@ -360,7 +354,7 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const /*static*/ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, - size_t num_node_compute_infos) noexcept { + size_t num_node_compute_infos) { (void)this_ptr; for (size_t i = 0; i < num_node_compute_infos; i++) { delete node_compute_infos[i]; diff --git a/onnxruntime/test/autoep/library/ep.h b/onnxruntime/test/autoep/library/ep.h index dfebcc52a0caf..b8c63f39438ba 100644 --- a/onnxruntime/test/autoep/library/ep.h +++ b/onnxruntime/test/autoep/library/ep.h @@ -31,14 +31,14 @@ class ExampleEp : public OrtEp, public ApiPtrs { private: static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; static OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info) noexcept; + OrtEpGraphSupportInfo* graph_support_info); static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes) noexcept; + _Out_writes_(count) OrtNode** ep_context_nodes); static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, - size_t num_node_compute_infos) noexcept; + size_t num_node_compute_infos); OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index 19a44008b8c97..d4895102b0bf1 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -14,7 +14,6 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis) ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; - GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; @@ -88,12 +87,6 @@ const char* ORT_API_CALL ExampleEpFactory::GetVendorImpl(const OrtEpFactory* thi return factory->vendor_.c_str(); } -/*static*/ -uint32_t ORT_API_CALL ExampleEpFactory::GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { - const auto* factory = static_cast(this_ptr); - return factory->vendor_id_; -} - /*static*/ const char* ORT_API_CALL ExampleEpFactory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h index 72fa1c1301841..fda77f12c4814 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -21,7 +21,6 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; - static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept; static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; @@ -54,7 +53,6 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name - const uint32_t vendor_id_{0xB357}; // EP vendor ID const std::string ep_version_{"0.1.0"}; // EP version // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 4c3f9e8dd4dbd..7b77ca8c69225 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -527,20 +527,18 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop if (std::is_same_v) { #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); - RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); - RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_DML execution_providers.push_back(DefaultDmlExecutionProvider()); - RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_WEBGPU execution_providers.push_back(DefaultWebGpuExecutionProvider()); - RunTest(opts, std::move(execution_providers)); #endif + + RunTest(opts, std::move(execution_providers)); } else { #ifdef USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 17e829e37f729..60498e6510ec2 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -1,24 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include #include #include #include #include -#include #include "core/common/common.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_type_and_shape.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/session/onnxruntime_cxx_api.h" -#include "core/graph/ep_api_types.h" -#include "core/graph/graph_proto_serializer.h" - -#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL -#include "core/providers/utils/ort_graph_to_proto.h" #include "test/ep_graph/test_ep_graph_utils.h" #include "test/util/include/api_asserts.h" @@ -34,7 +26,6 @@ namespace test { // forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent // to a graph represented by the internal ORT GraphViewer class. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph); -static void Check_Graph_GetSubgraph(const OrtGraph& api_graph); // // Tests @@ -77,178 +68,6 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } -TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { - // The overall structure of this model is similar to the one used in "Check3LayerNestedSubgraph" test. - // The model consists of a graph with subgraphs nested across three levels. - // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). - auto test_graph = TestGraph::Load(ORT_TSTR("testdata/three_layer_nested_subgraph_v2.onnx")); - ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; - - CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); -} - -static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& output_data) { - auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - Ort::SessionOptions sess_options; - Ort::Session session(*ort_env, model_path, sess_options); - - std::vector input_shape = {1, 1, 28, 28}; - std::vector input_data(28 * 28, 0.5f); - std::vector ort_inputs; - std::vector ort_input_names; - - // Add 'Input3' - ort_inputs.emplace_back(Ort::Value::CreateTensor( - memory_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size())); - ort_input_names.push_back("Input3"); - - // Run session and get outputs - std::array output_names{"Plus214_Output_0"}; - std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), - ort_inputs.size(), output_names.data(), output_names.size()); - - // Check output type and number of elements. - Ort::Value& ort_output = ort_outputs[0]; - auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); - size_t num_output_elems = output_type_shape.GetElementCount(); - - ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - ASSERT_EQ(num_output_elems, 10); - - // Return output data. - const float* output_values = ort_output.GetTensorData(); - output_data.assign(output_values, output_values + num_output_elems); -} - -// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. -// Checks that the outputs of the serialized and original models are identical. -TEST(EpGraphTest, SerializeToProto_Mnist) { - const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/mnist.onnx"); - const ORTCHAR_T* serialized_model_path = ORT_TSTR("mnist_serialized.onnx"); - std::filesystem::remove(serialized_model_path); - - { - auto test_graph = TestGraph::Load(original_model_path); - ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; - - // Serialize OrtGraph to GraphProto. Save initializers to external file. - std::string ext_ini_file_path = "mnist_serialized.bin"; - std::filesystem::remove(ext_ini_file_path); - std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); - auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, - const void* data, size_t bytes, - bool& is_external, std::string& location, - int64_t& offset) -> Ort::Status { - // OrtValueInfo* could be used to query initializer's name, type, shape, - // node consumers, etc. - (void)value_info; - - if (bytes <= 127) { - is_external = false; // Keep small initializers stored inside the TensorProto. - return Ort::Status{nullptr}; - } - - offset = ext_ini_ofs.tellp(); - location = ext_ini_file_path; - ext_ini_ofs.write(static_cast(data), bytes); - ext_ini_ofs.flush(); - is_external = true; // True if is external initializer. - - return Ort::Status{nullptr}; - }; - - ONNX_NAMESPACE::ModelProto model_proto; - OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, handle_initializer_data); - - std::ofstream ofs(serialized_model_path, std::ios::binary); - model_proto.SerializeToOstream(&ofs); - ofs.flush(); - - ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); - ASSERT_TRUE(std::filesystem::exists(ext_ini_file_path)); - } - - // Compare output of the original and serialized models. Should be identical. - std::vector output_original; - std::vector output_serialized; - - RunMNISTModel(original_model_path, output_original); - RunMNISTModel(serialized_model_path, output_serialized); - - EXPECT_EQ(output_serialized, output_original); -} - -static void Run3LayerModel(const ORTCHAR_T* model_path, bool input_cond, std::vector& output_data) { - auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - Ort::SessionOptions sess_options; - Ort::Session session(*ort_env, model_path, sess_options); - - std::vector input_shape = {1}; - std::vector ort_inputs; - std::vector ort_input_names; - - // Add 'if_cond_input' - ort_inputs.emplace_back(Ort::Value::CreateTensor( - memory_info, &input_cond, 1, input_shape.data(), input_shape.size())); - ort_input_names.push_back("if_cond_input"); - - // Run session and get outputs - std::array output_names{"if_cond_output"}; - std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), - ort_inputs.size(), output_names.data(), output_names.size()); - - // Check output type and number of elements. - Ort::Value& ort_output = ort_outputs[0]; - auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); - size_t num_output_elems = output_type_shape.GetElementCount(); - - ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - ASSERT_EQ(num_output_elems, 1); - - // Return output data. - const float* output_values = ort_output.GetTensorData(); - output_data.assign(output_values, output_values + num_output_elems); -} - -// Test serializing an OrtGraph to GraphProto. The model has 3 layers of nested subgraphs. -// Checks that the outputs of the serialized and original models are identical. -TEST(EpGraphTest, SerializeToProto_3LayerSubgraphs) { - const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/three_layer_nested_subgraph.onnx"); - const ORTCHAR_T* serialized_model_path = ORT_TSTR("three_layer_nested_subgraph_serialized.onnx"); - std::filesystem::remove(serialized_model_path); - - { - auto test_graph = TestGraph::Load(original_model_path); - ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; - - // Serialize OrtGraph to ModelProto (all initializers stored within TensorProtos). - ONNX_NAMESPACE::ModelProto model_proto; - OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto); - - std::ofstream ofs(serialized_model_path, std::ios::binary); - model_proto.SerializeToOstream(&ofs); - ofs.flush(); - - ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); - } - - // Compare output of the original and serialized models. Should be identical. - std::vector output_original; - std::vector output_serialized; - - { - Run3LayerModel(original_model_path, true, output_original); - Run3LayerModel(serialized_model_path, true, output_serialized); - EXPECT_EQ(output_serialized, output_original); - } - - { - Run3LayerModel(original_model_path, false, output_original); - Run3LayerModel(serialized_model_path, false, output_serialized); - EXPECT_EQ(output_serialized, output_original); - } -} - // // Utils for traversing an OrtGraph and checking against GraphViewer. // @@ -488,48 +307,6 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span nodes(num_nodes); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, nodes.data(), nodes.size())); - - // Select a half of nodes to create a OrtGraph - size_t num_selected_nodes = std::max((nodes.size() >> 1), (size_t)1); - std::vector selected_nodes(num_selected_nodes); - - for (size_t i = 0; i < num_selected_nodes; i++) { - selected_nodes[i] = nodes[i]; - } - - OrtGraph* sub_graph; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetGraphView(&api_graph, selected_nodes.data(), selected_nodes.size(), &sub_graph)); - - // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk. - // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw. - const GraphViewer& sub_graph_viewer = EpGraph::ToInternal(sub_graph)->GetGraphViewer(); - std::unique_ptr model = std::make_unique(sub_graph_viewer.Name(), true, sub_graph_viewer.GetGraph().GetLogger()); - auto model_proto = std::make_unique(model->ToProto()); - GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast(1)); - model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - - const char* graph_name = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetName(&api_graph, &graph_name)); - std::string name = graph_name; - name += "_half.onnx"; - - // Dump the graph for debugging - // std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary); - // model_proto->SerializeToOstream(&dump); - - ort_api.ReleaseGraph(sub_graph); -} - // Checks that the contents of the original GraphViewer matches the contents of the OrtGraph. // Uses the public C APIs to traverse the OrtGraph. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) { @@ -693,10 +470,9 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } // Check node subgraphs - std::unordered_map> node_subgraphs_map = - node->GetAttributeNameToSubgraphMap(); + std::vector> node_subgraphs = node->GetSubgraphs(); - if (!node_subgraphs_map.empty()) { + if (!node_subgraphs.empty()) { // Check node's implicit inputs to its subgraph nodes. const auto implicit_input_node_args = node->ImplicitInputDefs(); @@ -713,34 +489,18 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ // Recursively check subgraphs. size_t api_num_node_subgraphs = 0; ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumSubgraphs(api_node, &api_num_node_subgraphs)); - ASSERT_EQ(api_num_node_subgraphs, node_subgraphs_map.size()); std::vector api_node_subgraphs(api_num_node_subgraphs); - std::vector api_subgraph_attr_names(api_num_node_subgraphs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, api_node_subgraphs.data(), api_node_subgraphs.size(), - api_subgraph_attr_names.data())); - - for (const auto& [attr_name, subgraph] : node_subgraphs_map) { - // find index of this subgraph. - size_t api_subgraph_idx = api_num_node_subgraphs; - for (size_t subgraph_idx = 0; subgraph_idx < api_num_node_subgraphs; subgraph_idx++) { - if (api_subgraph_attr_names[subgraph_idx] == attr_name) { - api_subgraph_idx = subgraph_idx; - break; - } - } - ASSERT_NE(api_subgraph_idx, api_num_node_subgraphs); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, api_node_subgraphs.data(), api_node_subgraphs.size())); + + for (size_t subgraph_idx = 0; subgraph_idx < node_subgraphs.size(); subgraph_idx++) { + auto subgraph_viewer = std::make_unique(*node_subgraphs[subgraph_idx]); + const OrtGraph* api_subgraph = api_node_subgraphs[subgraph_idx]; - // Recursively check the subgraph - auto subgraph_viewer = std::make_unique(*subgraph); - const OrtGraph* api_subgraph = api_node_subgraphs[api_subgraph_idx]; CheckGraphCApi(*subgraph_viewer, *api_subgraph); } } } - - // Check creating an OrtGraph from a subset of nodes in an OrtGraph - Check_Graph_GetSubgraph(api_graph); } } // namespace test diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.cc b/onnxruntime/test/ep_graph/test_ep_graph_utils.cc index 3b3bc4c6da911..b7743e65061de 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.cc @@ -30,7 +30,6 @@ std::unique_ptr TestGraph::Load(const ORTCHAR_T* model_path) { const OrtGraph& TestGraph::GetOrtGraph() const { return *api_graph; } const GraphViewer& TestGraph::GetGraphViewer() const { return graph_viewer; } -const Model& TestGraph::GetModel() const { return *model; } static Status GetInputIndices(const Node& consumer_node, const std::string& name, /*out*/ std::vector& indices) { diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.h b/onnxruntime/test/ep_graph/test_ep_graph_utils.h index 2ce107cf734c6..b0ed825f21d71 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.h +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.h @@ -28,7 +28,6 @@ class TestGraph { static std::unique_ptr Load(const ORTCHAR_T* model_path); const OrtGraph& GetOrtGraph() const; const GraphViewer& GetGraphViewer() const; - const Model& GetModel() const; private: std::shared_ptr model; diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 4c5dcd2bd7580..18bc9cf05b36d 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -36,7 +36,7 @@ struct TestOrtEp : ::OrtEp, ApiPtrs { // Individual tests should fill out the other function pointers as needed. } - static const char* ORT_API_CALL GetNameImpl(const OrtEp* /*this_ptr*/) noexcept { + static const char* ORT_API_CALL GetNameImpl(const OrtEp* /*this_ptr*/) { constexpr const char* ep_name = "TestOrtEp"; return ep_name; } @@ -50,7 +50,7 @@ struct TestOrtEpFactory : ::OrtEpFactory { ReleaseEp = ReleaseEpImpl; } - static void ORT_API_CALL ReleaseEpImpl(::OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { + static void ORT_API_CALL ReleaseEpImpl(::OrtEpFactory* /*this_ptr*/, OrtEp* ep) { delete static_cast(ep); } }; @@ -125,7 +125,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { } { - auto prefer_nhwc_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) noexcept -> ::OrtStatus* { + auto prefer_nhwc_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* { *preferred_data_layout = OrtEpDataLayout::OrtEpDataLayout_NCHW; return nullptr; }; @@ -135,7 +135,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { #if !defined(ORT_NO_EXCEPTIONS) { - auto invalid_layout_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) noexcept -> ::OrtStatus* { + auto invalid_layout_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* { *preferred_data_layout = static_cast(-1); return nullptr; }; @@ -144,7 +144,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { } { - auto failing_fn = [](OrtEp* this_ptr, OrtEpDataLayout* /*preferred_data_layout*/) noexcept -> ::OrtStatus* { + auto failing_fn = [](OrtEp* this_ptr, OrtEpDataLayout* /*preferred_data_layout*/) -> ::OrtStatus* { auto* test_ort_ep = static_cast(this_ptr); return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "I can't decide what data layout I prefer."); }; @@ -167,7 +167,7 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) { const char* /*node_domain*/, const char* node_op_type, OrtEpDataLayout target_data_layout, - int* should_convert) noexcept -> ::OrtStatus* { + int* should_convert) -> ::OrtStatus* { EXPECT_EQ(target_data_layout, OrtEpDataLayout::OrtEpDataLayout_NHWC); if (node_op_type == std::string_view{"Conv"}) { @@ -201,7 +201,7 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) { const char* /*node_domain*/, const char* /*node_op_type*/, OrtEpDataLayout /*target_data_layout*/, - int* /*should_convert*/) noexcept -> ::OrtStatus* { + int* /*should_convert*/) -> ::OrtStatus* { auto* test_ort_ep = static_cast(this_ptr); return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "To convert to NHWC or not to convert to NHWC..."); diff --git a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp index ea36383f70621..65822eb294d7d 100644 --- a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp +++ b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp @@ -58,10 +58,10 @@ void COMPUTESOFTMAXINPLACE(benchmark::State& state) { std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory // warming up run - MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); for (auto _ : state) { - MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); } free(ptr.underlying_buffer); diff --git a/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp deleted file mode 100644 index b994981364947..0000000000000 --- a/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "test_util.h" - -template -class MlasDequantizeLinearTest : public MlasTestBase { - private: - MatrixGuardBuffer BufferInput; - MatrixGuardBuffer BufferOutput; - MatrixGuardBuffer BufferOutputReference; - - void GenerateReference(const QuantInt* Input, float* OutputReference, size_t N, float Scale, QuantInt ZeroPoint) { - int32_t ZeroPointS32 = static_cast(ZeroPoint); - - for (size_t n = 0; n < N; n++) { - OutputReference[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; - } - } - - void Test(size_t N) { - QuantInt* Input = BufferInput.GetBuffer(N); - float* Output = BufferOutput.GetBuffer(N); - float* OutputReference = BufferOutputReference.GetBuffer(N); - - std::default_random_engine generator(static_cast(N)); - - std::uniform_real_distribution min_gen(-10.f, -10e-3f); - float MinimumValue = min_gen(generator); - - std::uniform_real_distribution max_gen(10e-3f, 10.f); - float MaximumValue = max_gen(generator); - - float Scale = (MaximumValue - MinimumValue) / 512.f; - - std::uniform_int_distribution zp_distribution(std::numeric_limits::min(), - std::numeric_limits::max()); - QuantInt ZeroPoint = static_cast(zp_distribution(generator)); - - for (size_t n = 0; n < N; n++) { - Input[n] = static_cast(zp_distribution(generator)); - } - - GenerateReference(Input, OutputReference, N, Scale, ZeroPoint); - MlasDequantizeLinear(Input, Output, N, Scale, ZeroPoint); - - for (size_t n = 0; n < N; n++) { - ASSERT_EQ(Output[n], OutputReference[n]) << ", size=" << N << ", index=" << n; - } - } - - public: - static const char* GetTestSuiteName() { - if constexpr (std::is_same_v) { - return "DequantizeLinearS8"; - } else { - return "DequantizeLinearU8"; - } - } - - void ExecuteShort(void) override { - for (size_t n = 1; n <= 512; n++) { - Test(n); - } - } -}; - -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - size_t count = 0; - if (is_short_execute) { - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - } - return count; -}); diff --git a/onnxruntime/test/mlas/unittest/test_softmax.cpp b/onnxruntime/test/mlas/unittest/test_softmax.cpp index 4d7a45143b311..041b6c61cd5bf 100644 --- a/onnxruntime/test/mlas/unittest/test_softmax.cpp +++ b/onnxruntime/test/mlas/unittest/test_softmax.cpp @@ -152,7 +152,7 @@ class MlasSoftmaxTest : public MlasTestBase { } void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 1e-6f; @@ -206,7 +206,7 @@ class MlasSoftmaxTest : public MlasTestBase { InputReference[nd] = Input[nd].ToFloat(); } - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); ReferenceSoftmax(InputReference, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 5e-3f; diff --git a/onnxruntime/test/providers/cpu/math/softmax_test.cc b/onnxruntime/test/providers/cpu/math/softmax_test.cc index 215203b31f49c..649c9af7cc80b 100644 --- a/onnxruntime/test/providers/cpu/math/softmax_test.cc +++ b/onnxruntime/test/providers/cpu/math/softmax_test.cc @@ -61,8 +61,7 @@ TEST(SoftmaxOperator, webgpu_nan) { test.AddOutput("Y", dimensions, expected_result); // explicitly disable for EPs that do not handle NaN - test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCpuExecutionProvider, kCoreMLExecutionProvider, kDmlExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider, kCoreMLExecutionProvider}); } #endif diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 8fdbf0060eaa0..4e7a6356a5129 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -33,32 +33,6 @@ TEST(DequantizeLinearOpTest, Int8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// scalar zero & scale with uint8 (large enough input to execute MLAS vectorized loop) -TEST(DequantizeLinearOpTest, Uint8_Large) { - OpTester test("DequantizeLinear", 10); - std::vector dims{1, 1039}; // not evenly divisible by 16 (loop unroll amount) to test handling of leftover inputs - test.AddInput("x", dims, std::vector(1039, 1)); - test.AddInput("x_scale", {}, {1.0f}); - test.AddInput("x_zero_point", {}, {1}); - test.AddOutput("y", dims, std::vector(1039, 0.0f)); - // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. - // Disable WebGPU EP because it requires dims.Size() to be multiple of 4. Fails with error: needs at least component size 4. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); -} - -// scalar zero & scale with int8 (large enough input to execute MLAS vectorized loop) -TEST(DequantizeLinearOpTest, Int8_Large) { - OpTester test("DequantizeLinear", 10); - std::vector dims{1, 1039}; // not evenly divisible by 16 (loop unroll amount) to test handling of leftover inputs - test.AddInput("x", dims, std::vector(1039, 1)); - test.AddInput("x_scale", {}, {1.0f}); - test.AddInput("x_zero_point", {}, {1}); - test.AddOutput("y", dims, std::vector(1039, 0.0f)); - // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. - // Disable WebGPU EP because it requires dims.Size() to be multiple of 4. Fails with error: needs at least component size 4. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); -} - // scalar zero & scale with int4 TEST(DequantizeLinearOpTest, Int4) { OpTester test("DequantizeLinear", 21); diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc index e6d113e1e4dca..895c8ab3e53e4 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc @@ -235,16 +235,5 @@ TEST(ScatterNDOpTest, ScatterND_18_max) { test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } -// Test for ScatterND with empty indices - output should be same as input -TEST(ScatterNDOpTest, ScatterND_empty_indices) { - // Test with float data type and minimal empty case - OpTester test1("ScatterND", 11); - test1.AddInput("data", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - test1.AddInput("indices", {0, 1}, {}); // Empty indices tensor - no indices to process - test1.AddInput("updates", {0, 3}, {}); // Empty updates tensor - test1.AddOutput("output", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); // Same as input - test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); -} - } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 739e39a6975e2..4febfe7ba836d 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -509,11 +509,6 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB Ort::ModelCompilationOptions compile_options(*ort_env, session_options); compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size); - std::string target_dir = "./testdata/"; - std::string model_name = "test_model_in_mem.onnx"; - auto pos = model_name.rfind(".onnx"); - std::string bin_file_name = model_name.substr(0, pos) + "_qnn.bin"; - compile_options.SetEpContextBinaryInformation(ToWideString(target_dir).c_str(), ToWideString(model_name).c_str()); compile_options.SetEpContextEmbedMode(false); // Compile the model. @@ -524,18 +519,12 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB ASSERT_TRUE(output_model_buffer != nullptr); ASSERT_TRUE(output_model_buffer_size > 0); - ASSERT_TRUE(std::filesystem::exists(target_dir + bin_file_name)) << "expected context binary file should exist"; - // Check that the compiled model has the expected number of EPContext nodes. CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2); - // Add session option "ep.context_file_path" so that the session can use it to locate the [model_name]_qnn.bin file - std::string ctx_model = target_dir + model_name; - session_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ctx_model.c_str()); // Should be able to create a session with the compiled model and the original session options. EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, session_options))); - std::filesystem::remove(target_dir + bin_file_name); allocator.Free(output_model_buffer); } } @@ -1660,6 +1649,7 @@ static void DumpModelWithSharedCtx(ProviderOptions provider_options, Ort::Session session2(*ort_env, ToPathString(onnx_model_path2).c_str(), so); } +#if defined(__aarch64__) || defined(_M_ARM64) static void GetModelInputNames(const std::string& model_path, std::vector& input_names, std::vector& output_names, @@ -1679,6 +1669,7 @@ static void GetModelInputNames(const std::string& model_path, output_names.push_back(output->Name()); } } +#endif // 1. Create 2 QDQ models // 2. Initialize 2 Ort sessions which share the same QNN EP from these 2 QDQ models @@ -2003,73 +1994,6 @@ TEST_F(QnnHTPBackendTests, LoadFromArrayWithQnnEpContextGenPathValidation) { }); } } - -TEST_F(QnnHTPBackendTests, QnnEpDynamicOptions) { - ProviderOptions provider_options; - provider_options["backend_type"] = "htp"; - provider_options["offload_graph_io_quantization"] = "0"; - - Ort::SessionOptions so; - so.AppendExecutionProvider("QNN", provider_options); - so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE); - - Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx"), so); - - std::vector input_names; - std::vector output_names; - GetModelInputNames("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx", input_names, output_names, - DefaultLoggingManager().DefaultLogger()); - - // Run sessions - // prepare input - std::vector input_dim{3, 4}; - std::vector input_value(3 * 4, 0.0f); - Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); - std::vector ort_inputs; - std::vector input_names_c; - for (size_t i = 0; i < input_names.size(); ++i) { - auto input_tensor = Ort::Value::CreateTensor(info, input_value.data(), input_value.size(), - input_dim.data(), input_dim.size()); - ort_inputs.push_back(std::move(input_tensor)); - input_names_c.push_back(input_names[i].c_str()); - } - std::vector output_names_c; - for (size_t i = 0; i < output_names.size(); ++i) { - output_names_c.push_back(output_names[i].c_str()); - } - - auto ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), - output_names_c.data(), 1); - - const char* const workload_type[] = {"ep.dynamic.workload_type"}; - const char* const efficient_type[] = {"Efficient"}; - const char* const default_type[] = {"Default"}; - - // Test Efficient & Default options - session.SetEpDynamicOptions(workload_type, efficient_type, 1); - ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), - output_names_c.data(), 1); - - session.SetEpDynamicOptions(workload_type, default_type, 1); - ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), - output_names_c.data(), 1); - - // Test invalid EP dynamic option and invalid workload type - const char* const dne[] = {"DNE"}; - try { - session.SetEpDynamicOptions(workload_type, dne, 1); - FAIL() << "Expected exception to be thrown for workload type DNE but was set successfully"; - } catch (const std::exception& e) { - EXPECT_STREQ("Invalid EP Workload Type.", e.what()); - } - - try { - session.SetEpDynamicOptions(dne, efficient_type, 1); - FAIL() << "Expected exception to be thrown for dynamic option DNE but was set successfully"; - } catch (const std::exception& e) { - EXPECT_STREQ("Unsupported EP Dynamic Option", e.what()); - } -} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 4c0a53e83e274..85f8250f70fc5 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1254,38 +1254,6 @@ TEST_F(QnnHTPBackendTests, GridSample_U16_Nearest) { true); } -// Test QDQ GridSample with `linear` mode on opset 20+. -TEST_F(QnnHTPBackendTests, GridSample_Linear_ZerosPadding) { - RunQDQOpTest("GridSample", - {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), - TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, - {utils::MakeAttribute("mode", "linear"), utils::MakeAttribute("padding_mode", "zeros")}, - /*opset_version=*/20, - /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); -} - -TEST_F(QnnHTPBackendTests, GridSample_Linear_AlignCorners_BorderPadding) { - RunQDQOpTest("GridSample", - {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), - TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, - {utils::MakeAttribute("align_corners", static_cast(1)), - utils::MakeAttribute("mode", "linear"), - utils::MakeAttribute("padding_mode", "border")}, - /*opset_version=*/20, - /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); -} - -TEST_F(QnnHTPBackendTests, GridSample_Linear_ReflectionPadding_U16) { - RunQDQOpTest("GridSample", - {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), - TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, - {utils::MakeAttribute("mode", "linear"), utils::MakeAttribute("padding_mode", "reflection")}, - /*opset_version=*/21, - /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, - /*op_domain=*/kOnnxDomain, - /*use_contrib_qdq=*/true); -} - // Test QDQ GridSample with reflection padding mode // Inaccuracy detected for output 'output', element 2. // Output quant params: scale=0.024269860237836838, zero_point=0. diff --git a/onnxruntime/test/python/quantization/test_op_topk.py b/onnxruntime/test/python/quantization/test_op_topk.py deleted file mode 100644 index 1fdd0c987d1e8..0000000000000 --- a/onnxruntime/test/python/quantization/test_op_topk.py +++ /dev/null @@ -1,103 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import unittest - -import numpy as np -from onnx import TensorProto, helper, save -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type - -from onnxruntime.quantization import QuantFormat, QuantType, quantize_static - - -class TestTopKModel(unittest.TestCase): - @staticmethod - def construct_model(model_path, input_shape, axis_attr, k): - input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape) - k_tensor = helper.make_tensor("k", TensorProto.INT64, [1], [k]) - output_shape = input_shape[:] - output_shape[axis_attr] = k - output_values = helper.make_tensor_value_info("values", TensorProto.FLOAT, [1, k]) - output_indices = helper.make_tensor_value_info("indices", TensorProto.INT64, [1, k]) - - node = helper.make_node( - "TopK", inputs=["input", "k"], outputs=["values", "indices"], name="topk_node", axis=axis_attr - ) - - graph = helper.make_graph( - [node], - "quant_topk_op_test", - [input_tensor], - [output_values, output_indices], - initializer=[k_tensor], - ) - - model = helper.make_model( - graph, opset_imports=[helper.make_opsetid("", 16), helper.make_opsetid("com.microsoft", 1)] - ) - save(model, model_path) - - def quantize_topk_test(self, activation_type, weight_type, extra_options={}): # noqa: B006 - model_fp32_path = "topk_fp32.onnx" - input_shape = [1, 10] - axis = 1 - k = 3 - self.construct_model(model_fp32_path, input_shape, axis, k) - - input_data_list = [ - {"input": np.array([[1.8, 2.5, -5.9, 5.2, 4.1, 7.3, 0.2, -0.5, 0.845, 3.9]], dtype=np.float32)} - ] - data_reader = TestDataFeeds(input_data_list) - - activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" - weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" - model_qdq_path = f"topk_{activation_type_str}{weight_type_str}_{'QNoInCk' if extra_options['ForceQuantizeNoInputCheck'] else 'NoQNoInCk'}_qdq.onnx" - - # Verify QDQ mode - data_reader.rewind() - quantize_static( - model_fp32_path, - model_qdq_path, - data_reader, - quant_format=QuantFormat.QDQ, - activation_type=activation_type, - weight_type=weight_type, - extra_options=extra_options, - ) - qdqnode_counts = ( - { - "TopK": 1, - "QuantizeLinear": 2, - "DequantizeLinear": 2, - } - if extra_options["ForceQuantizeNoInputCheck"] - else { - "TopK": 1, - "QuantizeLinear": 0, - "DequantizeLinear": 0, - } - ) - check_op_type_count(self, model_qdq_path, **qdqnode_counts) - qnode_io_qtypes = { - "QuantizeLinear": [ - ["i", 2, activation_proto_qtype], - ["o", 0, activation_proto_qtype], - ] - } - check_qtype_by_node_type(self, model_qdq_path, qnode_io_qtypes) - data_reader.rewind() - check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) - - def test_quantize_topk_u8u8(self): - self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": True}) - - def test_quantize_topk_u8u8_no_force_quantize_no_input_check(self): - self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": False}) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx b/onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx deleted file mode 100644 index 34cf26c13d3fc..0000000000000 Binary files a/onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx and /dev/null differ diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 7f2134b2cda4f..461c243b82212 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -13,7 +13,6 @@ import random import unittest from dataclasses import dataclass -from enum import Enum import numpy import torch @@ -39,17 +38,11 @@ ATOL = None -class Formats(Enum): +class Formats: BSNH = 0 BNSH = 1 -class QKOutputType(Enum): - NO_OUTPUT = 0 - BEFORE_SOFTMAX = 1 - AFTER_SOFTMAX = 2 - - @dataclass class Config: batch_size: int = 0 @@ -61,8 +54,6 @@ class Config: head_size: int = 0 has_position_ids: bool = False has_attention_bias: bool = False - has_head_sink: bool = False - qk_output: QKOutputType = QKOutputType.NO_OUTPUT @dataclass @@ -76,8 +67,6 @@ class PromptConfig: head_size: int = 0 has_position_ids: bool = False has_attention_bias: bool = False - has_head_sink: bool = False - qk_output: QKOutputType = QKOutputType.NO_OUTPUT # LLaMA Microsoft model @@ -162,15 +151,6 @@ def create_group_query_attention_graph_prompt( ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length - - output_names = [ - "output", - "present_key", - "present_value", - ] - if config.qk_output != QKOutputType.NO_OUTPUT: - output_names.append("output_qk") - nodes = [ helper.make_node( "GroupQueryAttention", @@ -186,9 +166,8 @@ def create_group_query_attention_graph_prompt( "sin_cache" if rotary else "", "position_ids" if config.has_position_ids else "", "attention_bias" if config.has_attention_bias else "", - "head_sink" if config.has_head_sink else "", ], - output_names, + ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, @@ -197,7 +176,6 @@ def create_group_query_attention_graph_prompt( rotary_interleaved=rotary_interleaved, softcap=softcap, smooth_softmax=1 if use_smooth_softmax else 0, - qk_output=config.qk_output.value, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -311,15 +289,6 @@ def create_group_query_attention_graph_prompt( ), ] - if config.has_head_sink: - graph_input += [ - helper.make_tensor_value_info( - "head_sink", - ort_type, - [config.num_heads], - ), - ] - graph_output = [ helper.make_tensor_value_info( "output", @@ -368,15 +337,6 @@ def create_group_query_attention_graph_prompt( ), ] - if config.qk_output != QKOutputType.NO_OUTPUT: - graph_output += [ - helper.make_tensor_value_info( - "output_qk", - ort_type, - [config.batch_size, config.num_heads, config.kv_sequence_length, config.kv_sequence_length], - ), - ] - graph = helper.make_graph( nodes, "GroupQueryAttention_Graph", @@ -405,15 +365,6 @@ def create_group_query_attention_graph_past( present_kv_seqlen = ( config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length ) - - output_names = [ - "output", - "present_key", - "present_value", - ] - if config.qk_output != QKOutputType.NO_OUTPUT: - output_names.append("output_qk") - nodes = [ helper.make_node( "GroupQueryAttention", @@ -429,9 +380,8 @@ def create_group_query_attention_graph_past( "sin_cache" if rotary else "", "position_ids" if config.has_position_ids else "", "attention_bias" if config.has_attention_bias else "", - "head_sink" if config.has_head_sink else "", ], - output_names, + ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, @@ -440,7 +390,6 @@ def create_group_query_attention_graph_past( rotary_interleaved=rotary_interleaved, softcap=softcap, smooth_softmax=1 if use_smooth_softmax else 0, - qk_output=config.qk_output.value, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -492,7 +441,6 @@ def create_group_query_attention_graph_past( [1], ), ] - if not packed: graph_input += [ helper.make_tensor_value_info( @@ -514,7 +462,6 @@ def create_group_query_attention_graph_past( ], ), ] - if rotary: graph_input += [ helper.make_tensor_value_info( @@ -551,15 +498,6 @@ def create_group_query_attention_graph_past( ), ] - if config.has_head_sink: - graph_input += [ - helper.make_tensor_value_info( - "head_sink", - ort_type, - [config.num_heads], - ), - ] - graph_output = [ helper.make_tensor_value_info( "output", @@ -588,15 +526,6 @@ def create_group_query_attention_graph_past( ), ] - if config.qk_output != QKOutputType.NO_OUTPUT: - graph_output += [ - helper.make_tensor_value_info( - "output_qk", - ort_type, - [config.batch_size, config.num_heads, config.sequence_length, present_kv_seqlen], - ), - ] - graph = helper.make_graph( nodes, "GroupQueryAttention_Graph", @@ -623,17 +552,17 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): """ Arguments: - q: (batch_size, seqlen_q, num_heads, d) - k: (batch_size, seqlen_k, num_heads_k, d) - v: (batch_size, seqlen_k, num_heads_k, d) + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, num_heads, d = q.shape - _, seqlen_k, num_heads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, num_heads_k, d) - assert v.shape == (batch_size, seqlen_k, num_heads_k, d) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) @@ -664,7 +593,7 @@ def output_pad_fn(output_unpad): if qkvpacked: assert (query_padding_mask == key_padding_mask).all() - assert num_heads == num_heads_k + assert nheads == nheads_k qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) qkv = torch.stack([q, k, v], dim=2) if query_padding_mask is not None: @@ -785,8 +714,6 @@ def gqa_prompt_func( seqlens_k=None, position_ids=None, attention_bias=None, - head_sink=None, - output_qk=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True, @@ -819,18 +746,9 @@ def gqa_prompt_func( if config.has_attention_bias: assert attention_bias is not None - if config.qk_output != QKOutputType.NO_OUTPUT: - assert output_qk is not None - if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) - - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() - ort_outputs = {} - if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -839,6 +757,10 @@ def gqa_prompt_func( "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() if new_k is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -875,18 +797,25 @@ def gqa_prompt_func( io_binding.bind_output("output") io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v else: ort_inputs = { "query": q.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() if new_k is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() io_binding.bind_cpu_input("key", ort_inputs["key"]) io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: ort_inputs["cos_cache"] = cos.detach().cpu().numpy() ort_inputs["sin_cache"] = sin.detach().cpu().numpy() @@ -907,26 +836,11 @@ def gqa_prompt_func( io_binding.bind_output("output") io_binding.bind_output("present_key") io_binding.bind_output("present_value") - - if config.has_head_sink: - ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() - io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) - - if config.qk_output != QKOutputType.NO_OUTPUT: - ort_outputs["output_qk"] = OrtValue.ortvalue_from_numpy(output_qk.detach().cpu().numpy(), "cpu", 0) - io_binding.bind_ortvalue_output("output_qk", ort_outputs["output_qk"]) - - ort_session.run_with_iobinding(io_binding) - - out_qk = None - if config.qk_output != QKOutputType.NO_OUTPUT: - ort_output, present_k, present_v, out_qk = io_binding.copy_outputs_to_cpu() - else: + ort_session.run_with_iobinding(io_binding) ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - - return output, present_k, present_v, out_qk + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v def gqa_past_func( @@ -941,8 +855,6 @@ def gqa_past_func( seqlens_k=None, position_ids=None, attention_bias=None, - head_sink=None, - output_qk=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1, @@ -975,18 +887,9 @@ def gqa_past_func( if config.has_attention_bias: assert attention_bias is not None - if config.qk_output != QKOutputType.NO_OUTPUT: - assert output_qk is not None - if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) - - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() - ort_outputs = {} - if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -998,6 +901,9 @@ def gqa_past_func( .cpu() .numpy(), } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -1034,6 +940,11 @@ def gqa_past_func( io_binding.bind_output("output") io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v else: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -1047,6 +958,9 @@ def gqa_past_func( .cpu() .numpy(), } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -1074,26 +988,11 @@ def gqa_past_func( io_binding.bind_output("output") io_binding.bind_output("present_key") io_binding.bind_output("present_value") - - if config.has_head_sink: - ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() - io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) - - if config.qk_output != QKOutputType.NO_OUTPUT: - ort_outputs["output_qk"] = OrtValue.ortvalue_from_numpy(output_qk.detach().cpu().numpy(), "cpu", 0) - io_binding.bind_ortvalue_output("output_qk", ort_outputs["output_qk"]) - - ort_session.run_with_iobinding(io_binding) - - out_qk = None - if config.qk_output != QKOutputType.NO_OUTPUT: - ort_output, present_k, present_v, out_qk = io_binding.copy_outputs_to_cpu() - else: + ort_session.run_with_iobinding(io_binding) ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - - return output, present_k, present_v, out_qk + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): @@ -1126,28 +1025,11 @@ def construct_local_mask( ) -def smooth_softmax_ref(x, head_sink): - """ - Arguments: - x: (batch_size, num_heads, seqlen_q, seqlen_k) - head_sink: (num_heads) or None - Output: - y: (batch_size, num_heads, seqlen_q, seqlen_k) - """ - assert len(x.shape) == 4 - b, n, s, t = x.shape - - if head_sink is not None: - assert len(head_sink.shape) == 1 - assert head_sink.shape[0] == x.shape[1] - sink = head_sink.reshape(1, n, 1, 1).expand(b, -1, s, -1) - else: - sink = torch.zeros(b, n, s, 1, dtype=x.dtype) - - y = torch.cat([x, sink], dim=-1) - y = torch.softmax(y, dim=-1) - y = y[..., :-1] - return y +def smooth_softmax_ref(x): + x_max = x.amax(axis=-1, keepdim=True) + x_max = torch.maximum(x_max, torch.zeros_like(x_max)) + w = torch.exp(x - x_max) + return w * torch.reciprocal(w.sum(axis=-1, keepdim=True) + torch.exp(-x_max)) def attention_ref( @@ -1164,17 +1046,16 @@ def attention_ref( upcast=True, reorder_ops=False, use_smooth_softmax=False, - head_sink=None, ): """ Arguments: - q: (batch_size, seqlen_q, num_heads, head_dim) - k: (batch_size, seqlen_k, num_heads_k, head_dim) - v: (batch_size, seqlen_k, num_heads_k, head_dim) + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) dropout_p: float - dropout_mask: (batch_size, num_heads, seqlen_q, seqlen_k) + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) causal: whether to apply causal masking window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast @@ -1183,10 +1064,8 @@ def attention_ref( without changing the math. This is to estimate the numerical error from operation reordering. use_smooth_softmax: whether use smooth softmax or not - head_sink: (num_heads) or None Output: output: (batch_size, seqlen_q, nheads, head_dim) - masked_scores: (batch_size, nheads, seqlen_q, seqlen_k), before softmax attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: @@ -1206,10 +1085,8 @@ def attention_ref( scores = scores / softcap scores = scores.tanh() scores = scores * softcap - masked_scores = scores.clone() if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - masked_scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, @@ -1219,11 +1096,10 @@ def attention_ref( key_padding_mask, q.device, ) - masked_scores.masked_fill_(local_mask, 0.0) scores.masked_fill_(local_mask, float("-inf")) - if use_smooth_softmax or (head_sink is not None): - attention = smooth_softmax_ref(scores, head_sink) + if use_smooth_softmax: + attention = smooth_softmax_ref(scores) else: attention = torch.softmax(scores, dim=-1) @@ -1245,7 +1121,7 @@ def attention_ref( if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - return output.to(dtype=dtype_og), masked_scores.to(dtype=dtype_og), attention.to(dtype=dtype_og) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) def attention_qkvpacked_ref( @@ -1257,7 +1133,6 @@ def attention_qkvpacked_ref( upcast=True, reorder_ops=False, use_smooth_softmax=False, - head_sink=None, ): return attention_ref( qkv[:, :, 0], @@ -1271,7 +1146,6 @@ def attention_qkvpacked_ref( causal=causal, reorder_ops=reorder_ops, use_smooth_softmax=use_smooth_softmax, - head_sink=head_sink, ) @@ -1312,10 +1186,6 @@ def get_custom_position_ids(batch_size, sequence_length, seqlens_k=None, past=Fa return position_ids -def get_custom_head_sink(num_heads, torch_type=torch.float16): - return torch.rand(num_heads, dtype=torch_type) - - def parity_check_gqa_prompt( config, torch_type, @@ -1378,8 +1248,6 @@ def parity_check_gqa_prompt( requires_grad=False, ) - head_sink = get_custom_head_sink(config.num_heads, torch_type) if config.has_head_sink else None - window_size = (-1, -1) left_window_size = -1 if local: @@ -1437,20 +1305,6 @@ def parity_check_gqa_prompt( else None ) - output_qk = ( - torch.zeros( - config.batch_size, - config.num_heads, - config.kv_sequence_length, - config.q_sequence_length, - device="cpu", - dtype=torch_type, - requires_grad=False, - ) - if config.qk_output != QKOutputType.NO_OUTPUT - else None - ) - arange = rearrange(torch.arange(config.buffer_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") kv_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size) @@ -1461,7 +1315,7 @@ def parity_check_gqa_prompt( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded - out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( + out_ref, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1473,7 +1327,6 @@ def parity_check_gqa_prompt( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1484,7 +1337,7 @@ def parity_check_gqa_prompt( # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1 if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v, out_qk = gqa_prompt_func( + out, present_k, present_v = gqa_prompt_func( packed_qkv, k, v, @@ -1496,8 +1349,6 @@ def parity_check_gqa_prompt( cache_seqlens - 1, position_ids, attention_bias, - head_sink, - output_qk, left_window_size, past_format, True, @@ -1508,7 +1359,7 @@ def parity_check_gqa_prompt( numpy_type=numpy_type, ) else: - out, present_k, present_v, out_qk = gqa_prompt_func( + out, present_k, present_v = gqa_prompt_func( q, k, v, @@ -1520,8 +1371,6 @@ def parity_check_gqa_prompt( cache_seqlens - 1, position_ids, attention_bias, - head_sink, - output_qk, left_window_size, past_format, True, @@ -1535,22 +1384,6 @@ def parity_check_gqa_prompt( out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - if config.qk_output != QKOutputType.NO_OUTPUT: - out_qk_ref = ( - out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref - ) - out_qk_ref = out_qk_ref.detach().cpu().numpy() - - for batch_idx in range(config.batch_size): - total_seqlen = cache_seqlens[batch_idx] - assert numpy.allclose( - out_qk[batch_idx, :, :, :total_seqlen], - out_qk_ref[batch_idx, :, :, :total_seqlen], - rtol=rtol, - atol=atol, - equal_nan=True, - ) - # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1592,8 +1425,6 @@ def parity_check_gqa_prompt( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, - " qk_output:", - config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1700,28 +1531,12 @@ def parity_check_gqa_prompt_no_buff( else None ) - head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None - - output_qk = ( - torch.zeros( - config.batch_size, - config.num_heads, - config.kv_sequence_length, - config.q_sequence_length, - device="cpu", - dtype=torch_type, - requires_grad=False, - ) - if config.qk_output != QKOutputType.NO_OUTPUT - else None - ) - brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") new_mask = brange < cache_seqlens_expanded k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( + out_ref, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1733,7 +1548,6 @@ def parity_check_gqa_prompt_no_buff( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1744,7 +1558,7 @@ def parity_check_gqa_prompt_no_buff( # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1 if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v, out_qk = gqa_prompt_func( + out, present_k, present_v = gqa_prompt_func( packed_qkv, None, None, @@ -1756,8 +1570,6 @@ def parity_check_gqa_prompt_no_buff( cache_seqlens - 1, position_ids, attention_bias, - head_sink, - output_qk, left_window_size, past_format, False, @@ -1768,7 +1580,7 @@ def parity_check_gqa_prompt_no_buff( numpy_type=numpy_type, ) else: - out, present_k, present_v, out_qk = gqa_prompt_func( + out, present_k, present_v = gqa_prompt_func( q, None, None, @@ -1780,8 +1592,6 @@ def parity_check_gqa_prompt_no_buff( cache_seqlens - 1, position_ids, attention_bias, - head_sink, - output_qk, left_window_size, past_format, False, @@ -1795,22 +1605,6 @@ def parity_check_gqa_prompt_no_buff( out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - if config.qk_output != QKOutputType.NO_OUTPUT: - out_qk_ref = ( - out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref - ) - out_qk_ref = out_qk_ref.detach().cpu().numpy() - - for batch_idx in range(config.batch_size): - total_seqlen = cache_seqlens[batch_idx] - assert numpy.allclose( - out_qk[batch_idx, :, :, :total_seqlen], - out_qk_ref[batch_idx, :, :, :total_seqlen], - rtol=rtol, - atol=atol, - equal_nan=True, - ) - # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1852,8 +1646,6 @@ def parity_check_gqa_prompt_no_buff( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, - " qk_output:", - config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1967,8 +1759,6 @@ def parity_check_gqa_past( cos, sin = None, None q_ro, k_ro = q, new_k - head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None - arange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( @@ -1979,7 +1769,7 @@ def parity_check_gqa_past( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( + out_ref, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1991,7 +1781,6 @@ def parity_check_gqa_past( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -2018,24 +1807,10 @@ def parity_check_gqa_past( else None ) - output_qk = ( - torch.zeros( - config.batch_size, - config.num_heads, - config.sequence_length, - config.kv_sequence_length, - device="cpu", - dtype=torch_type, - requires_grad=False, - ) - if config.qk_output != QKOutputType.NO_OUTPUT - else None - ) - # ORT function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v, out_qk = gqa_past_func( + out, present_k, present_v = gqa_past_func( packed_qkv, k, v, @@ -2047,8 +1822,6 @@ def parity_check_gqa_past( cache_seqlens, position_ids, attention_bias, - head_sink, - output_qk, past_format, True, left_window_size, @@ -2059,7 +1832,7 @@ def parity_check_gqa_past( numpy_type=numpy_type, ) else: - out, present_k, present_v, out_qk = gqa_past_func( + out, present_k, present_v = gqa_past_func( q, k, v, @@ -2071,8 +1844,6 @@ def parity_check_gqa_past( cache_seqlens, position_ids, attention_bias, - head_sink, - output_qk, past_format, True, left_window_size, @@ -2086,22 +1857,6 @@ def parity_check_gqa_past( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - if config.qk_output != QKOutputType.NO_OUTPUT: - out_qk_ref = ( - out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref - ) - out_qk_ref = out_qk_ref.detach().cpu().numpy() - - for batch_idx in range(config.batch_size): - total_seqlen = cache_seqlens[batch_idx] + 1 - assert numpy.allclose( - out_qk[batch_idx, :, :, :total_seqlen], - out_qk_ref[batch_idx, :, :, :total_seqlen], - rtol=rtol, - atol=atol, - equal_nan=True, - ) - # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -2127,8 +1882,6 @@ def parity_check_gqa_past( softcap, " smooth_softmax:", use_smooth_softmax, - " head_sink:", - config.has_head_sink, " B:", config.batch_size, " S:", @@ -2145,8 +1898,6 @@ def parity_check_gqa_past( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, - " qk_output:", - config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -2266,8 +2017,6 @@ def parity_check_gqa_past_no_buff( cos, sin = None, None q_ro, k_ro = q, new_k - head_sink = get_custom_head_sink(config.num_heads, torch_type) if config.has_head_sink else None - arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( @@ -2278,7 +2027,7 @@ def parity_check_gqa_past_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( + out_ref, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -2290,7 +2039,6 @@ def parity_check_gqa_past_no_buff( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -2317,24 +2065,10 @@ def parity_check_gqa_past_no_buff( else None ) - output_qk = ( - torch.zeros( - config.batch_size, - config.num_heads, - config.sequence_length, - config.kv_sequence_length + config.sequence_length, - device="cpu", - dtype=torch_type, - requires_grad=False, - ) - if config.qk_output != QKOutputType.NO_OUTPUT - else None - ) - # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v, out_qk = gqa_past_func( + out, present_k, present_v = gqa_past_func( packed_qkv, k, v, @@ -2346,8 +2080,6 @@ def parity_check_gqa_past_no_buff( cache_seqlens, position_ids, attention_bias, - head_sink, - output_qk, past_format, False, window_size=left_window_size, @@ -2358,7 +2090,7 @@ def parity_check_gqa_past_no_buff( numpy_type=numpy_type, ) else: - out, present_k, present_v, out_qk = gqa_past_func( + out, present_k, present_v = gqa_past_func( q, k, v, @@ -2370,8 +2102,6 @@ def parity_check_gqa_past_no_buff( cache_seqlens, position_ids, attention_bias, - head_sink, - output_qk, past_format, False, window_size=left_window_size, @@ -2385,22 +2115,6 @@ def parity_check_gqa_past_no_buff( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - if config.qk_output != QKOutputType.NO_OUTPUT: - out_qk_ref = ( - out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref - ) - out_qk_ref = out_qk_ref.detach().cpu().numpy() - - for batch_idx in range(config.batch_size): - total_seqlen = cache_seqlens[batch_idx] + 1 - assert numpy.allclose( - out_qk[batch_idx, :, :, :total_seqlen], - out_qk_ref[batch_idx, :, :, :total_seqlen], - rtol=rtol, - atol=atol, - equal_nan=True, - ) - # Compare results all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET @@ -2420,8 +2134,6 @@ def parity_check_gqa_past_no_buff( softcap, " smooth_softmax:", use_smooth_softmax, - " head_sink:", - config.has_head_sink, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -2440,8 +2152,6 @@ def parity_check_gqa_past_no_buff( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, - " qk_output:", - config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -2470,16 +2180,7 @@ def setUp(self): ] def run_test_config( - self, - test_func, - config_class, - batches, - seqs, - num_h, - h_sizes, - pos_ids_attn_bias, - qk_output, - additional_params=None, + self, test_func, config_class, batches, seqs, num_h, h_sizes, pos_ids_attn_bias, additional_params=None ): if additional_params is None: additional_params = {} @@ -2501,59 +2202,33 @@ def run_test_config( for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: for has_pos, has_attn in pos_ids_attn_bias: - for head_sink in [False, True]: - if use_smooth_softmax and head_sink: - continue - for output_qk in qk_output: - if config_class == PromptConfig: - config = config_class( - b, - s, - s2, - s + s2 + 8, - n, - n2, - h, - has_pos, - has_attn, - head_sink, - output_qk, - ) - else: # Config - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = config_class( - b, - s, - s2, - sp, - n, - n2, - h, - has_pos, - has_attn, - head_sink, - output_qk, - ) - - params = { - "config": config, - "torch_type": precision["torch_type"], - "numpy_type": precision["numpy_type"], - "ort_type": precision["ort_type"], - "rtol": precision["rtol"], - "atol": precision["atol"], - "local": local, - "past_format": Formats.BNSH, - "rotary": rotary, - "rotary_interleaved": rotary_interleaved, - "packed": packed, - "softcap": softcap, - "use_smooth_softmax": use_smooth_softmax, - } - params.update(additional_params) - - all_close = test_func(**params) - self.assertTrue(all_close) + if config_class == PromptConfig: + config = config_class( + b, s, s2, s + s2 + 8, n, n2, h, has_pos, has_attn + ) + else: # Config + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = config_class(b, s, s2, sp, n, n2, h, has_pos, has_attn) + + params = { + "config": config, + "torch_type": precision["torch_type"], + "numpy_type": precision["numpy_type"], + "ort_type": precision["ort_type"], + "rtol": precision["rtol"], + "atol": precision["atol"], + "local": local, + "past_format": Formats.BNSH, + "rotary": rotary, + "rotary_interleaved": rotary_interleaved, + "packed": packed, + "softcap": softcap, + "use_smooth_softmax": use_smooth_softmax, + } + params.update(additional_params) + + all_close = test_func(**params) + self.assertTrue(all_close) def test_gqa_no_past(self): print("-------- TEST GQA NO PAST (PROMPT CASE) ---------") @@ -2570,33 +2245,12 @@ def test_gqa_no_past(self): ) num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - qk_output = ( - [QKOutputType.NO_OUTPUT] - if pipeline_mode - else [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] - ) # Test with buffer - self.run_test_config( - parity_check_gqa_prompt, - PromptConfig, - batches, - seqs, - num_h, - h_sizes, - pos_ids_attn_bias, - qk_output, - ) + self.run_test_config(parity_check_gqa_prompt, PromptConfig, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) # Test without buffer self.run_test_config( - parity_check_gqa_prompt_no_buff, - PromptConfig, - batches, - seqs, - num_h, - h_sizes, - pos_ids_attn_bias, - qk_output, + parity_check_gqa_prompt_no_buff, PromptConfig, batches, seqs, num_h, h_sizes, pos_ids_attn_bias ) def test_gqa_past(self): @@ -2614,25 +2268,11 @@ def test_gqa_past(self): ) num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - qk_output = ( - [QKOutputType.NO_OUTPUT] - if pipeline_mode - else [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] - ) # Test with buffer - self.run_test_config(parity_check_gqa_past, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias, qk_output) + self.run_test_config(parity_check_gqa_past, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) # Test without buffer - self.run_test_config( - parity_check_gqa_past_no_buff, - Config, - batches, - seqs, - num_h, - h_sizes, - pos_ids_attn_bias, - qk_output, - ) + self.run_test_config(parity_check_gqa_past_no_buff, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) def test_gqa_interactive_one_batch(self): print("-------- TEST GQA INTERACTIVE ---------") @@ -2647,7 +2287,6 @@ def test_gqa_interactive_one_batch(self): if pipeline_mode else [(False, False), (True, True), (False, True), (True, False)] ) - qk_output = [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [32] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] @@ -2660,7 +2299,6 @@ def test_gqa_interactive_one_batch(self): num_h, h_sizes, pos_ids_attn_bias, - qk_output, additional_params={"softcap": 0.0, "use_smooth_softmax": False}, ) self.run_test_config( @@ -2671,7 +2309,6 @@ def test_gqa_interactive_one_batch(self): num_h, h_sizes, pos_ids_attn_bias, - qk_output, additional_params={"softcap": 0.0, "use_smooth_softmax": False}, ) diff --git a/onnxruntime/test/python/transformers/test_gqa_cuda.py b/onnxruntime/test/python/transformers/test_gqa_cuda.py index 79976a92e54bf..2f5b638a57d0c 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cuda.py +++ b/onnxruntime/test/python/transformers/test_gqa_cuda.py @@ -782,8 +782,7 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if use_smooth_softmax: - head_sink = None - attention = smooth_softmax_ref(scores, head_sink) + attention = smooth_softmax_ref(scores) else: attention = torch.softmax(scores, dim=-1) diff --git a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py index ca5c9c2ce133f..410860a324a9d 100644 --- a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py +++ b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py @@ -401,8 +401,7 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if use_smooth_softmax: - head_sink = None - attention = smooth_softmax_ref(scores, head_sink) + attention = smooth_softmax_ref(scores) else: attention = torch.softmax(scores, dim=-1) diff --git a/onnxruntime/test/python/transformers/test_phi_vision.py b/onnxruntime/test/python/transformers/test_phi_vision.py index d276366706af9..67f89e633a146 100644 --- a/onnxruntime/test/python/transformers/test_phi_vision.py +++ b/onnxruntime/test/python/transformers/test_phi_vision.py @@ -149,7 +149,7 @@ def __init__(self): self.attn = PhiVCLIPAttention() self.ln = torch.nn.LayerNorm(20, eps=1e-05) - def forward(self, x, attention_mask=None): + def forward(self, x): # SkipLayerNorm ------+ # | | # Attention | @@ -163,7 +163,8 @@ def forward(self, x, attention_mask=None): x = self.ln(x) residual = x - x = self.attn(x, attention_mask=attention_mask) + # Attention + MatMul + x = self.attn(x) # SkipLayerNorm x = residual + x @@ -193,31 +194,14 @@ def verify_fusion(self, optimized_model, expected_model_filename): ) def export(self, model, inputs): - path = os.path.join(os.path.dirname(__file__), "export.onnx") - - if len(inputs) == 2: - torch.onnx.export( - model, - args=inputs, - f=path, - export_params=True, - opset_version=14, - do_constant_folding=True, - input_names=["input", "attention_mask"], - dynamic_axes={ - "input": {0: "batch", 1: "seq"}, - "attention_mask": {0: "batch", 2: "seq", 3: "seq"}, - }, - ) - else: - torch.onnx.export( - model, - args=inputs, - f=path, - export_params=True, - opset_version=14, - do_constant_folding=True, - ) + torch.onnx.export( + model, + args=inputs, + f=os.path.join(os.path.dirname(__file__), "export.onnx"), + export_params=True, + opset_version=14, + do_constant_folding=True, + ) def tearDown(self): path = os.path.join(os.path.dirname(__file__), "export.onnx") @@ -265,38 +249,6 @@ def test_phi_vision_attention(self): ) self.verify_fusion(optimized_model, "phi-3.5-v-instruct-vision-attention.onnx") - def test_phi_vision_attention_with_mask(self): - model = PhiVCLIPAttentionAndLayerNorm() - - batch, seq_len, dim = 1, 2, 20 - mask = torch.zeros(batch, 1, seq_len, seq_len) - mask[:, 1:] = float("-inf") - - inputs = (torch.randn(batch, seq_len, dim), mask) - self.export(model, inputs) - original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) - options = FusionOptions("clip") - optimized_model = optimize_model( - original_model, - model_type="clip", - num_heads=2, - hidden_size=20, - optimization_options=options, - opt_level=0, - use_gpu=True, - ) - self.verify_fusion(optimized_model, "phi-4-v-instruct-vision-attention.onnx") - - graph = optimized_model.model.graph - attention_node = next((n for n in graph.node if n.name == "Attention_0"), None) - self.assertIsNotNone(attention_node, "Could not find the Attention fused node") - attr_names = [attr.name for attr in attention_node.attribute] - self.assertNotIn( - "unidirectional", - attr_names, - f"The attention node should not have a 'unidirectional' attribute: {attr_names}", - ) - if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx b/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx deleted file mode 100644 index d036541a70aa0..0000000000000 Binary files a/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx and /dev/null differ diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index f02e3e8058c29..450b955f161af 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -1,6 +1,6 @@ # This file is auto updated by dependabot # When any package below is changed, you shall run "lintrunner init" again. lintrunner==0.12.7 -lintrunner-adapters==0.12.5 -ruff==0.12.3 -clang-format==20.1.8 +lintrunner-adapters==0.12.4 +ruff==0.12.2 +clang-format==20.1.7 diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 893f3c80fa4b8..f6e37d33b2414 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -284,8 +284,6 @@ def generate_vcpkg_install_options(build_dir, args): vcpkg_install_options.append("--x-feature=vsinpu-ep") if args.use_webgpu: vcpkg_install_options.append("--x-feature=webgpu-ep") - if args.wgsl_template == "dynamic": - vcpkg_install_options.append("--x-feature=webgpu-ep-wgsl-template-dynamic") if args.use_webnn: vcpkg_install_options.append("--x-feature=webnn-ep") if args.use_xnnpack: @@ -472,7 +470,6 @@ def generate_build_tree( else "OFF" ), "-Donnxruntime_REDUCED_OPS_BUILD=" + ("ON" if is_reduced_ops_build(args) else "OFF"), - "-Donnxruntime_CLIENT_PACKAGE_BUILD=" + ("ON" if args.client_package_build else "OFF"), "-Donnxruntime_BUILD_MS_EXPERIMENTAL_OPS=" + ("ON" if args.ms_experimental else "OFF"), "-Donnxruntime_ENABLE_LTO=" + ("ON" if args.enable_lto else "OFF"), "-Donnxruntime_USE_ACL=" + ("ON" if args.use_acl else "OFF"), diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index 53d53f3e15e99..ad27b8124c458 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -527,15 +527,6 @@ def add_size_reduction_args(parser: argparse.ArgumentParser) -> None: ) -def add_client_package_args(parser: argparse.ArgumentParser) -> None: - """Adds arguments for client package build package.""" - parser.add_argument( - "--client_package_build", - action="store_true", - help="Create ORT package with default settings more appropriate for client/on-device workloads.", - ) - - def add_python_binding_args(parser: argparse.ArgumentParser) -> None: """Adds arguments for Python bindings.""" parser.add_argument("--enable_pybind", action="store_true", help="Enable Python bindings.") @@ -842,7 +833,6 @@ def convert_arg_line_to_args(self, arg_line: str) -> list[str]: # Use list[str] add_dependency_args(parser) add_extension_args(parser) add_size_reduction_args(parser) - add_client_package_args(parser) # Language Bindings add_python_binding_args(parser) diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index e5e2a4749ef85..ee7f8f2fa386a 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 202aa61da0b80..aa25e3f31166a 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -60,7 +60,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index 69dc9d1a8f63d..7addb3217072a 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -6,7 +6,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: IsReleaseBuild displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 526ed71df2006..cf8bbbed70525 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index b99246625cb77..de024f0b3456f 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.36.1.250708 + default: 2.36.0.250627 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 626a638121858..4fa916db0de39 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml index a87bb55441ac7..84b6d30ee32ac 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml @@ -72,8 +72,6 @@ stages: SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} - - template: ../templates/set-version-number-variables-step.yml - # Reconstruct the build dir - task: PowerShell@2 displayName: 'PS: Extract nuget files gpu' @@ -116,7 +114,6 @@ stages: -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu" -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) - -p:PackageVersion=$(OnnxRuntimeVersion) workingDirectory: '$(Build.SourcesDirectory)\csharp' - template: ../templates/win-esrp-dll.yml diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index e2c6b25f48b6d..433250f05125e 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.36.1.250708 + default: 2.36.0.250627 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index 74f7f782fe1b2..ab779e164b36e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,7 +19,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.1.250708' + default: '2.36.0.250627' - name: enableWebGpu displayName: Enable WebGPU test diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 92e862bd79008..110f83ff587c8 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -53,7 +53,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.1.250708' + default: '2.36.0.250627' - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 5b48a14e2afc3..535784933a087 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -47,7 +47,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 930dc83b73460..3e7427cc7a2e3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.1.250708' + default: '2.36.0.250627' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index 96eea6cd6d2fb..e3f549e2d649f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.1.250708' + default: '2.36.0.250627' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index caee5367950e6..d533fb7c83ddd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: is1ES displayName: 'Whether the pipeline is running in 1ES' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 185f41822a7e5..cd060d1fbf19f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 9a1e7e5e251c9..2a2ac49b4e073 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 5affc152a0a4a..8528fa3907e96 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 29ebb8c4e4e61..1406ce338f13e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.36.1.250708' + QnnSdk: '2.36.0.250627' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false @@ -20,7 +20,7 @@ stages: name: ${{ parameters.qnn_ep_build_pool_name }} variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} - commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --client_package_build --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags ' + commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags ' steps: - template: set-version-number-variables-step.yml diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 7ebf5394e4530..78fce1f9b9602 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 jobs: - job: 'BUILD_QNN_EP' @@ -50,7 +50,7 @@ jobs: matrix: SHARED_LIB: QnnLibKind: 'shared_lib' - ExtraQnnBuildArgs: '--client_package_build' + ExtraQnnBuildArgs: '' STATIC_LIB: QnnLibKind: 'static_lib' ExtraQnnBuildArgs: '' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index ffeb577547f69..eb77c9422853d 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.1.250708 + default: 2.36.0.250627 jobs: - job: 'BUILD_QNN_EP'