diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index 70e8ea7e2792f..996e0d816d51a 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -22,6 +22,7 @@ jobs: strategy: matrix: vcpkg_option: [novcpkg, vcpkg] + wgsl_template: [static, dynamic] env: OrtPackageId: Microsoft.ML.OnnxRuntime OnnxRuntimeBuildDirectory: ${{ github.workspace }} @@ -123,6 +124,7 @@ jobs: --build_nodejs ` --build_java ` --use_webgpu ` + --wgsl_template ${{ matrix.wgsl_template }} ` ${{ matrix.vcpkg_option == 'vcpkg' && '--use_vcpkg' || '' }} ` --cmake_extra_defines ` onnxruntime_BUILD_UNIT_TESTS=ON ` diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index fb4238731ffc3..b01110b2a4a03 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -151,6 +151,7 @@ option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OF option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF) option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF) option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF) +option(onnxruntime_CLIENT_PACKAGE_BUILD "Enables default settings that are more appropriate for client/on-device workloads." OFF) cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF) # For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone cmake_dependent_option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." ON "onnxruntime_MINIMAL_BUILD;NOT onnxruntime_ENABLE_PYTHON" OFF) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 59d99ade131cd..6d517003fa6b6 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -95,6 +95,11 @@ if (onnxruntime_MINIMAL_BUILD) endif() endif() +# ORT build with default settings more appropriate for client/on-device workloads. +if (onnxruntime_CLIENT_PACKAGE_BUILD) + add_compile_definitions(ORT_CLIENT_PACKAGE_BUILD) +endif() + if (onnxruntime_ENABLE_LTO) include(CheckIPOSupported) check_ipo_supported(RESULT ipo_enabled OUTPUT ipo_output) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index e8f6bbe895d29..228906030d14c 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -774,13 +774,24 @@ if (onnxruntime_USE_WEBGPU) endif() if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic") - onnxruntime_fetchcontent_declare( - duktape - URL ${DEP_URL_duktape} - URL_HASH SHA1=${DEP_SHA1_duktape} - EXCLUDE_FROM_ALL - ) - onnxruntime_fetchcontent_makeavailable(duktape) + if(onnxruntime_USE_VCPKG) + find_package(unofficial-duktape CONFIG REQUIRED) + add_library(duktape_static ALIAS unofficial::duktape::duktape) + else() + onnxruntime_fetchcontent_declare( + duktape + URL ${DEP_URL_duktape} + URL_HASH SHA1=${DEP_SHA1_duktape} + EXCLUDE_FROM_ALL + ) + onnxruntime_fetchcontent_makeavailable(duktape) + + if(NOT TARGET duktape_static) + add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c") + target_compile_features(duktape_static PRIVATE c_std_99) + target_include_directories(duktape_static INTERFACE $) + endif() + endif() endif() endif() diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index f8f5546ae9465..47e7779d93b33 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -31,6 +31,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/eltwise.cpp ${MLAS_SRC_DIR}/erf.cpp ${MLAS_SRC_DIR}/compute.cpp + ${MLAS_SRC_DIR}/dequantize.cpp ${MLAS_SRC_DIR}/quantize.cpp ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp ${MLAS_SRC_DIR}/qladd.cpp diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 69c81a5ec7b9d..4184e0b049afc 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -72,10 +72,9 @@ endif() # TensorRT 10 GA onwards, the TensorRT libraries will have major version appended to the end on Windows, - # for example, nvinfer_10.dll, nvinfer_plugin_10.dll, nvonnxparser_10.dll ... + # for example, nvinfer_10.dll, nvonnxparser_10.dll ... if (WIN32 AND TRT_GREATER_OR_EQUAL_TRT_10_GA) set(NVINFER_LIB "nvinfer_${NV_TENSORRT_MAJOR}") - set(NVINFER_PLUGIN_LIB "nvinfer_plugin_${NV_TENSORRT_MAJOR}") set(PARSER_LIB "nvonnxparser_${NV_TENSORRT_MAJOR}") endif() @@ -83,15 +82,11 @@ set(NVINFER_LIB "nvinfer") endif() - if (NOT NVINFER_PLUGIN_LIB) - set(NVINFER_PLUGIN_LIB "nvinfer_plugin") - endif() - if (NOT PARSER_LIB) set(PARSER_LIB "nvonnxparser") endif() - MESSAGE(STATUS "Looking for ${NVINFER_LIB} and ${NVINFER_PLUGIN_LIB}") + MESSAGE(STATUS "Looking for ${NVINFER_LIB}") find_library(TENSORRT_LIBRARY_INFER ${NVINFER_LIB} HINTS ${TENSORRT_ROOT} @@ -101,14 +96,6 @@ MESSAGE(STATUS "Can't find ${NVINFER_LIB}") endif() - find_library(TENSORRT_LIBRARY_INFER_PLUGIN ${NVINFER_PLUGIN_LIB} - HINTS ${TENSORRT_ROOT} - PATH_SUFFIXES lib lib64 lib/x64) - - if (NOT TENSORRT_LIBRARY_INFER_PLUGIN) - MESSAGE(STATUS "Can't find ${NVINFER_PLUGIN_LIB}") - endif() - if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) MESSAGE(STATUS "Looking for ${PARSER_LIB}") @@ -120,7 +107,7 @@ MESSAGE(STATUS "Can't find ${PARSER_LIB}") endif() - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN} ${TENSORRT_LIBRARY_NVONNXPARSER}) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_NVONNXPARSER}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") else() if (TRT_GREATER_OR_EQUAL_TRT_10_GA) @@ -153,7 +140,7 @@ endif() # Static libraries are just nvonnxparser_static on all platforms set(onnxparser_link_libs nvonnxparser_static) - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN}) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") endif() @@ -161,7 +148,7 @@ # nvonnxparser_static is linked against tensorrt libraries in onnx-tensorrt # See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121 # However, starting from TRT 10 GA, nvonnxparser_static doesn't link against tensorrt libraries. - # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER} and ${TENSORRT_LIBRARY_INFER_PLUGIN}. + # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER}. if(onnxruntime_CUDA_MINIMAL) set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) else() diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index 5b80b1262464d..2865ad33b39f4 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -172,10 +172,12 @@ file(MAKE_DIRECTORY ${WGSL_GENERATED_DIR}) # Find all WGSL template input files - file(GLOB_RECURSE WGSL_TEMPLATE_FILES "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template") + file(GLOB_RECURSE WGSL_TEMPLATE_FILES + "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template" + "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.wgsl.template") # Set wgsl-gen command line options as a list - set(WGSL_GEN_OPTIONS "-i" "../" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") + set(WGSL_GEN_OPTIONS "-i" "${ONNXRUNTIME_ROOT}/core/providers/webgpu/" "-i" "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") if (onnxruntime_WGSL_TEMPLATE STREQUAL "static") if (CMAKE_BUILD_TYPE STREQUAL "Debug") list(APPEND WGSL_GEN_OPTIONS "--generator" "static-cpp-literal") @@ -207,10 +209,9 @@ # Add the generated directory to include paths target_include_directories(onnxruntime_providers_webgpu PRIVATE ${WGSL_GENERATED_ROOT}) elseif(onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic") - add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c") - target_compile_features(duktape_static PRIVATE c_std_99) target_link_libraries(onnxruntime_providers_webgpu duktape_static) - target_include_directories(onnxruntime_providers_webgpu PRIVATE ${duktape_SOURCE_DIR}/src) + onnxruntime_add_include_to_target(onnxruntime_providers_webgpu duktape_static) + # Define the path to the generated templates.js file target_compile_definitions(onnxruntime_providers_webgpu PRIVATE "ORT_WGSL_TEMPLATES_JS_PATH=\"${WGSL_GENERATED_TEMPLATES_JS}\"") diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json index 7c6b2fed36d1b..373ecec440921 100644 --- a/cmake/vcpkg.json +++ b/cmake/vcpkg.json @@ -43,7 +43,6 @@ "ms-gsl", "nlohmann-json", "onnx", - "optional-lite", { "name": "protobuf", "version>=": "3.21.12" @@ -94,6 +93,10 @@ "webgpu-ep": { "description": "Build with WebGPU EP", "dependencies": [] + }, + "webgpu-ep-wgsl-template-dynamic": { + "description": "Build with WebGPU EP with dynamic WGSL template code generator", + "dependencies": ["duktape"] } }, "overrides": [ @@ -104,6 +107,10 @@ { "name": "flatbuffers", "version": "23.5.26" + }, + { + "name": "duktape", + "version": "2.7.0#2" } ] } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs index c28830ec72157..6e6190b8227b8 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/Tests.cs @@ -40,10 +40,12 @@ public void RunPlatformUnitTest() var serializedResultSummary = _app.Invoke(_getResultsBackdoorMethodName)?.ToString(); Assert.IsNotEmpty(serializedResultSummary, "Test results were not returned"); + // Fix security issue (overflow with too much nesting): GHSA-5crp-9r3c-p9vr + JsonConvert.DefaultSettings = () => new JsonSerializerSettings { MaxDepth = 128 }; var testSummary = JsonConvert.DeserializeObject(serializedResultSummary); Assert.AreEqual(testSummary.Failed, 0, $"{testSummary.Failed} tests failed"); _app.Screenshot("Post-testing"); } } -} \ No newline at end of file +} diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs index 8419d261e4a41..625cc2c54055c 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/TestResultProcessor.cs @@ -45,8 +45,9 @@ public TestResultSummary GetResults() public string GetSerializedResults() { var resultSummary = GetResults(); + JsonConvert.DefaultSettings = () => new JsonSerializerSettings { MaxDepth = 128 }; var serializedResultSummary = JsonConvert.SerializeObject(resultSummary, Formatting.Indented); return serializedResultSummary; } } -} \ No newline at end of file +} diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index b80918e6615e1..f3dcde1abe37a 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2545,6 +2545,8 @@ This version of the operator has been available since version 1 of the 'com.micr
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
+
qk_output : int
+
Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).
rotary_interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
scale : float
@@ -2555,7 +2557,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Softcap value for attention weights. Default value is 0.
-#### Inputs (7 - 11) +#### Inputs (7 - 12)
query : T
@@ -2580,9 +2582,11 @@ This version of the operator has been available since version 1 of the 'com.micr
2D tensor with shape (batch_size, sequence_length). When processing the first prompt the kernel uses only the first element
attention_bias (optional) : T
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
+
head_sink (optional) : T
+
1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.
-#### Outputs +#### Outputs (3 - 4)
output : T
@@ -2591,6 +2595,8 @@ This version of the operator has been available since version 1 of the 'com.micr
present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
present_value : T
present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
+
output_qk (optional) : T
+
Values of QK matrix multiplication, either before or after softmax normalization
#### Type Constraints diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 1ffcabee8cc10..fa6c731231405 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -538,7 +538,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -942,7 +942,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1420,7 +1420,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 54e03a31fceef..c18a42cc1bbc1 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -952,9 +952,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return const_cast(this)->GetNodeArg(name); } - // search this and up through any parent_graph_ instance for a NodeArg + // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding mutable NodeArg NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name); + // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding const NodeArg + const NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const; + /** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found. @param name The NodeArg name. @param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created. diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h new file mode 100644 index 0000000000000..37665542f614f --- /dev/null +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -0,0 +1,718 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/* + SUMMARY: + Utilities to serialize an OrtGraph into an ONNX GraphProto or ModelProto. Can be used by execution provider + implementations that need to convert an OrtGraph instance into an ONNX protobuf model. + + Users may copy this file and modify as needed. + + USAGE: + This is a header-only implementation that includes both the function declarations and definitions. Copy this file + into a project that links with both ONNX Runtime and ONNX. + + Define the ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL preprocessor macro before the #include statement in exactly one C++ + file to define the implementation. Example: + + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + Other compilation units that depend on these utilities should include this file without defining the + preprocessor macro. + + Example program snippets are shown below. Refer to the function declarations for detailed usage information. + + EXAMPLE SNIPPET (initializers stored within TensorProto): + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + onnx::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto); + + // graph_proto stores initializers internally + } + ``` + + EXAMPLE SNIPPET (large initializers stored in external file): + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + std::string external_file_path = "weights.bin"; + std::ofstream out_file(external_file_path, std::ios::binary); + + auto handle_initializer_data = [&external_file_path, &out_file](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = out_file.tellp(); + location = external_file_path; + out_file.write(static_cast(data), bytes); + out_file.flush(); + is_external = true; // True if is external initializer + return Ort::Status{nullptr}; + } + + ONNX_NAMESPACE::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); + + // graph_proto stores large initializers in an external file + } + ``` +*/ + +#ifndef INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ +#define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ + +#include +#include "core/session/onnxruntime_cxx_api.h" +#include "onnx/onnx_pb.h" + +namespace OrtEpUtils { + +/// +/// Signature of user-provided function to handle initializer data. Called by OrtGraphToProto() for every initializer. +/// +/// If the function sets the `is_external` output parameter to false, OrtGraphToProto() stores initializer data +/// within the TensorProto as raw_data. +/// +/// Otherwise, if the function sets `is_external` to true, OrtGraphToProto() assumes that this function stores the +/// initializer data in a file. In this case, OrtGraphToProto() configures the corresponding TensorProto to point the +/// location and offset returned via the `location` and `offset` output parameters. +/// +/// It is recommended to keep small initializers with byte size <= 127 stored inline the TensorProto to ensure +/// ONNX shape inference works correctly with the serialized ONNX model. +/// +/// OrtValueInfo for the initializer. Can be used to query name, type, shape, +/// and consumer nodes. +/// Opaque pointer to the initializer data. +/// Size in bytes of the initializer data. +/// Output parameter set to true if the initializer data is stored externally. The +/// implementer is responsible for writing the initializer data to file. If set to false, +/// the initializer will be stored within the TensorProto. +/// Output parameter set to the location (e.g., file) into which the initializer is stored +/// by the implementer of this function. Ignored if `is_external` is set to false. +/// Output parameter set to the offset (e.g., file offset) into which the initializer is stored +/// by the implementer of this function. Ignored if `is_external` is set to false. +/// An Ort::Status indicating success or an error. Serialization exits if this returns an error. +using HandleInitializerDataFunc = std::function; + +/// +/// Serializes the provided OrtGraph to a onnx::GraphProto. +/// Allows the caller to provide a function that specifies whether an initializer should be stored +/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). +/// +/// OrtGraph instance to serialize. +/// Destination GraphProto into which to serialize the input OrtGraph. +/// Optional function called to allow the user to determine +/// where the initializer data is stored. +/// An Ort::Status indicating success or an error. +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::GraphProto& graph_proto, + HandleInitializerDataFunc handle_initializer_data_func = nullptr); + +/// +/// Serializes the provided top-level OrtGraph to a onnx::ModelProto. +/// Allows the caller to provide a function that specifies whether an initializer should be stored +/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). +/// +/// OrtGraph instance to serialize. +/// Destination ModelProto into which to serialize the input OrtGraph. +/// Optional function called to allow the user to determine +/// where the initializer data is stored. +/// An Ort::Status indicating success or an error. +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::ModelProto& model_proto, + HandleInitializerDataFunc handle_initializer_data_func = nullptr); +} // namespace OrtEpUtils + +// End of header +#endif // INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ + +// +// IMPLEMENTATION BELOW +// +#ifdef ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + +#include +#include +#include +#include +#include +#include + +#define ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return Ort::Status{_status}; \ + } \ + } while (0) + +#define ORT_EP_UTILS_CXX_RETURN_IF_ERROR(fn) \ + do { \ + Ort::Status _status = (fn); \ + if (!_status.IsOK()) { \ + return _status; \ + } \ + } while (0) + +#define ORT_EP_UTILS_C_RETURN_IF(cond, ort_api, msg) \ + do { \ + if ((cond)) { \ + return Ort::Status{(ort_api).CreateStatus(ORT_FAIL, (msg))}; \ + } \ + } while (0) + +namespace OrtEpUtils { + +static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, + bool get_symbolic_dims, + /*out*/ ONNXTensorElementDataType& elem_type, + /*out*/ std::vector& dims, + /*out*/ std::vector& symbolic_dims); +static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); + +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::GraphProto& graph_proto, + HandleInitializerDataFunc handle_initializer_data_func) { + const OrtApi& ort_api = Ort::GetApi(); + + // + // Set GraphProto metadata + // + const char* graph_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetName(&ort_graph, &graph_name)); + graph_proto.set_name(graph_name); + graph_proto.set_doc_string("Serialized from OrtGraph"); + + // + // Set GraphProto inputs and outputs + // + size_t num_graph_inputs = 0; + size_t num_graph_outputs = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumInputs(&ort_graph, &num_graph_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(&ort_graph, &num_graph_outputs)); + + std::vector graph_inputs(num_graph_inputs); + std::vector graph_outputs(num_graph_outputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetInputs(&ort_graph, graph_inputs.data(), graph_inputs.size())); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOutputs(&ort_graph, graph_outputs.data(), graph_outputs.size())); + + for (const OrtValueInfo* ort_value_info : graph_inputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); + } + + for (const OrtValueInfo* ort_value_info : graph_outputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); + } + + // + // Set GraphProto nodes, value_infos, and initializers. + // + + // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. + // A std::map maintains its elements in a stable ordering. + std::map value_infos; // For GraphProto.value_info + std::map initializer_value_infos; // For GraphProto.initializer + + // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. + // Optionally returns the OrtValueInfo name to the caller. + auto collect_value_info = [&ort_api, &value_infos, + &initializer_value_infos](const OrtValueInfo& ort_value_info, + /*out*/ const char** value_name_out = nullptr) -> Ort::Status { + const char* value_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); + + if (value_name_out != nullptr) { + *value_name_out = value_name; + } + + if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { + return Ort::Status{nullptr}; // Already processed this OrtValueInfo. + } + + bool is_required_graph_input = false; + bool is_optional_graph_input = false; + bool is_graph_output = false; + bool is_constant_initializer = false; + bool is_from_outer_scope = false; + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsRequiredGraphInput(&ort_value_info, &is_required_graph_input)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsOptionalGraphInput(&ort_value_info, &is_optional_graph_input)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsGraphOutput(&ort_value_info, &is_graph_output)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(&ort_value_info, &is_constant_initializer)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsFromOuterScope(&ort_value_info, &is_from_outer_scope)); + + // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. + // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. + // For values defined in an outer scope, just add the value info but not the initializer. + if (is_from_outer_scope) { + value_infos.emplace(value_name, &ort_value_info); + } else if (is_optional_graph_input) { + initializer_value_infos.emplace(value_name, &ort_value_info); + } else if (is_constant_initializer) { + value_infos.emplace(value_name, &ort_value_info); + initializer_value_infos.emplace(value_name, &ort_value_info); + } else if (!is_required_graph_input && !is_graph_output) { + value_infos.emplace(value_name, &ort_value_info); // This is an internal OrtValueInfo. + } + + return Ort::Status{nullptr}; + }; + + size_t num_nodes = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); + + std::vector nodes(num_nodes); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); + + // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos + // that will be stored in GraphProto.value_info and GraphProto.initializer. + for (size_t i = 0; i < num_nodes; i++) { + const OrtNode* ort_node = nodes[i]; + onnx::NodeProto* node_proto = graph_proto.add_node(); + + const char* node_name = nullptr; + const char* node_domain = nullptr; + const char* node_op_type = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetName(ort_node, &node_name)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetDomain(ort_node, &node_domain)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOperatorType(ort_node, &node_op_type)); + + node_proto->set_name(node_name); + node_proto->set_domain(node_domain); + node_proto->set_op_type(node_op_type); + + size_t num_inputs = 0; + size_t num_implicit_inputs = 0; + size_t num_outputs = 0; + size_t num_attrs = 0; + size_t num_subgraphs = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumInputs(ort_node, &num_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumImplicitInputs(ort_node, &num_implicit_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(ort_node, &num_outputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(ort_node, &num_attrs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumSubgraphs(ort_node, &num_subgraphs)); + + // Handle node attributes + if (num_attrs > 0) { + std::vector ort_attrs(num_attrs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetAttributes(ort_node, ort_attrs.data(), ort_attrs.size())); + + for (const OrtOpAttr* ort_attr : ort_attrs) { + OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + + Ort::Status status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; + if (!status.IsOK()) { + // This is an attribute type that ORT does not support via ReadOpAttr(), like subgraphs, so skip it. + // Can use Node_GetSubgraphs to get subgraphs. + continue; + } + + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + } + } + + // Handle node subgraphs + if (num_subgraphs > 0) { + std::vector ort_subgraphs(num_subgraphs); + std::vector subgraph_attr_names(num_subgraphs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetSubgraphs(ort_node, ort_subgraphs.data(), ort_subgraphs.size(), + subgraph_attr_names.data())); + + for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { + const OrtGraph* ort_subgraph = ort_subgraphs[subgraph_idx]; + const char* subgraph_attr_name = subgraph_attr_names[subgraph_idx]; + + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::GraphProto* subgraph_proto = attr_proto->mutable_g(); + + attr_proto->set_name(subgraph_attr_name); + attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_subgraph, *subgraph_proto)); + } + } + + // Handle node inputs + if (num_inputs > 0) { + std::vector ort_inputs(num_inputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetInputs(ort_node, ort_inputs.data(), ort_inputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_inputs) { + if (ort_value_info == nullptr) { + // missing optional input. + node_proto->add_input(""); + continue; + } + + const char* value_name = nullptr; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); + + node_proto->add_input(value_name); + } + } + + // Handle implicit inputs to this node. + if (num_implicit_inputs > 0) { + std::vector ort_implicit_inputs(num_implicit_inputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetImplicitInputs(ort_node, ort_implicit_inputs.data(), + ort_implicit_inputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_implicit_inputs) { + assert(ort_value_info != nullptr); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, /*value_name_out*/ nullptr)); + } + } + + // Handle node outputs + if (num_outputs > 0) { + std::vector ort_outputs(num_outputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOutputs(ort_node, ort_outputs.data(), ort_outputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_outputs) { + if (ort_value_info == nullptr) { + // missing optional output. + node_proto->add_output(""); + continue; + } + + const char* value_name = nullptr; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); + + node_proto->add_output(value_name); + } + } + } + + // Add value_infos to GraphProto as ValueInfoProto objects. + for (const std::pair& entry : value_infos) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*entry.second, *value_info_proto)); + } + + // Add initializers to GraphProto as TensorProto objects. + for (const std::pair& entry : initializer_value_infos) { + const OrtValueInfo* initializer_value_info = entry.second; + std::string initializer_name = std::string{entry.first}; // Need a null-terminated string. + std::vector initializer_dims; + std::vector initializer_sym_dims; + ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(*initializer_value_info, /*get_sym_dims*/ false, + initializer_elem_type, initializer_dims, + initializer_sym_dims)); + + onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); + tensor_proto->set_name(initializer_name); + tensor_proto->set_data_type(initializer_elem_type); + + auto* tensor_proto_dims = tensor_proto->mutable_dims(); + for (int64_t dim : initializer_dims) { + tensor_proto_dims->Add(dim); + } + + const OrtValue* ort_value = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer_value_info, &ort_value)); + + const void* data = nullptr; + size_t data_bytes = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); + + std::string ext_location; + int64_t ext_offset = 0; + bool is_external = false; + + if (handle_initializer_data_func != nullptr) { + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, + is_external, ext_location, ext_offset)); + } + + if (is_external) { + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); + auto* ext_data_entries = tensor_proto->mutable_external_data(); + onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); + + location_entry->set_key("location"); + location_entry->set_value(ext_location); + offset_entry->set_key("offset"); + offset_entry->set_value(std::to_string(ext_offset)); + } else { + // User wants to store data inline the TensorProto's raw_data + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); + tensor_proto->set_raw_data(data, data_bytes); + } + } + + return Ort::Status{nullptr}; +} + +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::ModelProto& model_proto, + HandleInitializerDataFunc handle_initializer_data_func) { + const OrtApi& ort_api = Ort::GetApi(); + + // Check that OrtGraph is a top-level graph (no parent node). + const OrtNode* parent_node = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetParentNode(&ort_graph, &parent_node)); + ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, ort_api, "Cannot serialize nested OrtGraph into a ModelProto"); + + // Set model description. + model_proto.set_doc_string("Serialized from OrtGraph"); + model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); + + // Set ir version. + int64_t ir_version = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOnnxIRVersion(&ort_graph, &ir_version)); + model_proto.set_ir_version(ir_version); + + // Set operator sets. + size_t num_operator_sets = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOperatorSets(&ort_graph, &num_operator_sets)); + ORT_EP_UTILS_C_RETURN_IF(num_operator_sets == 0, ort_api, "OrtGraph should have at least one operator set."); + + std::vector domains(num_operator_sets, nullptr); + std::vector opset_versions(num_operator_sets); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOperatorSets(&ort_graph, domains.data(), opset_versions.data(), + num_operator_sets)); + + auto* operator_sets = model_proto.mutable_opset_import(); + + for (size_t i = 0; i < num_operator_sets; ++i) { + onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); + operator_set->set_domain(domains[i]); + operator_set->set_version(opset_versions[i]); + } + + model_proto.clear_graph(); + onnx::GraphProto* graph_proto = model_proto.mutable_graph(); + + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(ort_graph, *graph_proto, handle_initializer_data_func)); + + return Ort::Status{nullptr}; +} + +static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, + bool get_symbolic_dims, + /*out*/ ONNXTensorElementDataType& elem_type, + /*out*/ std::vector& dims, + /*out*/ std::vector& symbolic_dims) { + const OrtApi& ort_api = Ort::GetApi(); + + const OrtTypeInfo* ort_type_info = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(&ort_value_info, &ort_type_info)); + + ONNXType ort_onnx_type = ONNX_TYPE_UNKNOWN; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(ort_type_info, &ort_onnx_type)); + ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, ort_api, "Expected OrtValueInfo to represent a Tensor"); + + const OrtTensorTypeAndShapeInfo* ort_type_shape = nullptr; + ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(ort_type_info, &ort_type_shape)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(ort_type_shape, &ort_elem_type)); + + size_t num_dims = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(ort_type_shape, &num_dims)); + + std::vector ort_dims(num_dims, 0); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(ort_type_shape, ort_dims.data(), ort_dims.size())); + + elem_type = ort_elem_type; + dims = std::move(ort_dims); + + if (get_symbolic_dims) { + std::vector ort_dim_syms(num_dims, nullptr); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetSymbolicDimensions(ort_type_shape, ort_dim_syms.data(), + ort_dim_syms.size())); + + symbolic_dims.reserve(num_dims); + for (const char* sym_dim : ort_dim_syms) { + symbolic_dims.push_back(sym_dim); + } + } + + return Ort::Status{nullptr}; +} + +// Create an onnx::ValueInfoProto from an OrtValueInfo (name, type, shape). +static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, + onnx::ValueInfoProto& value_info_proto) { + const OrtApi& ort_api = Ort::GetApi(); + + std::vector ort_dims; + std::vector ort_dim_syms; + ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + // We currently only support ONNX tensors. Support for other types (e.g., ONNX_TYPE_SEQUENCE) can be added later. + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, + ort_elem_type, ort_dims, ort_dim_syms)); + + const char* value_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); + value_info_proto.set_name(value_name); + + onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); + type_proto_tensor->set_elem_type(ort_elem_type); + + onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); + + for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { + onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); + + if (ort_dims[dim_idx] >= 0) { + dim_proto->set_dim_value(ort_dims[dim_idx]); + } else { + const std::string& dim_param = ort_dim_syms[dim_idx]; + + // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, + // which represents an unknown dimension. + if (!dim_param.empty()) { + dim_proto->set_dim_param(dim_param); + } + } + } + + return Ort::Status{nullptr}; +} + +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { + const OrtApi& ort_api = Ort::GetApi(); + + const char* attr_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetName(&ort_attr, &attr_name)); + attr_proto.set_name(attr_name); + + size_t total_attr_bytes = 0; + OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetType(&ort_attr, &attr_type)); + + switch (attr_type) { + case OrtOpAttrType::ORT_OP_ATTR_INT: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); + + int64_t i_val = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &i_val, sizeof(i_val), &total_attr_bytes)); + attr_proto.set_i(i_val); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_INTS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector i_vals(total_attr_bytes / sizeof(int64_t)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, i_vals.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* ints = attr_proto.mutable_ints(); + for (int64_t val : i_vals) { + ints->Add(val); + } + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); + + float f_val = 0.0f; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &f_val, sizeof(f_val), &total_attr_bytes)); + attr_proto.set_f(f_val); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector f_vals(total_attr_bytes / sizeof(float)); + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, f_vals.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* floats = attr_proto.mutable_floats(); + for (float val : f_vals) { + floats->Add(val); + } + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRING: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::string* str = attr_proto.mutable_s(); + + str->resize(total_attr_bytes, '\0'); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, + &total_attr_bytes)); + + str->resize(total_attr_bytes - 1); // remove extra ending terminating '\0' character. + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector chars(total_attr_bytes, '\0'); + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, chars.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* strs = attr_proto.mutable_strings(); + + // Strings are all in a single buffer, each separated with a '\0'. + // Extract each string and add it to the STRINGS attribute array. + char* at = chars.data(); + char* end = at + chars.size(); + + while (at < end) { + char* str_begin = at; + + while (*at && at < end) { + at++; + } + + strs->Add()->assign(str_begin, at - str_begin); + if (at < end) { + assert(*at == '\0'); + at++; // Skip '\0' to get to the beginning of the next string. + } + } + + break; + } + default: { + std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } + + return Ort::Status{nullptr}; +} + +} // namespace OrtEpUtils +#endif // ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 86c0b60db2bc4..82e782112974f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -66,6 +66,7 @@ extern "C" { #define _In_reads_(X) #define _Inout_updates_(X) #define _Out_writes_(X) +#define _Out_writes_opt_(X) #define _Inout_updates_all_(X) #define _Out_writes_bytes_all_(X) #define _Out_writes_all_(X) @@ -4749,6 +4750,8 @@ struct OrtApi { * \param[in] len Number of bytes allowed to store in data * \param[out] out Number of bytes required to save the data when the call failed, or the real number of bytes saved to data on success * + * \note Does not support reading graph attributes. Refer to Node_GetSubgraphs. + * * \since Version 1.17. */ ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out); @@ -5568,6 +5571,45 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); + /** \brief Returns the number of operator sets that the graph's model uses. + * + * \note An operator set is uniquely identified by the (domain, opset_version) pair. All models must have at + * least one entry that specifies which entry of the ONNX operator set is used. The ONNX domain is represented by + * an empty string. + * + * \param[in] graph The OrtGraph instance. + * \param[out] num_operator_sets Output parameter set to the number of operator sets that the graph's model uses. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); + + /** \brief Returns the operator sets that the graph's model uses. + * + * \note An operator set is uniquely identified by the (domain, opset_version) pair. All models must have at + * least one entry that specifies which entry of the ONNX operator set is used. The ONNX domain is represented by + * an empty string. + * + * \param[in] graph The OrtGraph instance. + * \param[out] domains Pre-allocated array of `num_operator_sets` elements that is filled with + * null-terminated domain names. + * \param[out] opset_versions Pre-allocated array of `num_operator_sets` elements that is filled with + * the opset version of the corresponding domain in the `domains` array. + * \param[in] num_operator_sets The size of the `domains` and `opset_versions` arrays. + * Typical usage sets this to the result of Graph_GetNumOperatorSets(). + * An error status is returned if `num_operator_sets` is less than the actual number + * of operator sets. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetOperatorSets, _In_ const OrtGraph* graph, + _Out_writes_(num_operator_sets) const char** domains, + _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets); + /** \brief Returns the number of graph inputs. * * \note The count includes initializers that are included in the list of graph inputs. @@ -5706,6 +5748,24 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); + /** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph. + * + * Note: + * The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference + * the same underlying graph. + * + * \param[in] src_graph The source OrtGraph instance. + * \param[in] nodes A subset of the nodes/OrtNodes in 'graph'. + * \param[in] num_nodes Number of nodes. + * \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetGraphView, _In_ const OrtGraph* src_graph, _In_ const OrtNode** nodes, + _In_ size_t num_nodes, _Outptr_ OrtGraph** dst_graph); + /// @} /// \name OrtNode @@ -5933,20 +5993,24 @@ struct OrtApi { /** \brief Get the subgraphs, as OrtGraph instances, contained by the given node. * - * \note Only certain operator types (e.g., If and Loop) contain nested subgraphs. + * \note Only certain operator types (e.g., If and Loop) contain nested subgraphs. ONNX nodes store subgraphs in + * their attributes, however, this function must be used to obtain subgraphs from an OrtNode. * * \param[in] node The OrtNode instance. * \param[out] subgraphs Pre-allocated array of `num_subgraphs` elements that is filled with the node's subgraphs. * \param[in] num_subgraphs The size of the `num_subgraphs` array. * Typical usage sets this to the result of Node_GetNumSubgraphs(). An error status is * returned if `num_subgraphs` is less than the number of node subgraphs. + * \param[out] attribute_names Optional pre-allocated array of `num_subgraphs` elements that is filled with the + * attribute names that correspond to the subgraphs. Ignored if set to NULL. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs); + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, + _Out_writes_opt_(num_subgraphs) const char** attribute_names); /** \brief Get the node's parent OrtGraph instance. * @@ -5962,6 +6026,19 @@ struct OrtApi { */ ORT_API2_STATUS(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); + /** \brief Returns the execution provider name that this node is assigned to run on. + * Returns NULL if the node has not been assigned to any execution provider yet. + * For plugin execution providers, the name is the one returned by OrtEp::GetName. + * + * \param[in] node The OrtNode instance. + * \param[out] out Output execution provider type and can be NULL if node has not been assigned. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); + /// @} /// \name OrtRunOptions @@ -6810,6 +6887,24 @@ struct OrtCompileApi { */ ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options, size_t flags); + + /** Sets information related to EP context binary file. + * + * EP uses this information to decide the location and context binary file name. + * Used while compiling model with input and output in memory buffer + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] output_directory Null terminated string of the path (wchar on Windows, char otherwise). + * \param[in] model_name Null terminated string of the model name (wchar on Windows, char otherwise). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetEpContextBinaryInformation, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const ORTCHAR_T* output_directory, + _In_ const ORTCHAR_T* model_name); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index c59baa59c91a5..d1b08f127fa2a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1161,6 +1161,8 @@ struct ModelCompilationOptions : detail::Base { size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer + ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory, + const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 612adc81d3309..ba5d53e6c2dd0 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -819,6 +819,15 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelPath( return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextBinaryInformation( + const ORTCHAR_T* output_directory, const ORTCHAR_T* model_name) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextBinaryInformation( + this->p_, + output_directory, + model_name)); + return *this; +} + inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalInitializersFile( const ORTCHAR_T* file_path, size_t initializer_size_threshold) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelExternalInitializersFile( diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 44c7bb6ee424a..5d00ce4940d02 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -358,7 +358,7 @@ struct OrtEp { * * \since Version 1.22. */ - const char*(ORT_API_CALL* GetName)(_In_ const OrtEp* this_ptr); + ORT_API_T(const char*, GetName, _In_ const OrtEp* this_ptr); /** \brief Get information about the nodes supported by the OrtEp instance. * @@ -376,8 +376,8 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* GetCapability)(_In_ OrtEp* this_ptr, _In_ const OrtGraph* graph, - _Inout_ OrtEpGraphSupportInfo* graph_support_info); + ORT_API2_STATUS(GetCapability, _In_ OrtEp* this_ptr, _In_ const OrtGraph* graph, + _Inout_ OrtEpGraphSupportInfo* graph_support_info); /** \brief Compile OrtGraph instances assigned to the OrtEp. Implementer must set a OrtNodeComputeInfo instance * for each OrtGraph in order to define its computation function. @@ -416,10 +416,10 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* Compile)(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, - _In_ const OrtNode** fused_nodes, _In_ size_t count, - _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes); + ORT_API2_STATUS(Compile, _In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes); /** \brief Release OrtNodeComputeInfo instances. * @@ -429,9 +429,9 @@ struct OrtEp { * * \since Version 1.23. */ - void(ORT_API_CALL* ReleaseNodeComputeInfos)(_In_ OrtEp* this_ptr, - OrtNodeComputeInfo** node_compute_infos, - _In_ size_t num_node_compute_infos); + ORT_API_T(void, ReleaseNodeComputeInfos, _In_ OrtEp* this_ptr, + OrtNodeComputeInfo** node_compute_infos, + _In_ size_t num_node_compute_infos); /** \brief Get the EP's preferred data layout. * @@ -445,8 +445,7 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* GetPreferredDataLayout)(_In_ OrtEp* this_ptr, - _Out_ OrtEpDataLayout* preferred_data_layout); + ORT_API2_STATUS(GetPreferredDataLayout, _In_ OrtEp* this_ptr, _Out_ OrtEpDataLayout* preferred_data_layout); /** \brief Given an op with domain `domain` and type `op_type`, determine whether an associated node's data layout * should be converted to `target_data_layout`. @@ -470,11 +469,10 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* ShouldConvertDataLayoutForOp)(_In_ OrtEp* this_ptr, - _In_z_ const char* domain, - _In_z_ const char* op_type, - _In_ OrtEpDataLayout target_data_layout, - _Outptr_ int* should_convert); + ORT_API2_STATUS(ShouldConvertDataLayoutForOp, _In_ OrtEp* this_ptr, + _In_z_ const char* domain, _In_z_ const char* op_type, + _In_ OrtEpDataLayout target_data_layout, + _Outptr_ int* should_convert); /** \brief Set dynamic options on this EP. * @@ -492,10 +490,10 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* SetDynamicOptions)(_In_ OrtEp* this_ptr, - _In_reads_(num_options) const char* const* option_keys, - _In_reads_(num_options) const char* const* option_values, - _In_ size_t num_options); + ORT_API2_STATUS(SetDynamicOptions, _In_ OrtEp* this_ptr, + _In_reads_(num_options) const char* const* option_keys, + _In_reads_(num_options) const char* const* option_values, + _In_ size_t num_options); /** \brief Called by ORT to notify the EP of the start of a run. * @@ -508,8 +506,7 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* OnRunStart)(_In_ OrtEp* this_ptr, - _In_ const OrtRunOptions* run_options); + ORT_API2_STATUS(OnRunStart, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options); /** \brief Called by ORT to notify the EP of the end of a run. * @@ -524,9 +521,7 @@ struct OrtEp { * * \since Version 1.23. */ - OrtStatus*(ORT_API_CALL* OnRunEnd)(_In_ OrtEp* this_ptr, - _In_ const OrtRunOptions* run_options, - _In_ bool sync_stream); + ORT_API2_STATUS(OnRunEnd, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options, _In_ bool sync_stream); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. @@ -586,7 +581,7 @@ struct OrtEpFactory { * * \since Version 1.22. */ - const char*(ORT_API_CALL* GetName)(const OrtEpFactory* this_ptr); + ORT_API_T(const char*, GetName, const OrtEpFactory* this_ptr); /** \brief Get the name of vendor who owns the execution provider that the factory creates. * @@ -597,7 +592,7 @@ struct OrtEpFactory { * * \since Version 1.22. */ - const char*(ORT_API_CALL* GetVendor)(const OrtEpFactory* this_ptr); // return EP vendor + ORT_API_T(const char*, GetVendor, const OrtEpFactory* this_ptr); // return EP vendor /** \brief Get information from the execution provider about OrtHardwareDevice support. * @@ -616,12 +611,12 @@ struct OrtEpFactory { * * \since Version 1.22. */ - OrtStatus*(ORT_API_CALL* GetSupportedDevices)(_In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_ size_t num_devices, - _Inout_ OrtEpDevice** ep_devices, - _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices); + ORT_API2_STATUS(GetSupportedDevices, _In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices); /** \brief Function to create an OrtEp instance for use in a Session. * @@ -647,12 +642,12 @@ struct OrtEpFactory { * * \since Version 1.22. */ - OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep); + ORT_API2_STATUS(CreateEp, _In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep); /** \brief Release the OrtEp instance. * @@ -661,7 +656,18 @@ struct OrtEpFactory { * * \since Version 1.22. */ - void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); + ORT_API_T(void, ReleaseEp, OrtEpFactory* this_ptr, struct OrtEp* ep); + + /** \brief Get the vendor id who owns the execution provider that the factory creates. + * + * This is typically the PCI vendor ID. See https://pcisig.com/membership/member-companies + * + * \param[in] this_ptr The OrtEpFactory instance. + * \return vendor_id The vendor ID of the execution provider the factory creates. + * + * \since Version 1.23. + */ + ORT_API_T(uint32_t, GetVendorId, const OrtEpFactory* this_ptr); /** \brief Get the version of the execution provider that the factory creates. * @@ -675,7 +681,7 @@ struct OrtEpFactory { * * \since Version 1.23. */ - const char*(ORT_API_CALL* GetVersion)(_In_ const OrtEpFactory* this_ptr); + ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr); /** \brief Create an OrtAllocator for the given OrtMemoryInfo. * diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 97e53e6acee5a..314cf76cc8044 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -148,7 +148,9 @@ static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = " // Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking // "0": thread will block if found no job to run -// "1": default, thread will spin a number of times before blocking +// "1": thread will spin a number of times before blocking +// The default is "0" when ORT is built with "ORT_CLIENT_PACKAGE_BUILD" and "1" otherwise. +// Thread spinning is disabled by default for client/on-device workloads to reduce cpu utilization and improve power efficiency. static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning"; static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning"; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 5a837fd1e0bfa..c2085342efd80 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -98,7 +98,7 @@ const calculateInputIndicesImpl = ( `fn calculateInputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} { var input_indices: ${input.type.indices}; var carry = 0u; - for (var i = ${inputShape.length}; i >= 0; i--) { + for (var i = ${inputShape.length - 1}; i >= 0; i--) { let input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}; let steps_i = ${getElementAt('uniforms.steps', 'i', inputShape.length)}; let signs_i = ${getElementAt('uniforms.signs', 'i', inputShape.length)}; diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index c3300f7272bb9..87008f51ff4b9 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -38,7 +38,6 @@ Usage: Options: -d --debug specify the debug build type of the artifacts to download. -l --latest if set, will always use the latest build, even if it is not completed yet. - --webgpu-ep if set, will use the webgpu EP wasm build instead of the default(JSEP) one. -h --help print this message and exit `; @@ -81,9 +80,8 @@ try { // The following code checks both the command line arguments and the npm_config_* environment variables to get the correct values. const debug = args.debug || process.env.npm_config_d || process.env.npm_config_debug; const latest = args.latest || process.env.npm_config_l || process.env.npm_config_latest; -const webgpuEp = args['webgpu-ep'] || process.env.npm_config_webgpu_ep; -const folderName = (debug ? 'Debug_wasm' : 'Release_wasm') + (webgpuEp ? '_webgpu' : ''); +const folderName = debug ? 'Debug_wasm' : 'Release_wasm'; const allowImcomplete = latest; const run = args._[0]; // The first non-option argument @@ -151,13 +149,17 @@ async function downloadArtifactsForRun(run: any): Promise { if (!fs.existsSync(WASM_FOLDER)) { fs.mkdirSync(WASM_FOLDER); } else { - // TODO: revise artifacts download - const filesToDelete = ['ort-wasm-simd-threaded.jsep.mjs', 'ort-wasm-simd-threaded.jsep.wasm']; - if (!folderName.endsWith('_webgpu')) { - filesToDelete.push('ort-wasm-simd-threaded.mjs', 'ort-wasm-simd-threaded.wasm'); - } fs.readdirSync(WASM_FOLDER).forEach((file) => { - if (filesToDelete.includes(file)) { + if ( + [ + 'ort-wasm-simd-threaded.jsep.mjs', + 'ort-wasm-simd-threaded.jsep.wasm', + 'ort-wasm-simd-threaded.jsep.mjs', + 'ort-wasm-simd-threaded.jsep.wasm', + 'ort-wasm-simd-threaded.mjs', + 'ort-wasm-simd-threaded.wasm', + ].includes(file) + ) { const filePath = path.join(WASM_FOLDER, file); console.log(`Deleting old file: ${filePath}`); fs.unlinkSync(filePath); diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 243f611da49e1..80d374d3f0b25 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -53,6 +53,12 @@ enum AttentionKernelType { AttentionKernel_Default }; +enum class QKOutputType : int { + NO_OUTPUT = 0, + BEFORE_SOFTMAX = 1, + AFTER_SOFTMAX = 2 +}; + constexpr bool LAYOUT_BSNH = false; constexpr bool LAYOUT_BNSH = true; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index ac32a4445f3ca..aef47edd5fcd2 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -17,13 +17,13 @@ namespace onnxruntime { namespace contrib { template -inline void ComputeSmoothSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { - MlasComputeSoftmax(score, score, N, D, false, true, tp); +inline void ComputeSmoothSoftmaxInplace(T* score, int D, float sink, ThreadPool* tp) { + MlasComputeSoftmax(score, score, 1, D, false, true, sink, tp); } template inline void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { - MlasComputeSoftmax(score, score, N, D, false, false, tp); + MlasComputeSoftmax(score, score, N, D, false, false, 0.0f, tp); } template diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index c79508cbae273..0d5117709c18a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -35,6 +35,8 @@ class GQAAttentionBase { use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; local_window_size_ = has_local ? static_cast(info.GetAttrOrDefault("local_window_size", -1)) : -1; + + qk_output_ = static_cast(info.GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))); } int num_heads_; // number of attention heads of Q @@ -44,6 +46,7 @@ class GQAAttentionBase { bool do_rotary_; // whether or not to use rotary embeddings bool rotary_interleaved_; int local_window_size_; + int qk_output_; bool use_smooth_softmax_; @@ -51,12 +54,14 @@ class GQAAttentionBase { Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH const T* K, // K data with shape BxN_kvxSxH const T* V, // V data with shape BxN_kvxSxH + const T* head_sink, // Head sink for smooth softmax, nullptr if not used const Tensor* attention_bias, // Attention bias to add to QxK' const Tensor* past_key, // past K input tensor (if not using past state) const Tensor* past_value, // past V input tensor (if not using past state) Tensor* output, // output tensor Tensor* present_key, // present K output tensor (if separating present KV) Tensor* present_value, // present V output tensor (if separating present KV) + Tensor* output_qk, // output QK buffer const Tensor* seqlens_k, // past sequence lengths tensor GroupQueryAttentionParameters& parameters, // attention parameters AllocatorPtr allocator, // allocator for temporary tensors @@ -64,6 +69,7 @@ class GQAAttentionBase { const bool is_prompt = parameters.is_first_prompt; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; + const int total_sequence_length = parameters.total_sequence_length; const int head_size = parameters.head_size; const int hidden_size = parameters.hidden_size; const bool packed_qkv = parameters.is_packed_qkv; @@ -79,8 +85,7 @@ class GQAAttentionBase { // Compute the attention score. bool gqa_mlas_supported = MlasGQASupported(CblasNoTrans, CblasTrans) && MlasGQASupported(CblasNoTrans, CblasNoTrans); - size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * - (gqa_mlas_supported ? sizeof(T) : sizeof(float)); + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * (gqa_mlas_supported ? sizeof(T) : sizeof(float)); auto attention_probs = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); @@ -96,11 +101,13 @@ class GQAAttentionBase { const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; + T* output_qk_buffer = output_qk != nullptr ? output_qk->MutableData() : nullptr; + if (gqa_mlas_supported) { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, - batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, - head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, - tp, allocator); + ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, + batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, + seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, + past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -110,10 +117,10 @@ class GQAAttentionBase { hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); } else { - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), attention_bias_data, - batch_size, sequence_length, attention_bias_shape, seqlen_past_kv_cache, seqlen_present_kv_cache, - head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, is_prompt, - tp, allocator); + ComputeAttentionProbs(static_cast(attention_probs), Q, k, head_sink, seqlens_k->Data(), attention_bias_data, + batch_size, sequence_length, total_sequence_length, attention_bias_shape, seqlen_past_kv_cache, + seqlen_present_kv_cache, head_size, past_key_data, present_key_data, output_qk_buffer, + past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; @@ -136,16 +143,19 @@ class GQAAttentionBase { void ComputeAttentionProbs(U* attention_probs, // output buffer with size BxNxSxT const T* Q, // Q data. Its size is BxNxSxH const T* K, // k data. Its size is BxNxLxH + const T* head_sink, // for smooth softmax. Its size is N. const int32_t* seqlens_k, // total - 1 sequence lengths tensor const T* attention_bias, // optional attention bias const size_t batch_size, // batch size of self-attention const size_t sequence_length, // sequence length of self-attention (S) + const size_t total_sequence_length, // total sequence length (T) const gsl::span attention_bias_shape, // shape of the attention bias const size_t past_buffer_sequence_length, // sequence length of past state const size_t present_buffer_sequence_length, // sequence length of present state const size_t head_size, // head size of self-attention const T* past_key, // past key only T* present_key, // present key only + T* output_qk, // output QK buffer const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt @@ -197,6 +207,11 @@ class GQAAttentionBase { const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; U* output = attention_probs + output_offset; + T* output_qk_thread = nullptr; + if (output_qk != nullptr) { + const ptrdiff_t output_qk_offset = SafeInt(sequence_length) * total_sequence_length * (batch_index * num_heads_ + head_index); + output_qk_thread = output_qk + output_qk_offset; + } // Compute attention bias offset based on the batch and head indexes // Attention bias is of shape (B or 1, H or 1, S, T) so handle broadcasting @@ -310,12 +325,6 @@ class GQAAttentionBase { } } - if (use_smooth_softmax_) { - ComputeSmoothSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); - } else { - ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); - } - // set causal [seq_causal_length, total_seqlen) to 0.f for (size_t total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) { if constexpr (std::is_same::value) { @@ -325,11 +334,30 @@ class GQAAttentionBase { } } + if (qk_output_ == static_cast(QKOutputType::BEFORE_SOFTMAX)) { + WriteOutputQKHeadChunk(output_qk_thread, output_softmax, total_sequence_length); + } + + if (use_smooth_softmax_ || head_sink != nullptr) { + float sink = (head_sink != nullptr) ? static_cast(head_sink[head_index]) : 0.0f; + ComputeSmoothSoftmaxInplace(output_softmax + start_offset, static_cast(window_size), sink, nullptr); + } else { + ComputeAttentionSoftmaxInplace(output_softmax + start_offset, 1, static_cast(window_size), nullptr); + } + + if (qk_output_ == static_cast(QKOutputType::AFTER_SOFTMAX)) { + WriteOutputQKHeadChunk(output_qk_thread, output_softmax, total_sequence_length); + } + output_softmax += present_buffer_sequence_length; if (attention_bias_thread != nullptr) { attention_bias_thread += attention_total_seqlen; } + + if (output_qk_thread != nullptr) { + output_qk_thread += total_sequence_length; + } } } }); @@ -455,6 +483,20 @@ class GQAAttentionBase { SafeInt(sequence_length) * batch_size * num_heads_ * head_size); } } + + template + void WriteOutputQKHeadChunk(T* output_qk, const U* attention_probs, size_t total_sequence_length) const { + if (output_qk == nullptr) { + return; + } + + if constexpr (std::is_same_v) { + std::memcpy(output_qk, attention_probs, SafeInt(total_sequence_length) * sizeof(T)); + } else { + static_assert(std::is_same_v && std::is_same_v); + MlasConvertFloatToHalfBuffer(static_cast(attention_probs), output_qk, total_sequence_length); + } + } }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index a912bd6e6b43c..eb1560ac8e341 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -95,6 +95,11 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { Tensor* present_k = context->Output(1, present_k_shape); Tensor* present_v = context->Output(2, present_v_shape); + std::vector output_qk_shape{static_cast(batch_size), static_cast(num_heads_), static_cast(parameters.sequence_length), static_cast(parameters.total_sequence_length)}; + Tensor* output_qk = context->Output(3, output_qk_shape); + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckOutputs(output_qk, qk_output_)); + AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -206,10 +211,12 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + const T* head_sink_data = (head_sink != nullptr) ? head_sink->Data() : nullptr; + // Compute the attention score and apply the score to V return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), - attention_bias, past_key, past_value, output, present_k, present_v, - seqlens_k, parameters, allocator, context); + head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, + output_qk, seqlens_k, parameters, allocator, context); } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 0f66119540b03..f01ce985658aa 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -398,6 +398,37 @@ Status CheckCustomAttentionInputs(const T* position_ids, return Status::OK(); } +template +Status CheckOutputs(const T* output_qk, int qk_output) { + const bool is_valid_qk_output = qk_output == static_cast(QKOutputType::NO_OUTPUT) || + qk_output == static_cast(QKOutputType::BEFORE_SOFTMAX) || + qk_output == static_cast(QKOutputType::AFTER_SOFTMAX); + if (!is_valid_qk_output) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "qk_output attribute received unsupported value ", qk_output); + } + + if (qk_output != static_cast(QKOutputType::NO_OUTPUT) && output_qk == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "qk_output attribute was configured but output buffer was not provided"); + } + + return Status::OK(); +} + +inline Status CheckNoQKOutput(int num_outputs, int qk_output) { + if (num_outputs > 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "output_qk optional output is not supported"); + } + + if (qk_output != static_cast(QKOutputType::NO_OUTPUT)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "qk_output attribute is not supported"); + } + + return Status::OK(); +} + } // namespace group_query_attention_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 68c4b01d2db20..9cb93cbcd3f32 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -109,6 +109,12 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; + // The current GQA CUDA implementation will never be able to have a QK output. + // GQA CUDA uses either flash attention or memory efficient attention. Neither kernel supports returning the QK output. + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( + context->OutputCount(), + static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 85aef55908506..09a6550549614 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -213,6 +213,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( + context->OutputCount(), + static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index f3334b13dc645..1f039177b0a21 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -178,6 +178,10 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& head_sink, params)); + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( + context.OutputCount(), + static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); + WebgpuAttentionParameters parameters(params); TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size_); diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 8ea593f107833..c4667d53c0674 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -170,7 +170,7 @@ std::string CPUIDInfo::GetX86Vendor(int32_t* data) { uint32_t CPUIDInfo::GetVendorId(const std::string& vendor) { if (vendor == "GenuineIntel") return 0x8086; - if (vendor == "GenuineAMD") return 0x1022; + if (vendor == "AuthenticAMD") return 0x1022; if (vendor.find("Qualcomm") == 0) return 'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24); if (vendor.find("NV") == 0) return 0x10DE; return 0; diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index c3dd9321ebb0b..47fbe08da41ff 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -247,8 +247,11 @@ struct OrtNode { /// Gets the node's subgraphs (e.g., subgraphs contained by an If or Loop node). /// /// Buffer into which to copy the subgraphs. + /// Optional buffer into which to copy the attribute name for each subgraph. + /// If set, must point to a buffer with the same number of elements as `subgraphs`. /// A status indicating success or an error. - virtual onnxruntime::Status GetSubgraphs(gsl::span subgraphs) const = 0; + virtual onnxruntime::Status GetSubgraphs(gsl::span subgraphs, + const char** opt_attribute_names) const = 0; /// /// Gets the node's parent graph, which is the graph that contains this node. @@ -280,6 +283,23 @@ struct OrtGraph { /// The model's ONNX IR version. virtual int64_t GetOnnxIRVersion() const = 0; + /// + /// Gets the number of operator sets (domain, opset version) the graph's model relies on. + /// + /// Output parameter set to the number of operator sets. + /// A status indicating success or an error. + virtual onnxruntime::Status GetNumOperatorSets(size_t& num_operator_sets) const = 0; + + /// + /// Gets the operator sets the graph's model relies on. An operator set is uniquely identified by a + /// (domain, opset version) pair. + /// + /// Buffer into which to copy the domains. + /// Buffer into which to copy the opset version for each domain. + /// A status indicating success or an error. + virtual onnxruntime::Status GetOperatorSets(gsl::span domains, + gsl::span opset_versions) const = 0; + /// /// Returns the number of graph inputs, including initializers that appear in the list of graph inputs. /// diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index f2757c2c96471..e2b17aa84d2b1 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -6,6 +6,7 @@ #include "core/graph/contrib_ops/quantization_defs.h" #include "core/graph/contrib_ops/onnx_function_util.h" #include "core/graph/contrib_ops/shape_inference_functions.h" +#include "contrib_ops/cpu/bert/attention_common.h" // Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from // ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build #if defined(_WIN32) && !defined(NDEBUG) @@ -232,7 +233,8 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c // Type and shape inference for group query attention and sparse attention. void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index = -1, - int use_max_past_present_buffer = -1) { + int use_max_past_present_buffer = -1, + int output_qk_index = -1) { ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); int64_t kv_sequence_length = -1; @@ -277,13 +279,20 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte } } - if (ctx.getNumOutputs() > 1) { // has present output + if (ctx.getNumOutputs() >= 3) { // has present output // copy the type from query to present key ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1); // copy the type from query to present value ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2); + int64_t total_sequence_length_value = 0; + const auto* total_sequence_length_data = ctx.getInputData(6); + if (total_sequence_length_data != nullptr) { + const auto& data = ParseData(total_sequence_length_data); + total_sequence_length_value = static_cast(data[0]); + } + if (past_key_index >= 0 && hasInputShape(ctx, past_key_index)) { auto& past_shape = getInputShape(ctx, past_key_index); auto& past_dims = past_shape.dim(); @@ -299,30 +308,25 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); } else if (use_max_past_present_buffer == 0) { if (kv_sequence_length > 0 && past_dims[2].has_dim_value()) { - int64_t total_sequence_length = kv_sequence_length + past_dims[2].dim_value(); + const int64_t present_sequence_length = kv_sequence_length + past_dims[2].dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; for (auto& dim : past_dims) { *present_shape.add_dim() = dim; } - // shape of present key/value is (batch_size, kv_num_heads, total_sequence_length, head_size) - present_shape.mutable_dim(2)->set_dim_value(total_sequence_length); + // shape of present key/value is (batch_size, kv_num_heads, present_sequence_length, head_size) + present_shape.mutable_dim(2)->set_dim_value(present_sequence_length); updateOutputShape(ctx, 1, present_shape); updateOutputShape(ctx, 2, present_shape); } } else if (use_max_past_present_buffer == -1) { - const auto* total_sequence_length_data = ctx.getInputData(6); - if (total_sequence_length_data != nullptr && past_dims[2].has_dim_value()) { - int64_t total_sequence_length_value = 0; - const auto& data = ParseData(total_sequence_length_data); - total_sequence_length_value = static_cast(data[0]); - + if (total_sequence_length_value > 0 && past_dims[2].has_dim_value()) { // present_sequence_length = max(past_sequence_length, total_sequence_length) - int64_t present_sequence_length = total_sequence_length_value > past_dims[2].dim_value() - ? total_sequence_length_value - : past_dims[2].dim_value(); + const int64_t present_sequence_length = total_sequence_length_value > past_dims[2].dim_value() + ? total_sequence_length_value + : past_dims[2].dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; for (auto& dim : past_dims) { @@ -336,19 +340,50 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte updateOutputShape(ctx, 2, present_shape); } } + + if (output_qk_index >= 0) { + const bool did_supply_qk_buffer = ctx.hasOutput(output_qk_index); + const int64_t qk_output_type = getAttribute(ctx, "qk_output", static_cast(QKOutputType::NO_OUTPUT)); + + if (qk_output_type == static_cast(QKOutputType::NO_OUTPUT) && did_supply_qk_buffer) { + fail_shape_inference("Output QK buffer was provided but qk_output attribute was not configured"); + } + + if (qk_output_type != static_cast(QKOutputType::NO_OUTPUT) && !did_supply_qk_buffer) { + fail_shape_inference("Output QK buffer was not provided but qk_output attribute was configured"); + } + + int64_t num_heads = getAttribute(ctx, "num_heads", 0); + if (did_supply_qk_buffer && hasInputShape(ctx, 0) && total_sequence_length_value > 0 && num_heads > 0) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, output_qk_index); + + auto& query_shape = getInputShape(ctx, 0); + auto& query_dims = query_shape.dim(); + + if (query_dims[0].has_dim_value() && query_dims[1].has_dim_value()) { + ONNX_NAMESPACE::TensorShapeProto output_qk_shape; + *output_qk_shape.add_dim() = query_dims[0]; // batch_size + output_qk_shape.add_dim()->set_dim_value(num_heads); // num_heads + *output_qk_shape.add_dim() = query_dims[1]; // sequence_length + output_qk_shape.add_dim()->set_dim_value(total_sequence_length_value); // total_sequence_length + updateOutputShape(ctx, output_qk_index, output_qk_shape); + } + } + } } } } -void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) { +void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index, int qk_output_index) { // TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not constexpr int use_max_past_present_buffer = -1; - BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer); + BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer, qk_output_index); } void SparseAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) { constexpr int use_max_past_present_buffer = 1; - BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer); + constexpr int qk_output_index = -1; + BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer, qk_output_index); } constexpr const char* Attention_ver1_doc = R"DOC( @@ -1127,6 +1162,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Use a smooth factor in softmax.", AttributeProto::INT, static_cast(-1)) + .Attr("qk_output", + "Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).", + AttributeProto::INT, + static_cast(QKOutputType::NO_OUTPUT)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" @@ -1184,6 +1223,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)", "T", OpSchema::Optional) + .Input(11, + "head_sink", + "1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", @@ -1200,10 +1244,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", "T") + .Output(3, + "output_qk", + "Values of QK matrix multiplication, either before or after softmax normalization", + "T", + OpSchema::Optional) .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - GroupQueryAttentionTypeAndShapeInference(ctx, 3); + GroupQueryAttentionTypeAndShapeInference(ctx, 3, 3); })); constexpr const char* PagedAttention_ver1_doc = R"DOC( diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 698c7422a1e2a..f57543416a68f 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -129,11 +129,12 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph, ConvertNodeArgsToValueInfos(ep_graph, value_infos_map, node_implicit_inputs, ep_node_implicit_inputs); - std::vector> node_subgraphs = node.GetSubgraphs(); - ep_node_subgraphs.reserve(node_subgraphs.size()); + std::unordered_map> subgraphs_map = node.GetAttributeNameToSubgraphMap(); + ep_node_subgraphs.reserve(subgraphs_map.size()); - for (gsl::not_null subgraph : node_subgraphs) { + for (const auto& [attr_name, subgraph] : subgraphs_map) { SubgraphState subgraph_state; + subgraph_state.attribute_name = attr_name; subgraph_state.subgraph_viewer = std::make_unique(*subgraph); ORT_RETURN_IF_ERROR(EpGraph::Create(*subgraph_state.subgraph_viewer, subgraph_state.ep_subgraph)); subgraph_state.ep_subgraph->SetParentNode(ep_node.get()); @@ -233,12 +234,17 @@ Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { return Status::OK(); } -Status EpNode::GetSubgraphs(gsl::span dst) const { +Status EpNode::GetSubgraphs(gsl::span subgraphs, + const char** opt_attribute_names) const { const size_t num_subgraphs = subgraphs_.size(); - ORT_RETURN_IF_ERROR((CheckCopyDestination("node attributes", num_subgraphs, dst))); + ORT_RETURN_IF_ERROR((CheckCopyDestination("node subgraphs", num_subgraphs, subgraphs))); for (size_t i = 0; i < num_subgraphs; ++i) { - dst[i] = subgraphs_[i].ep_subgraph.get(); + subgraphs[i] = subgraphs_[i].ep_subgraph.get(); + + if (opt_attribute_names) { + opt_attribute_names[i] = subgraphs_[i].attribute_name.c_str(); + } } return Status::OK(); @@ -270,6 +276,10 @@ const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const { } } +const std::string& EpNode::GetEpName() const { + return node_.GetExecutionProviderType(); +} + // // EpValueInfo // @@ -499,10 +509,34 @@ void EpGraph::IndexToEpNodeMap::SetEpNode(NodeIndex node_index, EpNode* ep_node) EpGraph::EpGraph(const GraphViewer& graph_viewer, PrivateTag) : OrtGraph(OrtGraphIrApi::kEpApi), graph_viewer_(graph_viewer) {} +EpGraph::EpGraph(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + PrivateTag) + : OrtGraph(OrtGraphIrApi::kEpApi), + graph_viewer_(*graph_viewer.get()), + owned_graph_viewer_(std::move(graph_viewer)), + owned_indexed_sub_graph_(std::move(indexed_sub_graph)) {} + // Static class function to create a std::unique_ptr. Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { auto ep_graph = std::make_unique(graph_viewer, PrivateTag{}); + return CreateImpl(std::move(ep_graph), graph_viewer, result); +} + +// Static class function to create a std::unique_ptr. +Status EpGraph::Create(std::unique_ptr src_graph_viewer, + std::unique_ptr src_indexed_sub_graph, + /*out*/ std::unique_ptr& result) { + auto& graph_viewer = *src_graph_viewer.get(); + auto ep_graph = std::make_unique(std::move(src_graph_viewer), + std::move(src_indexed_sub_graph), + PrivateTag{}); + + return CreateImpl(std::move(ep_graph), graph_viewer, result); +} + +Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance(); std::unordered_map> value_infos_map; @@ -660,6 +694,43 @@ const std::string& EpGraph::GetName() const { return graph_viewer_.Name(); } int64_t EpGraph::GetOnnxIRVersion() const { return graph_viewer_.GetOnnxIRVersion(); } +Status EpGraph::GetNumOperatorSets(size_t& num_operator_sets) const { + num_operator_sets = graph_viewer_.DomainToVersionMap().size(); + return Status::OK(); +} + +Status EpGraph::GetOperatorSets(gsl::span domains, + gsl::span opset_versions) const { + const std::unordered_map& domain_to_version = graph_viewer_.DomainToVersionMap(); + size_t num_operator_sets = domain_to_version.size(); + + ORT_RETURN_IF_ERROR((CheckCopyDestination("operator set domains", num_operator_sets, domains))); + ORT_RETURN_IF_ERROR((CheckCopyDestination("operator set versions", num_operator_sets, opset_versions))); + + // Collect (domain, version) pairs and sort them by domain to ensure user always gets a stable ordering. + std::vector> pairs; + pairs.reserve(num_operator_sets); + + for (const auto& [domain, version] : domain_to_version) { + pairs.emplace_back(domain.c_str(), version); + } + + std::sort(pairs.begin(), pairs.end(), + [](const std::pair& a, const std::pair& b) -> bool { + return std::strcmp(a.first, b.first) < 0; + }); + + // Copy sorted (domain, version) pairs into the destination buffers. + size_t index = 0; + for (const auto& [domain_c_str, version] : pairs) { + domains[index] = domain_c_str; + opset_versions[index] = version; + index++; + } + + return Status::OK(); +} + size_t EpGraph::GetNumInputs() const { return inputs_.size(); } diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 4240f5636b7ae..d3921e051e18a 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -111,6 +111,7 @@ struct EpNode : public OrtNode { struct SubgraphState { SubgraphState() = default; SubgraphState(SubgraphState&& other) = default; + std::string attribute_name; std::unique_ptr subgraph_viewer; // The graph_viewer wrapped by EpGraph below. std::unique_ptr ep_subgraph; }; @@ -182,7 +183,8 @@ struct EpNode : public OrtNode { Status GetNumSubgraphs(size_t& num_subgraphs) const override; // Gets the subgraphs contained by this node. - Status GetSubgraphs(gsl::span subgraphs) const override; + Status GetSubgraphs(gsl::span subgraphs, + const char** opt_attribute_names) const override; // Gets this node's parent graph, which is the graph that directly contains this node. Status GetGraph(const OrtGraph*& parent_graph) const override; @@ -206,6 +208,9 @@ struct EpNode : public OrtNode { // Helper that gets the node's attributes by name. const OrtOpAttr* GetAttribute(const std::string& name) const; + // Helper that gets the execution provider name that this node is assigned to run on. + const std::string& GetEpName() const; + private: // Back pointer to containing graph. Useful when traversing through nested subgraphs. // Will be nullptr if the EpNode was created without an owning graph. @@ -249,15 +254,32 @@ struct EpGraph : public OrtGraph { public: EpGraph(const GraphViewer& graph_viewer, PrivateTag); + EpGraph(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + PrivateTag); /// /// Creates an instance of EpGraph, which wraps a GraphViewer. + /// This call is used when creating an EpGraph from a GraphViewer instance. The GraphViewer instance is not onwed by this EpGraph. /// /// /// /// static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + /// + /// Creates an instance of EpGraph, which wraps a GraphViewer. + /// This call is used when creating an EpGraph from a subset of nodes in another EpGraph. + /// In this case, due to the implementation of OrtApis::Graph_GetGraphView, the new EpGraph instance + /// must take ownership of both the GraphViewer and IndexedSubGraph. + /// + /// + /// + /// + static Status Create(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + /*out*/ std::unique_ptr& result); + // Defines ToExternal() and ToInternal() functions to convert between OrtGraph and EpGraph. DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtGraph, EpGraph, OrtGraphIrApi::kEpApi) @@ -271,6 +293,14 @@ struct EpGraph : public OrtGraph { // Returns the model's ONNX IR version. int64_t GetOnnxIRVersion() const override; + // Gets the number of operator sets that the graph's model uses. + Status GetNumOperatorSets(size_t& num_operator_sets) const override; + + // Gets the operator sets that the graph's model uses. An operator set is uniquely identified by a + // (domain, opset version) pair. + Status GetOperatorSets(gsl::span domains, + gsl::span opset_versions) const override; + // Get the number of graph inputs, including initializers that are listed as graph inputs. size_t GetNumInputs() const override; @@ -321,9 +351,22 @@ struct EpGraph : public OrtGraph { const OrtValue* GetInitializerValue(std::string_view name) const; private: + /// + /// The real implementation of creating an EpGraph instance. + /// Please use one of the above 'Create' functions that internally call this function, and avoid calling this function directly. + /// + /// + /// + /// + /// + static Status CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + const GraphViewer& graph_viewer_; const EpNode* parent_node_ = nullptr; + std::unique_ptr owned_graph_viewer_ = nullptr; + std::unique_ptr owned_indexed_sub_graph_ = nullptr; + std::vector> nodes_; IndexToEpNodeMap index_to_ep_node_; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index ca40bad2b4250..4d3091520d876 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,6 +1818,10 @@ NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name return node_arg; } +const NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const { + return const_cast(this)->GetNodeArgIncludingParentGraphs(node_arg_name); +} + void Graph::ReverseDFSFrom(gsl::span from, const std::function& enter, const std::function& leave, diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 1842c2b4a0d1f..948ebaa5f7e15 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -168,7 +168,15 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) filtered_node_inputs_including_initializers_.reserve(metadef->inputs.size()); for (const auto& input : metadef->inputs) { - const auto* nodearg = graph.GetNodeArg(input); + // NodeArgs from the current scope or any outer scopes should be handled correctly. + // + // There is an edge case where the model consists of a graph with subgraphs nested across three levels. + // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). + // When constructing a new GraphViewer for the second- and third-layer subgraphs, + // the second-layer graph may not have the corresponding value_info for that first-layer input, + // because the second-layer graph itself doesn't consume it. + // Therefore, when working within the second-layer graph, we need to search outer scopes for the missing value_info. + const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(input); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Input not found:", input); filtered_node_inputs_including_initializers_.push_back(nodearg); if (!graph.IsInitializedTensor(input)) { @@ -177,7 +185,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) } for (const auto& output : metadef->outputs) { - const auto* nodearg = graph.GetNodeArg(output); + const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(output); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Output not found:", output); filtered_node_outputs_.push_back(nodearg); } diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 6330a42c115db..6e7e17374bb59 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -136,7 +136,8 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); } - Status GetSubgraphs(gsl::span /*subgraphs*/) const override { + Status GetSubgraphs(gsl::span /*subgraphs*/, + const char** /*opt_attribute_names*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); } @@ -176,6 +177,17 @@ struct ModelEditorGraph : public OrtGraph { return ONNX_NAMESPACE::Version::IR_VERSION; } + Status GetNumOperatorSets(size_t& /*num_operator_sets*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the graph's operator sets."); + } + + Status GetOperatorSets(gsl::span /*domains*/, + gsl::span /*opset_versions*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the graph's operator sets."); + } + size_t GetNumInputs() const override { return inputs.size(); } Status GetInputs(gsl::span /*result*/) const override { diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 3575e30721af7..4d85c35461825 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1020,6 +1020,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ); @@ -1223,6 +1224,21 @@ MlasQuantizeLinearS4( int8_t ZeroPoint ); +// +// Linear dequantization routines. +// + +template +void +MLASCALL +MlasDequantizeLinear( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ); + /** * @brief Requantize a block of the intermediate buffer to the output buffer, * optionally adding the supplied bias diff --git a/onnxruntime/core/mlas/lib/compute.cpp b/onnxruntime/core/mlas/lib/compute.cpp index 96a2398796777..669c73d2b9c06 100644 --- a/onnxruntime/core/mlas/lib/compute.cpp +++ b/onnxruntime/core/mlas/lib/compute.cpp @@ -74,6 +74,7 @@ struct MLAS_SOFTMAX_WORK_BLOCK { ptrdiff_t ThreadCountN; bool LogSoftmax; bool SmoothSoftmax; + float Sink; const T* Input; T* Output; size_t N; @@ -850,6 +851,7 @@ Return Value: const size_t D = WorkBlock->D; const bool LogSoftmax = WorkBlock->LogSoftmax; const bool SmoothSoftmax = WorkBlock->SmoothSoftmax; + const float Sink = WorkBlock->Sink; const float* Input = WorkBlock->Input + n * D; float* Output = WorkBlock->Output + n * D; @@ -880,11 +882,12 @@ Return Value: #else float Maximum = MlasReduceMaximumF32Kernel(Input, D); #endif - float NegativeMaximum = -Maximum; - if (SmoothSoftmax && NegativeMaximum > 0.0f) { - NegativeMaximum = 0.0f; + if (SmoothSoftmax && Sink > Maximum) { + Maximum = Sink; } + float NegativeMaximum = -Maximum; + // // Compute the exponential function for each element of the row (save to Temp if provided) and // compute the sum of these exponential functions. @@ -897,7 +900,7 @@ Return Value: #endif if (SmoothSoftmax) { - Accumulation += expf(NegativeMaximum); + Accumulation += expf(Sink + NegativeMaximum); } if (LogSoftmax) { @@ -1014,6 +1017,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ) /*++ @@ -1039,6 +1043,8 @@ Routine Description: SmoothSoftmax - Supplies true if a smooth factor is used in softmax operation. + Sink - Supplies the smooth factor to use in the softmax operation. + ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. @@ -1060,6 +1066,7 @@ Return Value: WorkBlock.Output = Output; WorkBlock.N = N; WorkBlock.D = D; + WorkBlock.Sink = Sink; // // Compute the number of target threads given the complexity of the softmax @@ -1097,6 +1104,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ); @@ -1110,6 +1118,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/mlas/lib/dequantize.cpp b/onnxruntime/core/mlas/lib/dequantize.cpp new file mode 100644 index 0000000000000..175d3f668ac39 --- /dev/null +++ b/onnxruntime/core/mlas/lib/dequantize.cpp @@ -0,0 +1,395 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + dequantize.cpp + +Abstract: + + This module implements routines to dequantize buffers. + + The dequantization formula as specified in the ONNX operator documentation is: + + Output = (Input - ZeroPoint) * Scale + +--*/ + +#include "mlasi.h" + +// +// DequantizeLinear reference implementation using the C++ runtime. +// + +template +static +MLAS_FORCEINLINE +void +MlasDequantizeLinearRefImpl( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ) +/*++ + +Routine Description: + + This routine quantizes the input buffer using the supplied quantization + parameters. + +Arguments: + + Input - Supplies the input buffer with quantized data. + + Output - Supplies the output buffer. + + N - Supplies the number of elements to process. + + Scale - Supplies the quantization scale. + + ZeroPoint - Supplies the quantization zero point value. + +Return Value: + + None. + +--*/ +{ + int32_t ZeroPointS32 = static_cast(ZeroPoint); + + for (size_t n = 0; n < N; n++) { + Output[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; + } +} + +#if defined(MLAS_SSE2_INTRINSICS) +// Implementation for Intel SSE 2. Refer to the Intel Intrisics Guide: +// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html + +void +MLASCALL +MlasDequantizeLinearS8Kernel( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); + const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s + const __m128i Zeros = _mm_setzero_si128(); + + while (N >= 16) { + // Load a vector of 16 int8s: [0 ... 15] + __m128i VectorS8 = _mm_loadu_si128(reinterpret_cast(Input)); + + // Sign-extend into 2 vectors of 8 int16s + __m128i SignMaskS8 = _mm_cmpgt_epi8(Zeros, VectorS8); // 0xFF for every negative byte in VectorS8 + __m128i VectorS16_0 = _mm_unpacklo_epi8(VectorS8, SignMaskS8); // [0 ... 7] + __m128i VectorS16_1 = _mm_unpackhi_epi8(VectorS8, SignMaskS8); // [8 ... 15] + + // Subtract the zero-points in int16 domain. + VectorS16_0 = _mm_sub_epi16(VectorS16_0, ZeroPointS16Vector); + VectorS16_1 = _mm_sub_epi16(VectorS16_1, ZeroPointS16Vector); + + // Sign-extend into 4 vectors of 4 int32s + __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); + __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] + __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] + + __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); + __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] + __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); + __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); + __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); + __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + _mm_storeu_ps(Output + 0, VectorF32_0); + _mm_storeu_ps(Output + 4, VectorF32_1); + _mm_storeu_ps(Output + 8, VectorF32_2); + _mm_storeu_ps(Output + 12, VectorF32_3); + + Input += 16; + Output += 16; + N -= 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasDequantizeLinearU8Kernel( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); + const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s + const __m128i Zeros = _mm_setzero_si128(); + + while (N >= 16) { + // Load a vector of 16 uint8s: [0 ... 15] + __m128i VectorU8 = _mm_loadu_si128(reinterpret_cast(Input)); + + // Zero-extend into 2 vectors of 8 uint16s + __m128i VectorU16_0 = _mm_unpacklo_epi8(VectorU8, Zeros); // [0 ... 7] + __m128i VectorU16_1 = _mm_unpackhi_epi8(VectorU8, Zeros); // [8 ... 15] + + // Subtract the zero-points as uint16s. Due to two's compliment, negative results can be reinterpreted as int16 + __m128i VectorS16_0 = _mm_sub_epi16(VectorU16_0, ZeroPointS16Vector); + __m128i VectorS16_1 = _mm_sub_epi16(VectorU16_1, ZeroPointS16Vector); + + // Sign-extend into 4 vectors of 4 int32s + __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); + __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] + __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] + + __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); + __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] + __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); + __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); + __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); + __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + _mm_storeu_ps(Output + 0, VectorF32_0); + _mm_storeu_ps(Output + 4, VectorF32_1); + _mm_storeu_ps(Output + 8, VectorF32_2); + _mm_storeu_ps(Output + 12, VectorF32_3); + + Input += 16; + Output += 16; + N -= 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().DequantizeLinearS8Kernel( +#else + MlasDequantizeLinearS8Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().DequantizeLinearU8Kernel( +#else + MlasDequantizeLinearU8Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} +#elif defined(MLAS_NEON64_INTRINSICS) +// Implementation for ARM64 NEON. Refer to the ARM instrinsics guide: +// https://developer.arm.com/architectures/instruction-sets/intrinsics/ + +void +MLASCALL +MlasDequantizeLinearS8Kernel( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); + const int16x8_t ZeroPointVector = vdupq_n_s16(ZeroPoint); // Broadcast ZeroPoint (sign-extended to 16bits) + + while (N >= 16) { + // Load a vector of 16 int8s: [0 ... 15] + int8x16_t VectorS8 = vld1q_s8(Input); + + // Sign-extend into 2 vectors of 8 int16s + int16x8_t VectorS16_0 = vmovl_s8(vget_low_s8(VectorS8)); // [0 ... 7] + int16x8_t VectorS16_1 = vmovl_s8(vget_high_s8(VectorS8)); // [8 ... 15] + + // Subtract the zero-points in int16 domain. + VectorS16_0 = vsubq_s16(VectorS16_0, ZeroPointVector); + VectorS16_1 = vsubq_s16(VectorS16_1, ZeroPointVector); + + // Sign-extend into 4 vectors of 4 int32s + int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] + int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] + int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] + int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); + float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); + float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); + float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + vst1q_f32(Output + 0, VectorF32_0); + vst1q_f32(Output + 4, VectorF32_1); + vst1q_f32(Output + 8, VectorF32_2); + vst1q_f32(Output + 12, VectorF32_3); + + N -= 16; + Input += 16; + Output += 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasDequantizeLinearU8Kernel( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); + const uint8x8_t ZeroPointVector = vdup_n_u8(ZeroPoint); // Broadcast ZeroPoint to 8 uint8s + + while (N >= 16) { + // Load a vector of 16 uint8s: [0 ... 15] + uint8x16_t VectorU8 = vld1q_u8(Input); + + // Subtract zero-point. The vsubl_u8 instruction zero-extends its arguments to uint16 first. + // The reinterpret from uint16x8 to int16x8 is actually a NOP. + int16x8_t VectorS16_0 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(VectorU8), ZeroPointVector)); // [0 ... 7] + int16x8_t VectorS16_1 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(VectorU8), ZeroPointVector)); // [8 ... 15] + + // Sign-extend into 4 vectors of 4 int32s + int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] + int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] + int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] + int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); + float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); + float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); + float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + vst1q_f32(Output + 0, VectorF32_0); + vst1q_f32(Output + 4, VectorF32_1); + vst1q_f32(Output + 8, VectorF32_2); + vst1q_f32(Output + 12, VectorF32_3); + + N -= 16; + Input += 16; + Output += 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasDequantizeLinearS8Kernel(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + MlasDequantizeLinearU8Kernel(Input, Output, N, Scale, ZeroPoint); +} +#else +// Implementation that uses the scalar reference implementation. + +template +void +MLASCALL +MlasDequantizeLinear( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ) +{ + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ); + +template +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ); + +#endif diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 0af3cd2e33b02..0879d1b0ba510 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -747,6 +747,24 @@ void float Scale, int8_t ZeroPoint); +typedef +void +(MLASCALL MLAS_DEQUANTIZE_LINEAR_U8_KERNEL)( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint); + +typedef +void +(MLASCALL MLAS_DEQUANTIZE_LINEAR_S8_KERNEL)( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint); + template struct MLAS_QUANT_KERNEL { @@ -903,6 +921,8 @@ extern "C" { MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; #if defined(MLAS_TARGET_AMD64) + MLAS_DEQUANTIZE_LINEAR_S8_KERNEL MlasDequantizeLinearS8Kernel; + MLAS_DEQUANTIZE_LINEAR_U8_KERNEL MlasDequantizeLinearU8Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelAvx512F; @@ -1246,6 +1266,8 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; + MLAS_DEQUANTIZE_LINEAR_S8_KERNEL* DequantizeLinearS8Kernel; + MLAS_DEQUANTIZE_LINEAR_U8_KERNEL* DequantizeLinearU8Kernel; uint32_t NchwcBlockSize; uint32_t PreferredBufferAlignment; int32_t MaximumThreadCount; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 45d3a876beb86..45bba5363d4f2 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -285,6 +285,8 @@ Return Value: this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; + this->DequantizeLinearS8Kernel = MlasDequantizeLinearS8Kernel; + this->DequantizeLinearU8Kernel = MlasDequantizeLinearU8Kernel; #ifndef __APPLE__ #ifndef FORCE_GENERIC_ALGORITHMS this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse; diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index dcc030cb3467d..fa645939a6395 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -89,23 +89,10 @@ uint64_t GetLuidKey(LUID luid) { return (uint64_t(luid.HighPart) << 32) | luid.LowPart; } -// Converts a wide string (up to 4 characters) representing a hardware ID component (e.g., "ABCD" from "VEN_ABCD") -// into a uint32_t. The conversion is done in a little-endian manner, meaning the first character -// of the string becomes the least significant byte of the integer, and the fourth character -// becomes the most significant byte. -uint32_t WStringToUint32Id(const std::wstring& vendor_name) { - uint32_t vendor_id = 0; - for (size_t i = 0; i < 4 && i < vendor_name.size(); ++i) { - // For little-endian, place each character at the appropriate byte position - // First character goes into lowest byte, last character into highest byte - vendor_id |= static_cast(vendor_name[i] & 0xFF) << (i * 8); - } - return vendor_id; -} - // returns info for display and processor entries. key is (vendor_id << 32 | device_id) // npus: (vendor_id << 32 | device_id) for devices we think are NPUs from DXCORE -std::unordered_map GetDeviceInfoSetupApi(const std::unordered_set& npus) { +std::unordered_map GetDeviceInfoSetupApi(const std::unordered_set& npus, + bool& have_remote_display_adapter) { std::unordered_map device_info; const GUID local_DXCORE_ADAPTER_ATTRIBUTE_D3D12_GENERIC_ML = {0xb71b0d41, 0x1088, 0x422f, 0xa2, 0x7c, 0x2, 0x50, 0xb7, 0xd3, 0xa9, 0x88}; @@ -151,8 +138,7 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde if (auto idx = hardware_id.find(prefix); idx != std::wstring::npos) { auto id = hardware_id.substr(idx + prefix.size(), 4); if (id.size() == 4) { - // DXCore reports vendor and device IDs as 32-bit integer representations of the ASCII string. - return WStringToUint32Id(id); + return static_cast(std::stoul(id, nullptr, 16)); } } @@ -170,6 +156,11 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde // Won't always have a vendor id from an ACPI entry. ACPI is not defined for this purpose. if (vendor_id == 0 && device_id == 0) { + static const std::wstring remote_display_adapter_id(L"RdpIdd_IndirectDisplay"); + if (guid == GUID_DEVCLASS_DISPLAY && remote_display_adapter_id == buffer) { + have_remote_display_adapter = true; + } + continue; } @@ -305,7 +296,7 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde } // returns LUID to DeviceInfo -std::unordered_map GetDeviceInfoD3D12() { +std::unordered_map GetDeviceInfoD3D12(bool have_remote_display_adapter) { std::unordered_map device_info; ComPtr factory; @@ -314,6 +305,8 @@ std::unordered_map GetDeviceInfoD3D12() { return device_info; } + UINT num_adapters = 0; + ComPtr adapter; for (UINT i = 0; factory->EnumAdapters1(i, adapter.ReleaseAndGetAddressOf()) != DXGI_ERROR_NOT_FOUND; ++i) { DXGI_ADAPTER_DESC1 desc; @@ -339,9 +332,12 @@ std::unordered_map GetDeviceInfoD3D12() { info.metadata[L"LUID"] = std::to_wstring(key); info.metadata[L"DxgiAdapterNumber"] = std::to_wstring(i); info.metadata[L"DxgiVideoMemory"] = std::to_wstring(desc.DedicatedVideoMemory / (1024 * 1024)) + L" MB"; + + ++num_adapters; } - // iterate by high-performance GPU preference to add that info + // iterate by high-performance GPU preference to add that info. + UINT cur_adapter = 0; for (UINT i = 0; factory->EnumAdapterByGpuPreference( i, DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE, IID_PPV_ARGS(adapter.ReleaseAndGetAddressOf())) != DXGI_ERROR_NOT_FOUND; @@ -352,12 +348,41 @@ std::unordered_map GetDeviceInfoD3D12() { } uint64_t key = GetLuidKey(desc.AdapterLuid); - auto it = device_info.find(key); - if (it != device_info.end()) { - DeviceInfo& info = it->second; - info.metadata[L"DxgiHighPerformanceIndex"] = std::to_wstring(i); + if (it == device_info.end()) { + continue; } + + DeviceInfo& info = it->second; + + // try and drop the Microsoft Remote Display Adapter. it does not have the DXGI_ADAPTER_FLAG_SOFTWARE flag set + // and the vendor id, device id and description are the same as the real device. the LUID is different to the real + // device. + // Assumption: it will have the worst performance index of the devices we're considering so we only check the + // last adapter + if (num_adapters > 1 && have_remote_display_adapter && cur_adapter == num_adapters - 1) { + ComPtr output; + if (adapter->EnumOutputs(0, &output) == DXGI_ERROR_NOT_FOUND) { + // D3D_DRIVER_TYPE_WARP. Software based or disabled adapter. + // An adapter can be disabled in an RDP session. e.g. integrated GPU is disabled if there's a discrete GPU + + // if we have seen this vendor_id+device_id combination with a different LUID before we drop it. + if (std::any_of(device_info.begin(), device_info.end(), + [key, &info](const auto& entry) { + const auto& entry_info = entry.second; + return key != entry.first && + info.vendor_id == entry_info.vendor_id && + info.device_id == entry_info.device_id; + })) { + device_info.erase(key); + continue; + } + } + } + + info.metadata[L"DxgiHighPerformanceIndex"] = std::to_wstring(i); + + ++cur_adapter; } return device_info; @@ -497,10 +522,12 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor } } - // d3d12 info. key is luid - std::unordered_map luid_to_d3d12_info = GetDeviceInfoD3D12(); // setupapi_info. key is vendor_id+device_id - std::unordered_map setupapi_info = GetDeviceInfoSetupApi(npus); + bool have_remote_display_adapter = false; // set if we see the RdpIdd_IndirectDisplay hardware ID. + std::unordered_map setupapi_info = GetDeviceInfoSetupApi(npus, have_remote_display_adapter); + + // d3d12 info. key is luid + std::unordered_map luid_to_d3d12_info = GetDeviceInfoD3D12(have_remote_display_adapter); // Ensure we have at least one CPU bool found_cpu = false; diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index 2817dda9d0085..e123414b03b21 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -99,7 +99,7 @@ common::Status SoftmaxCPU(size_t N, float* Ydata, bool logarithmic, onnxruntime::concurrency::ThreadPool* thread_pool) { - MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, thread_pool); + MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, 0.0f, thread_pool); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h index 3359b2a69fe83..f7cc2523adbf6 100644 --- a/onnxruntime/core/providers/cpu/ml/ml_common.h +++ b/onnxruntime/core/providers/cpu/ml/ml_common.h @@ -445,7 +445,7 @@ void batched_update_scores_inplace(gsl::span scores, int64_t num_batches_in, } if (use_mlas) { - MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, threadpool); + MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, 0.0f, threadpool); } else { while (s < s_end) { gsl::span scores_for_batch(s, s + batch_size); diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index adb2aee171f39..c691be6ffd0e8 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include "core/framework/element_type_lists.h" #include "core/framework/float8.h" @@ -301,14 +302,31 @@ struct DequantizeLinearApply { * @param[in] zero_point same shape as scale */ void op(size_t M, size_t K, size_t N, const T* input, - const OutT* scale, OutT* output, const T* zero_point) { + const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { for (size_t m = 0; m < M; m++) { for (size_t k = 0; k < K; k++) { +#if defined(ORT_CLIENT_PACKAGE_BUILD) + // TODO: Only using multithreaded/SIMD DQ when ORT is built for client/on-device workloads. + // Make this the default behavior after more testing. + if constexpr (std::is_same_v || std::is_same_v) { + ParDequantizeLinearStd(input, output, N, scale[k], zero_point ? zero_point[k] : 0, thread_pool); + input += N; + output += N; + } else { + auto zp = zero_point ? static_cast(zero_point[k]) : 0; + auto sc = static_cast(scale[k]); + for (size_t n = 0; n < N; n++) { + *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); + } + } +#else + ORT_UNUSED_PARAMETER(thread_pool); auto zp = zero_point ? static_cast(zero_point[k]) : 0; auto sc = static_cast(scale[k]); for (size_t n = 0; n < N; n++) { *output++ = static_cast(static_cast(static_cast(*input++) - zp) * sc); } +#endif // defined(ORT_CLIENT_PACKAGE_BUILD) } } } @@ -327,7 +345,8 @@ struct DequantizeLinearApply { * @param[in] zero_point same shape as scale */ void op(size_t M, size_t K, size_t N, size_t quant_block_size, - const T* input, const OutT* scale, OutT* output, const T* zero_point) { + const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { + ORT_UNUSED_PARAMETER(thread_pool); if (zero_point) { for (size_t m = 0; m < M; m++) { for (size_t bd = 0; bd < K; bd += quant_block_size) { @@ -368,7 +387,8 @@ template struct DequantizeLinearApply { // per-tensor/layer or per-axis quantization void op(size_t M, size_t K, size_t N, - const T* input, const OutT* scale, OutT* output, const T* zero_point) { + const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { + ORT_UNUSED_PARAMETER(thread_pool); size_t input_index = 0; for (size_t m = 0; m < M; m++) { @@ -394,7 +414,8 @@ struct DequantizeLinearApply { // Blocked quantization // TODO(fajin) : add mlas kernel to utilize multithreading, refer MlasDequantizeBlockwise. void op(size_t M, size_t K, size_t N, size_t quant_block_size, - const T* input, const OutT* scale, OutT* output, const T* zero_point) { + const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { + ORT_UNUSED_PARAMETER(thread_pool); size_t input_index = 0; if (zero_point) { @@ -440,36 +461,36 @@ struct DequantizeLinearApply { #if !defined(DISABLE_FLOAT8_TYPES) -#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ - template \ - struct DequantizeLinearApply { \ - /* Per-tensor/layer or per-axis quantization */ \ - void op(size_t M, size_t K, size_t N, \ - const T* input, const OutT* scale, OutT* output, const T*) { \ - for (size_t m = 0; m < M; m++) { \ - for (size_t bd = 0; bd < K; bd++) { \ - auto sc = scale[bd]; \ - for (size_t bs = 0; bs < N; bs++, input++) { \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - } \ - } \ - /* Blocked quantization */ \ - void op(size_t M, size_t K, size_t N, size_t quant_block_size, \ - const T* input, const OutT* scale, OutT* output, const T*) { \ - for (size_t m = 0; m < M; m++) { \ - for (size_t bd = 0; bd < K; bd += quant_block_size) { \ - for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { \ - for (size_t bs = 0; bs < N; bs++, input++) { \ - auto sc = static_cast(scale[bs]); \ - *output++ = static_cast(input->ToFloat() * sc); \ - } \ - } \ - scale += N; \ - } \ - } \ - } \ +#define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ + template \ + struct DequantizeLinearApply { \ + /* Per-tensor/layer or per-axis quantization */ \ + void op(size_t M, size_t K, size_t N, \ + const T* input, const OutT* scale, OutT* output, const T*, concurrency::ThreadPool*) { \ + for (size_t m = 0; m < M; m++) { \ + for (size_t bd = 0; bd < K; bd++) { \ + auto sc = scale[bd]; \ + for (size_t bs = 0; bs < N; bs++, input++) { \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + } \ + } \ + /* Blocked quantization */ \ + void op(size_t M, size_t K, size_t N, size_t quant_block_size, \ + const T* input, const OutT* scale, OutT* output, const T*, concurrency::ThreadPool*) { \ + for (size_t m = 0; m < M; m++) { \ + for (size_t bd = 0; bd < K; bd += quant_block_size) { \ + for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { \ + for (size_t bs = 0; bs < N; bs++, input++) { \ + auto sc = static_cast(scale[bs]); \ + *output++ = static_cast(input->ToFloat() * sc); \ + } \ + } \ + scale += N; \ + } \ + } \ + } \ }; DEQUANTIZE_LINEAR_APPLY_FLOAT8(Float8E4M3FN) @@ -513,6 +534,7 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { const auto to = x_scale.GetElementType(); const T* input = x.Data(); constexpr bool is_4bit = boost::mp11::mp_contains, T>::value; + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); if (to == ONNX_NAMESPACE::TensorProto::FLOAT) { const float* scale = x_scale.Data(); @@ -522,12 +544,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { static_cast(broadcast_dim), static_cast(process_block_size), static_cast(block_size_), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } else { DequantizeLinearApply().op(static_cast(process_block_count), static_cast(broadcast_dim), static_cast(process_block_size), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } } else if (to == ONNX_NAMESPACE::TensorProto::FLOAT16) { const MLFloat16* scale = x_scale.Data(); @@ -537,12 +559,12 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { static_cast(broadcast_dim), static_cast(process_block_size), static_cast(block_size_), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } else { DequantizeLinearApply().op(static_cast(process_block_count), static_cast(broadcast_dim), static_cast(process_block_size), - input, scale, output, zero_point); + input, scale, output, zero_point, thread_pool); } } else if (to == ONNX_NAMESPACE::TensorProto::BFLOAT16) { ORT_THROW("DequantizeLinear into BFLOAT16 is not implemented yet."); diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 2de496a9168a0..f00bf51ae143d 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -313,8 +313,10 @@ CUDA_Provider* GetProvider() { // OrtEpApi infrastructure to be able to use the CUDA EP as an OrtEpFactory for auto EP selection. struct CudaEpFactory : OrtEpFactory { CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in} { + ort_version_supported = ORT_API_VERSION; GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; @@ -331,6 +333,11 @@ struct CudaEpFactory : OrtEpFactory { return factory->vendor.c_str(); } + static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_id; + } + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { return ORT_VERSION; } @@ -374,6 +381,7 @@ struct CudaEpFactory : OrtEpFactory { const OrtApi& ort_api; const std::string ep_name{kCudaExecutionProvider}; // EP name const std::string vendor{"Microsoft"}; // EP vendor name + uint32_t vendor_id{0x1414}; // Microsoft vendor ID }; extern "C" { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index a5066a41981e5..9611cb82d5a62 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -781,7 +781,10 @@ namespace Dml // this branch could be reached with a bad custom operator or malformed file. If // a legitimate case reaches here and DML needs to support a new input/output type // besides tensors, then remove the assert. - assert(false); + + // If the model has nodes that use Optional we will arrive here. It's a valid ONNX model but + // TryGetTensorDataType doesn't handle Optional. + // assert(false); nodeContainsSupportedDataTypes = false; return; } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 711d81186bad1..c5b6507ac847b 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1304,7 +1304,7 @@ std::vector NvExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - return std::make_unique(device_id, CUDA_PINNED); + return std::make_unique(CUDA_PINNED, device_id); }, narrow(device_id_)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 86b684f8c6ebd..21947a22e2b92 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -235,7 +235,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(reshape_input, reshape_input_info)); bool needs_reshape = false; - const std::string reshape4d = input_names[0] + "_pre_reshape"; + const std::string reshape_prior_out = input_names[0] + "_prior_reshape"; if (input_shape.size() == 3) { needs_reshape = true; // build new_shape = {N, 1, C, L} @@ -245,25 +245,24 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra input_shape[1], input_shape[2]}; - const std::string reshape_node_name = "pre_reshape"; - QnnTensorWrapper rw( - reshape4d, + QnnTensorWrapper reshape_prior_tensor( + reshape_prior_out, QNN_TENSOR_TYPE_NATIVE, reshape_input_info.qnn_data_type, reshape_input_info.quant_param.Copy(), std::move(new_shape)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(rw)), - "Failed to add reshape-4d tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_prior_tensor)), + "Failed to add reshape prior tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - reshape_node_name, + utils::GetNodeName(node_unit) + "_reshape_prior", QNN_OP_PACKAGE_NAME_QTI_AISW, - "Reshape", + QNN_OP_RESHAPE, {input_names[0]}, - {reshape4d}, + {reshape_prior_out}, {}, do_op_validation), - "Failed to create reshape-4d node."); - input_names[0] = reshape4d; + "Failed to create reshape prior node for pool op."); + input_names[0] = reshape_prior_out; input_shape = {input_shape[0], 1, input_shape[1], input_shape[2]}; } @@ -446,9 +445,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } const auto& outputs = node_unit.Outputs(); const std::string real_out = outputs[0].node_arg.Name(); - const std::string pool_name = "poolmax2d"; - const std::string pool_out = real_out + "_post_reshape"; - const std::string post_reshape_node_name = "post_reshape"; + const std::string pool_out = real_out + "_reshape_after"; const std::string qnn_op = GetQnnOpType(op_type); TensorInfo output_info{}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); @@ -466,33 +463,34 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra "Failed to add tensor for pool_out"); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - pool_name, + utils::GetNodeName(node_unit) + "_pool2d", QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op, - {reshape4d}, + {reshape_prior_out}, {pool_out}, std::move(param_tensor_names), do_op_validation), - "Failed to create QNN Pool node for rank-3 input."); + "Failed to create pool node for rank-3 input."); std::vector final_shape3d = output_info.shape; - QnnTensorWrapper reshape_back_tensor( + QnnTensorWrapper reshape_after_tensor( real_out, tensor_type, output_info.qnn_data_type, output_info.quant_param.Copy(), std::move(final_shape3d)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_back_tensor)), "Failed to add tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_after_tensor)), + "Failed to add reshape after tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - post_reshape_node_name, + utils::GetNodeName(node_unit) + "_reshape_after", QNN_OP_PACKAGE_NAME_QTI_AISW, - "Reshape", + QNN_OP_RESHAPE, {pool_out}, {real_out}, {}, do_op_validation), - "Failed to create reshape-back node."); + "Failed to create reshape after node for pool op."); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 2650316dd07ac..502ea86b689f4 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -9,7 +9,7 @@ namespace onnxruntime { namespace qnn { -// Operator which only need to hanle node inputs & outputs, no attributes or no need to handle attributes +// Operator which only need to handle node inputs & outputs, no attributes or no need to handle attributes class SimpleOpBuilder : public BaseOpBuilder { public: SimpleOpBuilder() : BaseOpBuilder("SimpleOpBuilder") {} @@ -38,7 +38,7 @@ class SimpleOpBuilder : public BaseOpBuilder { const logging::Logger& logger, bool do_op_validation) const ORT_MUST_USE_RESULT; - static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest"}; + static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest", "linear"}; static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; static constexpr std::array scatternd_supported_reduction = {"none", "add", "mul"}; }; @@ -60,8 +60,8 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, // To DO: Remove once QNN CPU supports ScatterND const auto qnn_backend_type = qnn_model_wrapper.GetQnnBackendType(); if (op_type == "ScatterND") { - ORT_RETURN_IF_NOT(qnn_backend_type == QnnBackendType::HTP, - "QNN EP only supports ScatterND op on HTP backend. Falling back to ORT CPU."); + ORT_RETURN_IF(qnn_backend_type == QnnBackendType::CPU, + "QNN EP does not support ScatterND op on CPU backend. Falling back to ORT CPU."); } // ONNX's Min, Max, and Sum operators accept a variable number of inputs (i.e., variadic). @@ -233,12 +233,12 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, std::string mode = node_helper.Get("mode", "linear"); Qnn_Scalar_t mode_qnn_scalar = QNN_SCALAR_INIT; mode_qnn_scalar.dataType = QNN_DATATYPE_UINT_32; - if ("bilinear" == mode) { + if ("linear" == mode || "bilinear" == mode) { mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_BILINEAR; } else if ("nearest" == mode) { mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_NEAREST; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support bilinear & nearest."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support [linear, bilinear, nearest]."); } QnnParamWrapper mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_MODE, mode_qnn_scalar); param_tensor_names.push_back(mode_param.GetParamTensorName()); @@ -254,7 +254,7 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, } else if ("reflection" == padding_mode) { padding_mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_PADDING_MODE_REFLECTION; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support zeros, border & reflection."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support [zeros, border, reflection]."); } QnnParamWrapper padding_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_PADDING_MODE, padding_mode_qnn_scalar); param_tensor_names.push_back(padding_mode_param.GetParamTensorName()); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index d22edaf33eb1c..3dc103046424e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -839,6 +839,23 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord return Status::OK(); } +Status QnnBackendManager::SetContextPriority(ContextPriority context_priority) { + QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; + ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority, context_priority_config)); + + QnnContext_Config_t* configs[] = {&context_priority_config, nullptr}; + for (const auto& context_handle : contexts_) { + auto result = qnn_interface_.contextSetConfig(context_handle, (const QnnContext_Config_t**)configs); + ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to set context priority for context handle: ", context_handle); + } + + return Status::OK(); +} + +Status QnnBackendManager::ResetContextPriority() { + return SetContextPriority(context_priority_); +} + Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { if (true == context_created_) { LOGS_DEFAULT(INFO) << "Context created already."; @@ -1426,13 +1443,33 @@ Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, return Status::OK(); } -Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_id, - uint32_t rpc_control_latency) { +Status QnnBackendManager::SetRpcPowerConfigs(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency, + uint32_t rpc_polling_time) { // This function is called in QNN EP's OnRunStart() even if QNN backend setup failed and the model is assigned // to a different EP. Therefore, we have to check that backend setup actually completed before trying to // set RPC control latency. Otherwise, this causes a segfault because the QNN backend library is unloaded. ORT_RETURN_IF_NOT(backend_setup_completed_, "Cannot set HTP RPC control latency if backend setup is not complete."); + + constexpr int kNumRpcPollingPowerConfigs = 2; + std::vector rpc_power_configs; + rpc_power_configs.reserve(kNumRpcPollingPowerConfigs); + + // Set rpc control latency here if (rpc_control_latency != 0) { + auto& rpc_control_latency_cfg = rpc_power_configs.emplace_back(); + rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; + rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency; + } + + // Note: v68 does not support rpc polling mode + if (rpc_polling_time != 0) { + auto& rpc_polling_time_cfg = rpc_power_configs.emplace_back(); + rpc_polling_time_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME; + rpc_polling_time_cfg.rpcPollingTimeConfig = rpc_polling_time; + } + + if (rpc_power_configs.size() > 0) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); @@ -1442,15 +1479,6 @@ Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_ "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; - // Set rpc control latency here, but note that v68 doesn't support rpc polling mode. - constexpr int kNumRpcPollingPowerConfigs = 2; - std::vector rpc_power_configs(kNumRpcPollingPowerConfigs); - QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency_cfg = rpc_power_configs[0]; - // v68 doesn't support this. - QnnHtpPerfInfrastructure_PowerConfig_t& rpc_polling_time = rpc_power_configs[1]; - rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; - rpc_polling_time.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME; - rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency; std::vector perf_power_configs_ptr = ObtainNullTermPtrVector(rpc_power_configs); status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data()); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 3e68df3024565..2a71c7391b180 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -159,8 +159,9 @@ class QnnBackendManager : public std::enable_shared_from_this Status SetHtpPowerConfig(uint32_t htp_power_config_client_id, HtpPerformanceMode htp_performance_mode); - Status SetRpcControlLatency(uint32_t htp_power_config_client_id, - uint32_t rpc_control_latency); + Status SetRpcPowerConfigs(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency, + uint32_t rpc_polling_time); const QNN_INTERFACE_VER_TYPE& GetQnnInterface() { return qnn_interface_; } @@ -219,6 +220,11 @@ class QnnBackendManager : public std::enable_shared_from_this // For each node name, a mapping to the context handle will be created void ProcessContextFromBinListAsync(Qnn_ContextHandle_t handle, void* notifyParam); + // Sets the context priority to the given value, if valid + Status SetContextPriority(ContextPriority context_priority); + // Resets the context priority to the session default as defined by context_priority_ + Status ResetContextPriority(); + private: Status LoadBackend(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 236447cc95c3d..3acb3347acee1 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -1356,7 +1356,8 @@ QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* uint32_t device_id, uint32_t core_id, qnn::HtpPerformanceMode default_htp_performance_mode, - uint32_t default_rpc_control_latency) + uint32_t default_rpc_control_latency, + uint32_t default_rpc_polling_time) : qnn_backend_manager_(qnn_backend_manager) { Status rt = qnn_backend_manager_->CreateHtpPowerCfgId(device_id, core_id, htp_power_config_id_); is_htp_power_config_id_valid_ = rt.IsOK(); @@ -1367,9 +1368,10 @@ QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetHtpPowerConfig(htp_power_config_id_, default_htp_performance_mode)); } - if (default_rpc_control_latency > 0) { - ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcControlLatency(htp_power_config_id_, - default_rpc_control_latency)); + if (default_rpc_control_latency > 0 || default_rpc_polling_time > 0) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcPowerConfigs(htp_power_config_id_, + default_rpc_control_latency, + default_rpc_polling_time)); } } } @@ -1400,7 +1402,8 @@ QNNExecutionProvider::PerThreadContext& QNNExecutionProvider::GetPerThreadContex if (context_state_.retired_context_pool.empty()) { uint32_t core_id = 0; context = std::make_shared(qnn_backend_manager_.get(), device_id_, core_id, - default_htp_performance_mode_, default_rpc_control_latency_); + default_htp_performance_mode_, default_rpc_control_latency_, + default_rpc_polling_time_); } else { context = context_state_.retired_context_pool.back(); context_state_.retired_context_pool.pop_back(); @@ -1468,15 +1471,21 @@ Status QNNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_optio LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; } + uint32_t rpc_polling_time = 0; + if (qnn::HtpPerformanceMode::kHtpBurst != htp_performance_mode) { + rpc_polling_time = 9999; + } + if (GetPerThreadContext().IsHtpPowerConfigIdValid()) { if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) { ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), htp_performance_mode)); } - if (rpc_control_latency > 0) { - ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcControlLatency(GetPerThreadContext().GetHtpPowerConfigId(), - rpc_control_latency)); + if (rpc_control_latency > 0 || rpc_polling_time > 0) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcPowerConfigs(GetPerThreadContext().GetHtpPowerConfigId(), + rpc_control_latency, + rpc_polling_time)); } } @@ -1545,4 +1554,38 @@ OrtDevice QNNExecutionProvider::GetOrtDeviceByMemType(OrtMemType /* em_type */) return default_device_; } +Status QNNExecutionProvider::SetEpDynamicOptions(gsl::span keys, + gsl::span values) { + if (keys.size() != values.size()) { + LOGS_DEFAULT(ERROR) << "SetEpDynamicOptions: number of keys (" << keys.size() + << ") does not equal number of values (" << values.size() << ")."; + } + auto key_it = keys.begin(); + auto value_it = values.begin(); + + while (key_it != keys.end() && value_it != values.end()) { + std::string key(*key_it); + std::string value(*value_it); + + if (key == kOrtEpDynamicOptionsWorkloadType) { + if (value == "Default") { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->ResetContextPriority()); + } else if (value == "Efficient") { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetContextPriority(qnn::ContextPriority::LOW)); + } else { + LOGS_DEFAULT(ERROR) << "Invalid EP Workload Type: " << value; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid EP Workload Type."); + } + } else { + LOGS_DEFAULT(ERROR) << "EP Dynamic Option \"" << key << "\" is not currently supported."; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported EP Dynamic Option"); + } + + key_it++; + value_it++; + } + + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 06f9726ae96cf..6adf613932d66 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -57,6 +57,9 @@ class QNNExecutionProvider : public IExecutionProvider { OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; + Status SetEpDynamicOptions(gsl::span keys, + gsl::span value) override; + private: std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, @@ -96,6 +99,7 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t device_id_ = 0; qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; uint32_t default_rpc_control_latency_ = 0; + uint32_t default_rpc_polling_time_ = 0; bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; bool stop_share_ep_contexts_ = false; @@ -116,7 +120,8 @@ class QNNExecutionProvider : public IExecutionProvider { PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, uint32_t device_id, uint32_t core_id, qnn::HtpPerformanceMode default_htp_performance_mode, - uint32_t default_rpc_control_latency); + uint32_t default_rpc_control_latency, + uint32_t default_rpc_polling_time); ~PerThreadContext(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc index c679ea1adb286..785177ce37788 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc @@ -125,8 +125,10 @@ struct QnnEpFactory : OrtEpFactory { OrtHardwareDeviceType hw_type, const char* qnn_backend_type) : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, qnn_backend_type{qnn_backend_type} { + ort_version_supported = ORT_API_VERSION; GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; @@ -142,7 +144,12 @@ struct QnnEpFactory : OrtEpFactory { static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); - return factory->vendor.c_str(); + return factory->ep_vendor.c_str(); + } + + static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_vendor_id; } static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { @@ -195,8 +202,9 @@ struct QnnEpFactory : OrtEpFactory { } const OrtApi& ort_api; - const std::string ep_name; // EP name - const std::string vendor{"Microsoft"}; // EP vendor name + const std::string ep_name; // EP name + const std::string ep_vendor{"Microsoft"}; // EP vendor name + uint32_t ep_vendor_id{0x1414}; // Microsoft vendor ID // Qualcomm vendor ID. Refer to the ACPI ID registry (search Qualcomm): https://uefi.org/ACPI_ID_List const uint32_t vendor_id{'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24)}; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 90a4294fb47f0..1e9fafe8aa323 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -7,6 +7,25 @@ #include "tensorrt_execution_provider_custom_ops.h" #include "tensorrt_execution_provider.h" +// The filename extension for a shared library is different per platform +#ifdef _WIN32 +#define LIBRARY_PREFIX +#define LIBRARY_EXTENSION ORT_TSTR(".dll") +#elif defined(__APPLE__) +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".dylib" +#else +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".so" +#endif + +#ifdef _WIN32 +#define ORT_DEF2STR_HELPER(x) L#x +#else +#define ORT_DEF2STR_HELPER(X) #X +#endif +#define ORT_DEF2STR(x) ORT_DEF2STR_HELPER(x) + namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(bool verbose); @@ -58,8 +77,31 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& // Get all registered TRT plugins from registry LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ..."; TensorrtLogger trt_logger = GetTensorrtLogger(false); - initLibNvInferPlugins(&trt_logger, ""); + try { + void* library_handle = nullptr; + const auto& env = onnxruntime::GetDefaultEnv(); +#if NV_TENSORRT_MAJOR < 10 + auto full_path = env.GetRuntimePath() + + PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin") LIBRARY_EXTENSION); +#else +#ifdef _WIN32 + auto full_path = PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin_" ORT_DEF2STR(NV_TENSORRT_MAJOR)) LIBRARY_EXTENSION); +#else + auto full_path = PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin") LIBRARY_EXTENSION ORT_TSTR("." ORT_DEF2STR(NV_TENSORRT_MAJOR))); +#endif +#endif + + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, false, &library_handle)); + bool (*dyn_initLibNvInferPlugins)(void* logger, char const* libNamespace); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "initLibNvInferPlugins", (void**)&dyn_initLibNvInferPlugins)); + if (!dyn_initLibNvInferPlugins(&trt_logger, "")) { + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugin library was found but was not able to initialize default plugins."; + } + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugins successfully loaded."; + } catch (const std::exception&) { + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugin library is not on the path and is therefore ignored"; + } int num_plugin_creator = 0; auto plugin_creators = getPluginRegistry()->getAllCreators(&num_plugin_creator); std::unordered_set registered_plugin_names; diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index e8140a4d59eab..113a3f31be7f9 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -193,27 +193,21 @@ class BucketCacheManager : public IBufferCacheManager { } void ReleaseBuffer(WGPUBuffer buffer) override { - pending_buffers_.emplace_back(buffer); - } + auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); - void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { - for (auto& buffer : pending_buffers_) { - auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); - auto it = buckets_.find(buffer_size); - if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { - it->second.emplace_back(buffer); - } else { - wgpuBufferRelease(buffer); - } + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { + it->second.emplace_back(buffer); + } else { + wgpuBufferRelease(buffer); } + } - pending_buffers_.clear(); + void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { + // no-op } ~BucketCacheManager() { - for (auto& buffer : pending_buffers_) { - wgpuBufferRelease(buffer); - } for (auto& pair : buckets_) { for (auto& buffer : pair.second) { wgpuBufferRelease(buffer); @@ -242,7 +236,6 @@ class BucketCacheManager : public IBufferCacheManager { } std::unordered_map buckets_limit_; std::unordered_map> buckets_; - std::vector pending_buffers_; std::vector buckets_keys_; }; diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc index 7f92ea4ed3776..313a96ba25509 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.cc +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -52,10 +52,28 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T1", CastOpTypeConstraints()) .TypeConstraint("T2", CastOpTypeConstraints()), Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 19, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Cast, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", CastOpTypeConstraints()) + .TypeConstraint("T2", CastOpTypeConstraints()), + Cast); ONNX_OPERATOR_KERNEL_EX( Cast, kOnnxDomain, - 19, + 23, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T1", CastOpTypeConstraints()) diff --git a/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc index f13e86c185928..9f07e2d2a3988 100644 --- a/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc +++ b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc @@ -146,24 +146,24 @@ Status ScatterND::ComputeInternal(ComputeContext& context) const { const auto* updates = context.Input(2); const auto& input_shape = input->Shape(); const auto& indices_shape = indices->Shape(); - auto indices_rank = indices_shape.NumDimensions(); - auto last_index_dimension = static_cast(indices_shape[indices_rank - 1]); - auto num_updates_elements = static_cast(input_shape.SizeFromDimension(last_index_dimension)); - // TODO: support bool with components 4. - const size_t components = 1; - auto output_size = static_cast((indices_shape.SizeToDimension(indices_rank - 1) + components - 1) / components); auto* output = context.Output(0, input_shape); - if (output_size == 0) { - // If the output tensor is empty, we can return early. - return Status::OK(); - } - MLDataType data_type = input->DataType(); const void* source = input->DataRaw(); void* target = output->MutableDataRaw(); // If source and target pointers are not equal (non-inplace operation), we need to copy the data. if (target != source) { ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input, *output)); } + if (indices_shape.Size() == 0) { + // If the indices are empty, we can return early. + return Status::OK(); + } + auto indices_rank = indices_shape.NumDimensions(); + auto last_index_dimension = static_cast(indices_shape[indices_rank - 1]); + auto num_updates_elements = static_cast(input_shape.SizeFromDimension(last_index_dimension)); + // TODO: support bool with components 4. + const size_t components = 1; + auto output_size = static_cast((indices_shape.SizeToDimension(indices_rank - 1) + components - 1) / components); + MLDataType data_type = input->DataType(); ScatterNDProgram program(reduction_, data_type); program .CacheHint(static_cast(reduction_)) diff --git a/onnxruntime/core/providers/webgpu/tensor/slice.cc b/onnxruntime/core/providers/webgpu/tensor/slice.cc index 39432db5113d1..7e8b434431781 100644 --- a/onnxruntime/core/providers/webgpu/tensor/slice.cc +++ b/onnxruntime/core/providers/webgpu/tensor/slice.cc @@ -172,8 +172,8 @@ Status Slice::ComputeInternal(ComputeContext& context) const { } if (step < 0) { // we are slicing in reverse - start = std::clamp(start, int64_t{0}, dim_value - 1); - end = std::clamp(end, int64_t{-1}, dim_value - 1); + start = dim_value > 0 ? std::clamp(start, int64_t{0}, dim_value - 1) : 0; + end = dim_value > 0 ? std::clamp(end, int64_t{-1}, dim_value - 1) : -1; // note that we are flipping start and end to switch to forward step signs.push_back(-1); steps.push_back(static_cast(-step)); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 460d220ecf1b9..6e09f494f4a8d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -123,7 +123,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 8, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Cast); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Cast); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, float, Clip); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, float, Clip); @@ -455,7 +457,9 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast), KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast), - KERNEL_CREATE_INFO(19, Cast), + KERNEL_CREATE_INFO_VERSIONED(19, 20, Cast), + KERNEL_CREATE_INFO_VERSIONED(21, 22, Cast), + KERNEL_CREATE_INFO(23, Cast), // // activations BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/README.md b/onnxruntime/core/providers/webgpu/wgsl_templates/README.md index c1a62e7fa7858..6bd2f98cc5713 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/README.md +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/README.md @@ -64,7 +64,7 @@ This section includes instructions for how to use the template system in the dev 1. Create WGSL template files in `.wgsl.template` extension. - [Reference: Template Syntax](https://github.com/fs-eire/wgsl-template?tab=readme-ov-file#template-syntax) - - [Reference: Built-in Utilities](#Utilities) + - [Reference: Built-in Utilities](https://github.com/fs-eire/wgsl-template?tab=readme-ov-file#Utilities) - [Example: Pad](../tensor/pad.wgsl.template) 2. In the implementation of `YourProgram::GenerateShaderCode()`, load and use the generated template files. @@ -117,4 +117,4 @@ This section includes instructions for how to use the template system in the dev 1. Build ORT once with dynamic template mode 2. Launch wgsl-gen in watch mode 3. Run ORT to debug/validate the shader - 4. Make changes to the template files, and repeat step (3) + 4. Make changes to the template files, and repeat step (c) diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json index 7cde6c17f54e9..df1940ed6416b 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json @@ -9,13 +9,13 @@ "version": "1.0.0", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.3" + "@fs-eire/wgsl-template": "^0.1.13" } }, "node_modules/@fs-eire/wgsl-template": { - "version": "0.1.10", - "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.10.tgz", - "integrity": "sha512-F5qQZxNweZ3ZD3d9RNc/g3nTiW7jyaAVi7SlMOL4wOfXh+Nm/qca2DISNTf3kjpVqkoazMJGbZ6TPQ4a/vjw0g==", + "version": "0.1.13", + "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.13.tgz", + "integrity": "sha512-SOQjVCQCUmXb9qYr2E3CKNs88/FzINuhFJiobBEkSAsyKtJby9oFWGZnrEO+hIl/oDTLA01LbjiDxuf6TGHE/w==", "license": "MIT", "dependencies": { "minimist": "^1.2.8" diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json index 34831ccddeb33..246e7365531e0 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json @@ -10,6 +10,6 @@ "author": "", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.3" + "@fs-eire/wgsl-template": "^0.1.13" } } diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index e821265fff80d..142d64caa64aa 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -99,69 +99,93 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n return true; } -// Check if all input tensor ranks of the given node are supported by WebNN. -bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) { - const std::string_view op_type = node.OpType(); - const auto it = op_inputs_map.find(op_type); - if (it == op_inputs_map.end()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type << "] is not found in the op inputs map."; +// Check if a single input's rank of an ONNX op is supported by corresponding WebNN op. +bool IsInputRankSupported(const emscripten::val& wnn_limits, + const std::string_view webnn_op_type, + const std::string_view input_name, + const size_t input_rank, + const std::string_view node_name, + const logging::Logger& logger) { + const std::string webnn_op_type_str(webnn_op_type); + const std::string input_name_str(input_name); + + if (wnn_limits[webnn_op_type_str].isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type: [" << webnn_op_type + << "] is not defined in WebNN MLOpSupportLimits."; return false; } - const auto& input_defs = node.InputDefs(); - const std::string_view webnn_op_type = it->second.opType; - const std::string webnn_op_type_str(webnn_op_type); + const emscripten::val input_limits = wnn_limits[webnn_op_type_str][input_name_str]; - for (const auto& input : it->second.inputs) { - if (static_cast(input.index) >= input_defs.size() || input_defs[input.index] == nullptr) { - LOGS(logger, VERBOSE) << "Input index [" << input.index - << "] for operator type [" << op_type - << "], corresponding WebNN op type [" << webnn_op_type - << "], WebNN input name [" << input.name - << "] is invalid."; - return false; - } + if (input_limits.isUndefined()) { + LOGS(logger, VERBOSE) << "Node name: [" << node_name + << "], WebNN op type: [" << webnn_op_type + << "], input [" << input_name + << "]: limits are not defined in WebNN MLOpSupportLimits."; + return false; + } - std::vector input_shape; - if (!GetShape(*input_defs[input.index], input_shape, logger)) { - return false; - } + const emscripten::val rank_range = input_limits["rankRange"]; + if (rank_range.isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type + << "] input [" << input_name + << "]: missing 'rankRange' attribute."; + return false; + } - const std::string input_name_str(input.name); - if (wnn_limits[webnn_op_type_str].isUndefined() || - wnn_limits[webnn_op_type_str][input_name_str].isUndefined()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name " << input.name - << " is not defined in wnn_limits."; - return false; - } + const emscripten::val min_val = rank_range["min"]; + const emscripten::val max_val = rank_range["max"]; + if (min_val.isUndefined() || max_val.isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type + << "] input [" << input_name + << "]: its 'rankRange' limits is missing valid 'min' or 'max' attributes."; + return false; + } - const auto& input_limits = wnn_limits[webnn_op_type_str][input_name_str]; - if (input_limits["rankRange"].isUndefined()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name " << input.name - << "'s rankRange is not defined."; - return false; + size_t min_rank = min_val.as(); + size_t max_rank = max_val.as(); + if (input_rank < min_rank || input_rank > max_rank) { + LOGS(logger, VERBOSE) << "Node name: [" << node_name + << "] WebNN op type [" << webnn_op_type + << "] input [" << input_name << "] rank " << input_rank + << " is not in supported range [" << min_rank << ", " << max_rank << "]"; + return false; + } + + return true; +} + +bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) { + const std::string_view onnx_op_type = node.OpType(); + const std::string_view webnn_op_type = GetWebNNOpType(onnx_op_type); + + if (webnn_op_type.empty()) { + LOGS(logger, VERBOSE) << "ONNX op type: [" << onnx_op_type << "]'s corresponding WebNN op is not found."; + return false; + } + + std::vector inputs; + if (!GetWebNNOpInputs(onnx_op_type, inputs, logger)) { + return false; + } + + const auto& input_defs = node.InputDefs(); + + for (const auto& input : inputs) { + // If it is an optional input and is absent, skip. + if (!TensorExists(input_defs, input.index)) { + continue; } - int input_dim_size = static_cast(input_shape.size()); - int min_rank = input_limits["rankRange"]["min"].as(); - int max_rank = input_limits["rankRange"]["max"].as(); - - if (input_dim_size < min_rank || input_dim_size > max_rank) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name: " << input.name - << ", input size " << input_dim_size - << " is not in supported range [" << min_rank << ", " << max_rank << "]"; + std::vector shape; + if (!GetShape(*input_defs[input.index], shape, logger) || + !IsInputRankSupported(wnn_limits, webnn_op_type, input.name, + shape.size(), + node.Name(), logger)) { return false; } } + return true; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index d59788600f997..50e361ede221e 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -216,6 +216,13 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger); +bool IsInputRankSupported(const emscripten::val& wnn_limits, + const std::string_view webnn_op_type, + const std::string_view input_name, + const size_t input_rank, + const std::string_view node_name, + const logging::Logger& logger); + // Get a set of nodes supported by WebNN EP. std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, @@ -244,6 +251,33 @@ inline std::string_view GetWebNNOpType(const std::string_view onnx_op_type) { return (it != op_inputs_map.end()) ? it->second.opType : ""; } +// Get corresponding input name of WebNN op type by ONNX op type from op_input_map +inline std::string_view GetWebNNInputName(const std::string_view onnx_op_type, const int input_index) { + const auto it = op_inputs_map.find(onnx_op_type); + + if (it != op_inputs_map.end()) { + for (const auto& input : it->second.inputs) { + if (input.index == input_index) { + return input.name; + } + } + } + + return ""; +} + +inline bool GetWebNNOpInputs(const std::string_view onnx_op_type, + std::vector& inputs, + const logging::Logger& logger) { + const auto it = op_inputs_map.find(onnx_op_type); + if (it == op_inputs_map.end()) { + LOGS(logger, VERBOSE) << "WebNN op inputs not found for op type: " << onnx_op_type; + return false; + } + inputs = it->second.inputs; + return true; +} + bool AreDataTypesSame(const std::string_view op_type, gsl::span input_types, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index fc630af8cf1e3..fdf1709d87bac 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -18,10 +18,6 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - WebnnDeviceType device_type, const logging::Logger& logger) const override; }; // Add operator related. @@ -65,20 +61,6 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -// Operator support related. -bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const GraphViewer& /* initializers */, - const Node& node, - WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - return true; -} - void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index b0ec006db6986..3c8e7fa34f7ed 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -62,13 +62,12 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, int32_t input_type; if (!GetType(input, input_type, logger)) return false; - const std::string_view webnn_op_type = GetWebNNOpType(op_type); - if (webnn_op_type.empty()) - return false; + const std::string_view webnn_op_type = GetWebNNOpType(op_type); const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); return IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits, - webnn_input_name, "input", logger); + webnn_input_name, "input", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 280ffc83eae89..851dc373923ac 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -73,9 +73,10 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod return false; } - std::string webnn_input_name = op_type == "PRelu" ? "input" : "a"; + const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); std::string onnx_input_name = op_type == "PRelu" || op_type == "Pow" ? "X" : "A"; - return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index 8589237617745..db5e8cd51656c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -75,7 +75,8 @@ bool ConcatOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index b9383a63fe307..e0bfb3bd682e8 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -324,7 +324,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N x_zero_point = model_builder.CreateOrGetConstant(x_type, 0); } - // Scale is not used by ConvInteger but required by DequantizeLinear. So set it to deafult value 1.0f. + // Scale is not used by ConvInteger but required by DequantizeLinear. So set it to default value 1.0f. // The x_zero_point must be a scalar and the scale input should have the same shape as the zero point input. // So the x_scale must be a scalar too. x_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f); diff --git a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc index 7528d9ad2ff51..f3c392b608e45 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc @@ -77,10 +77,6 @@ bool CumSumOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - const std::string axis_name = GetTensorName(input_defs, 1); // Inputs contain optional 'axis' input. const auto* init = graph_viewer.GetConstantInitializer(axis_name); diff --git a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc index c22dd9e97bb1a..37a00fcb12abd 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc @@ -21,11 +21,6 @@ class DropoutOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - private: - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; // Add operator related. @@ -65,26 +60,13 @@ Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val options = emscripten::val::object(); options.set("label", output_defs[1]->Name() + "_identity"); // Add additional identity op in case the mask is the output of a WebNN graph, - // beacuse WebNN does not support a constant operand as output. + // because WebNN does not support a constant operand as output. emscripten::val mask_output = model_builder.GetBuilder().call("identity", one_constant, options); model_builder.AddOperand(output_defs[1]->Name(), std::move(mask_output)); } return Status::OK(); } -// Operator support related. -bool DropoutOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - return true; -} - void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc index e5b4fcddc4221..6aa760c0f4baf 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc @@ -28,6 +28,8 @@ class EinsumOpBuilder : public BaseOpBuilder { const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; + bool HasSupportedOutputsImpl(const Node& /* node */, const emscripten::val& /* wnn_limits */, + const logging::Logger& /* logger */) const override; }; // Helper functions, thanks for DML EP's OperatorHelper. @@ -42,12 +44,6 @@ enum class RecognizedOperatorType { Total, }; -struct RecognizedOperatorInfo { - RecognizedOperatorType recognized_operator_type; - std::initializer_list component_ranks; - std::initializer_list label_indices; -}; - struct Component { uint32_t label_index_begin; uint32_t label_index_end; @@ -598,7 +594,7 @@ Status EinsumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } } - // tranpose input + // transpose input std::vector permutation(input_labels.size()); for (uint32_t idx = 0; idx < input_labels.size(); idx++) { if (idx != diagonal_idx_1 && idx != diagonal_idx_2) { @@ -620,7 +616,7 @@ Status EinsumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options_trilu.set("upper", false); output = model_builder.GetBuilder().call("triangular", output, options_trilu); // tril - // reducesum to achieve the diagonal values + // reduceSum to achieve the diagonal values std::vector input_shape; std::vector reduced_axes; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); @@ -700,12 +696,6 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const GraphViewer&, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - if (input_defs.size() > 2) { - // TODO: Support more than two inputs. - LOGS(logger, VERBOSE) << "EinSum only supports up to two inputs."; - return false; - } - NodeAttrHelper helper(node); const auto equation = helper.Get("equation", std::string(" ")); std::vector label_indices; @@ -724,13 +714,6 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const GraphViewer&, return false; } - RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, - output_dimensions); - if (recognized_operator_type == RecognizedOperatorType::None) { - LOGS(logger, VERBOSE) << "The equation is not supported in Einsum."; - return false; - } - return true; } @@ -738,9 +721,14 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + if (input_defs.size() > 2) { + // TODO: Support more than two inputs. + LOGS(logger, VERBOSE) << "EinSum only supports up to two inputs."; + return false; + } + const std::string_view op_type = node.OpType(); - int32_t input0_type; - int32_t input1_type; + int32_t input0_type, input1_type; bool has_input1 = TensorExists(input_defs, 1); if (!GetType(*input_defs[0], input0_type, logger) || @@ -754,6 +742,13 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod return false; } + std::vector input0_shape; + std::vector input1_shape; + if (!GetShape(*input_defs[0], input0_shape, logger) || + (has_input1 && !GetShape(*input_defs[1], input1_shape, logger))) { + return false; + } + NodeAttrHelper helper(node); const auto equation = helper.Get("equation", std::string(" ")); std::vector label_indices; @@ -770,17 +765,54 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod RecognizedOperatorType recognized_operator_type = DetermineRecognizedOperatorType(label_indices, components, output_dimensions); + std::string_view decomposed_op_type; if (recognized_operator_type == RecognizedOperatorType::None) { LOGS(logger, VERBOSE) << "The equation is not supported in Einsum."; return false; - } else if (recognized_operator_type == RecognizedOperatorType::Pairwise) { - // Map to WebNN's gemm or matmul - return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger); + } else if (recognized_operator_type == RecognizedOperatorType::Multiply) { + decomposed_op_type = "Mul"; } else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) { - return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger); - } else { - return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger); + decomposed_op_type = "ReduceSum"; + } else if (recognized_operator_type == RecognizedOperatorType::Diagonal) { + decomposed_op_type = "Trilu"; + } else if (recognized_operator_type == RecognizedOperatorType::Transpose) { + decomposed_op_type = "Transpose"; + } else if (recognized_operator_type == RecognizedOperatorType::Pairwise) { + decomposed_op_type = "MatMul"; + } else { // Identity + // For the Identity case, we simply forward the input to the output without any modification. + return true; + } + + const std::string_view wnn_input0_name = GetWebNNInputName(decomposed_op_type, 0); + const std::string_view decompose_wnn_op_type = GetWebNNOpType(decomposed_op_type); + if (decompose_wnn_op_type.empty() || + !IsDataTypeSupportedByWebNNOp(op_type, decompose_wnn_op_type, input0_type, + wnn_limits, wnn_input0_name, "inputs", logger) || + !IsInputRankSupported(wnn_limits, decompose_wnn_op_type, wnn_input0_name, + input0_shape.size(), node.Name(), logger)) { + return false; + } + + if (has_input1) { + const std::string_view wnn_input1_name = GetWebNNInputName(decomposed_op_type, 1); + return IsDataTypeSupportedByWebNNOp(op_type, decompose_wnn_op_type, input1_type, + wnn_limits, wnn_input1_name, "inputs", logger) && + IsInputRankSupported(wnn_limits, decompose_wnn_op_type, wnn_input1_name, + input1_shape.size(), node.Name(), logger); } + + return true; +} + +bool EinsumOpBuilder::HasSupportedOutputsImpl(const Node& /* node */, + const emscripten::val& /* wnn_limits */, + const logging::Logger& /* logger */) const { + // The Einsum op produces output with the same data type as its input. + // Therefore, checking the output data type is unnecessary. + // This override prevents calling the base class implementation, as the base implementation + // would return false due to Einsum being a decomposed op. + return true; } void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc index 06beb56415609..ae4c3705fdb2e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc @@ -56,14 +56,14 @@ bool GatherElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const N const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t data_type; - int32_t indices_type; + int32_t data_type, indices_type; if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { return false; } return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateGatherElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc index 9200c596c0e53..af508c2800f4b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc @@ -61,14 +61,14 @@ bool GatherNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& n const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t data_type; - int32_t indices_type; + int32_t data_type, indices_type; if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { return false; } return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index d84c70032e1d1..7111a8f6beaa3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -20,8 +20,6 @@ class GatherOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -50,38 +48,20 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. - -bool GatherOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - const auto rank = input_shape.size(); - if (rank < 1) { - LOGS(logger, VERBOSE) << "Gather only supports input shapes >= 1D, but input is " - << rank << "d shape"; - return false; - } - - return true; -} - bool GatherOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t input_type; - int32_t indices_type; + int32_t input_type, indices_type; + if (!GetType(input, input_type, logger) || !GetType(indices, indices_type, logger)) return false; return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 02f46c85d1d06..7af17fdc5db78 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -91,7 +91,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); std::vector a_zero_point_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[2], a_zero_point_shape, logger), "Cannot get shape of a_zero_point"); - // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to deafult value 1.0f. + // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to default value 1.0f. // The scale input should have the same shape as the zero point input. a_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, @@ -268,11 +268,45 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - if (op_type == "MatMulInteger") { - // The first decomposed op of MatMulInteger is DequantizeLinear, and so - // we only need to ensure it supports the input0_type. + if (op_type == "Gemm") { + return IsInputRankSupportedByOp(node, wnn_limits, logger) && + IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); + } else if (op_type == "MatMulInteger") { + // Check up to 4 inputs for MatMulInteger + for (size_t i = 0; i < input_defs.size(); ++i) { + std::vector shape; + if (!GetShape(*input_defs[i], shape, logger)) { + return false; + } + + // We made workaround to support 1D for input A and B, skip further checks if they are 1D + if (i <= 1 && shape.size() == 1) { + continue; + } + + // For DequantizeLinear, input indices: 0 (x), 1 (scale), 2 (zero_point) + if (!IsInputRankSupported(wnn_limits, "dequantizeLinear", + (i < 2) ? "input" : "zeroPoint", + shape.size(), node.Name(), logger)) { + return false; + } + } return IsDataTypeSupportedByOp("DequantizeLinear", input0_type, wnn_limits, "input", "x", logger); - } else { + } else { // MatMul + for (int i = 0; i < 2; ++i) { + std::vector shape; + if (!GetShape(*input_defs[i], shape, logger)) { + return false; + } + + if (shape.size() == 1) { + continue; + } + + if (!IsInputRankSupported(wnn_limits, "matmul", (i == 0) ? "a" : "b", shape.size(), node.Name(), logger)) { + return false; + } + } return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index dfe80dd419092..95e75a3083cc2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -219,7 +219,8 @@ bool GruOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger); + return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 42940083cad8e..55d468c4843cb 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -91,8 +91,10 @@ bool LogicalOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no } } + const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); std::string onnx_input_name = op_type == "Not" ? "X" : "A"; - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc index 8936bda875aef..e8aab725375ad 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc @@ -21,8 +21,6 @@ class LRNOpBuilder : public BaseOpBuilder { // Operator support related. private: - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, @@ -128,11 +126,10 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. -bool LRNOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { +bool LRNOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + std::vector input_shape; if (!GetShape(*input_defs[0], input_shape, logger)) return false; @@ -143,12 +140,6 @@ bool LRNOpBuilder::IsOpSupportedImpl(const GraphViewer&, return false; } - return true; -} - -bool LRNOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, - const emscripten::val& wnn_limits, const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); int32_t input_type = 0; if (!GetType(*input_defs[0], input_type, logger)) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 09e584bc66f8a..04d59e2f30d15 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -242,7 +242,8 @@ bool LstmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } bool LstmOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc index 111d03571e974..9ab403b7051d2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc @@ -48,7 +48,7 @@ void MatMulNBitsBuilder::AddInitializersToSkip(ModelBuilder& model_builder, cons // DequantizeLinear + Transpose + MatMul. Given that the CPU EP currently only supports // 4-bit quantization, we only handle 4-bit quantization here. // -// To align with WebNN's dequantizeLinear op contraints, the following transformations are +// To align with WebNN's dequantizeLinear op constraints, the following transformations are // required for MatMulNBits inputs: // 1. B: must be a constant initializer and registered as a 'uint4' WebNN constant with shape // [N, n_blocks_per_col, blob_size * 2]. @@ -159,10 +159,6 @@ bool MatMulNBitsBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const logging::Logger& logger) const { const auto& name = node.Name(); const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - return false; - } // Inputs B and zero_points (if present) must be initializers if (!graph_viewer.GetConstantInitializer(input_defs[1]->Name())) { // B @@ -193,6 +189,10 @@ bool MatMulNBitsBuilder::HasSupportedInputsImpl(const GraphViewer&, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } int32_t A_type = 0; int32_t B_type = 0; @@ -227,10 +227,13 @@ bool MatMulNBitsBuilder::HasSupportedInputsImpl(const GraphViewer&, return false; } - // We only support 4-bit quantization, which is represented as the uint4 data type in WebNN. - // Ensure that uint4 is supported. + // Data type: Currently, only 4-bit quantization is supported, represented as the uint4 data type in WebNN. + // Ensure that the uint4 data type is supported by WebNN's dequantizeLinear op. + // Input rank: Only the rank of the first input (A) is flexible. Verify that its rank is supported by + // WebNN's matmul op. return IsDataTypeSupportedByOp("DequantizeLinear", ONNX_NAMESPACE::TensorProto_DataType_UINT4, - wnn_limits, "input", "x", logger); + wnn_limits, "input", "x", logger) && + IsInputRankSupported(wnn_limits, "matmul", "a", input_shape.size(), node.Name(), logger); } bool MatMulNBitsBuilder::HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 4e4014e3553ea..9f5ac6ef15735 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -20,8 +20,6 @@ class MaxMinOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const Node& node, - WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -68,25 +66,6 @@ Status MaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. -bool MaxMinOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - const auto& op_type = node.OpType(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - if (input_defs.size() < 1) { - LOGS(logger, VERBOSE) << op_type << " requires at least one input (data)"; - return false; - } - - return true; -} - bool MaxMinOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); @@ -108,7 +87,8 @@ bool MaxMinOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 148eacac98e4a..9fb643f055ef3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -46,28 +46,14 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); - std::vector scale_shape; const size_t scale_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 2 : 1; - ORT_RETURN_IF_NOT(GetShape(*input_defs[scale_input_index], scale_shape, logger), "Cannot get scale shape"); - const auto scale_size = scale_shape.size(); - // Except LayerNormalization, other normalization ops' scale input should be 1-D. - if (op_type == "LayerNormalization") { - ORT_RETURN_IF_NOT(scale_size >= 1 && scale_size <= rank, - "The scale size should be less than or equal to input size."); - } else { - ORT_RETURN_IF_NOT(scale_size == 1, "The scale size should be one."); - } - emscripten::val scale = model_builder.GetOperand(input_defs[scale_input_index]->Name()); options.set("scale", scale); const size_t bias_input_index = op_type == "SkipSimplifiedLayerNormalization" ? 3 : 2; emscripten::val bias = emscripten::val::undefined(); if (TensorExists(input_defs, bias_input_index)) { - // Bias input exists, and bias's shape should be the same as scale's shape. - std::vector bias_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[bias_input_index], bias_shape, logger), "Cannot get bias shape"); - ORT_RETURN_IF_NOT(bias_shape == scale_shape, "The bias' shape should be equal to scale's shape."); + // Bias input exists. bias = model_builder.GetOperand(input_defs[bias_input_index]->Name()); options.set("bias", bias); } @@ -279,12 +265,6 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const GraphViewer&, return false; } - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - LOGS(logger, VERBOSE) << "Cannot get input shape."; - return false; - } - const auto& output_defs = node.OutputDefs(); if (op_type == "SkipSimplifiedLayerNormalization") { if (output_defs.size() > 4) { @@ -316,33 +296,28 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const No const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); - int32_t input0_type; // input data type - int32_t input1_type; // scale data type - int32_t input2_type; // B data type - int32_t input3_type; // mean data type - int32_t input4_type; // var data type - bool has_input2 = TensorExists(input_defs, 2); - bool has_input3 = TensorExists(input_defs, 3); - bool has_input4 = TensorExists(input_defs, 4); - - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger) || - (has_input2 && !GetType(*input_defs[2], input2_type, logger)) || - (has_input3 && !GetType(*input_defs[3], input3_type, logger)) || - (has_input4 && !GetType(*input_defs[4], input4_type, logger))) { - return false; - } - std::vector input_types = {input0_type, input1_type}; - if (has_input2) { - input_types.push_back(input2_type); - } - if (has_input3) { - input_types.push_back(input3_type); + std::vector input_types; + bool all_types_valid = true; + + // Iterate through all inputs and check their existence and types + for (size_t i = 0; i <= input_defs.size(); ++i) { + if (TensorExists(input_defs, i)) { + int32_t input_type; + if (!GetType(*input_defs[i], input_type, logger)) { + all_types_valid = false; + break; + } + input_types.push_back(input_type); + } } - if (has_input4) { - input_types.push_back(input4_type); + + // Return false if any input type is invalid + if (!all_types_valid) { + return false; } + + // Check if all input data types are the same if (!AreDataTypesSame(op_type, input_types, logger)) { return false; } @@ -355,13 +330,29 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const No const std::string_view webnn_op_type = GetWebNNOpType(decomposed_op_type); const std::string_view webnn_input_name = GetWebNNOpFirstInputName(decomposed_op_type); if (!IsDataTypeSupportedByWebNNOp( - op_type, webnn_op_type, input0_type, wnn_limits, webnn_input_name, "input", logger)) { + op_type, webnn_op_type, input_types[0], wnn_limits, webnn_input_name, "input", logger)) { return false; } } - return true; + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } + // It's complicated to check all the decomposed ops' input rank support. + // Ensure at least the first input rank is supported by the decomposed ops (pow and div accept the first input). + return IsInputRankSupported(wnn_limits, "pow", "a", input_shape.size(), node.Name(), logger) && + IsInputRankSupported(wnn_limits, "div", "a", input_shape.size(), node.Name(), logger); } else { - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + bool is_data_type_supported = IsDataTypeSupportedByOp(op_type, input_types[0], wnn_limits, "input", "X", logger); + if (op_type == "InstanceNormalization") { + // Skip input rank check for InstanceNormalization, as we will reshape the input to 4D if necessary. + return is_data_type_supported; + } + + // For other ops, check both data type and input rank compatibility. + bool is_input_rank_supported = IsInputRankSupportedByOp(node, wnn_limits, logger); + return is_input_rank_supported && is_data_type_supported; } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc index f2a3f08b73148..5d921c5176a64 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -133,20 +133,6 @@ bool PoolOpBuilder::IsOpSupportedImpl(const GraphViewer&, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& op_type = node.OpType(); - const auto& input_defs = node.InputDefs(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - const auto input_size = input_shape.size(); - if (input_size != 4) { - LOGS(logger, VERBOSE) - << op_type << " only supports rank-4 tensor, input [" - << input_defs[0]->Name() << "] has actual dim count " << input_size; - return false; - } - NodeAttrHelper helper(node); if (op_type == "AveragePool" || op_type == "LpPool" || op_type == "MaxPool") { if (helper.Get("kernel_shape", std::vector{1, 1}).size() != 2) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index dd25fb9bf9315..053c41773db40 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -167,7 +167,8 @@ bool QDQOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && + IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "scale", "x_scale", logger) && (!has_input2 || IsDataTypeSupportedByOp(op_type, input2_type, wnn_limits, "zeroPoint", "x_zero_point", logger)); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index a3a0397eda4a3..6ea9b0a440d93 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -128,16 +128,10 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - const auto& op_type = node.OpType(); const std::string axes_name = GetTensorName(input_defs, 1); // If the optional input 'axes' is provided, it must be an initializer. if (!axes_name.empty() && !graph_viewer.GetConstantInitializer(axes_name)) { - LOGS(logger, VERBOSE) << "Input axes of " << op_type << " must be a constant"; + LOGS(logger, VERBOSE) << "Input axes of " << node.OpType() << " must be a constant"; return false; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc index 8cbb381e0f53e..0444ae3afb56a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc @@ -79,11 +79,6 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - const auto& perm_name = input_defs[1]->Name(); const auto* perm_init = graph_viewer.GetConstantInitializer(perm_name); if (!perm_init) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc index 893ca9d2419c7..37071b1030e11 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc @@ -285,7 +285,7 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build sign_buffer.set(1, 1.0f); } else if (input_data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { if (model_builder.IsFloat16ArrayAvailable()) { - // Float16Array is avaliable - use Float16Array. + // Float16Array is available - use Float16Array. sign_buffer = emscripten::val::global("Float16Array").new_(2); sign_buffer.set(0, -1.0f); sign_buffer.set(1, 1.0f); diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc index f894e8bfbd517..c2974bd988f6b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc @@ -71,7 +71,6 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& updates = *node.InputDefs()[2]; - const std::string_view op_type = node.OpType(); int32_t data_type; int32_t indices_type; @@ -85,8 +84,11 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const return false; } + const std::string_view op_type = node.OpType(); + return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateScatterElementsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc index e61ac3dcc9617..a7788cfd847e9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc @@ -63,7 +63,6 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& updates = *node.InputDefs()[2]; - const std::string_view op_type = node.OpType(); int32_t data_type; int32_t indices_type; @@ -76,9 +75,10 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& if (data_type != updates_type) { return false; } - + const std::string_view op_type = node.OpType(); return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && - IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateScatterNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index 8853891ff8ed6..5efbfe932c602 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -136,10 +136,6 @@ bool SliceOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const No const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) { - return false; - } if (input_defs.size() < 3) { LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 3 inputs (data, starts, ends) but got " @@ -166,10 +162,17 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const GraphViewer& graph_viewer, con const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& input = *input_defs[0]; - const std::string_view op_type = node.OpType(); + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } + int32_t input_type; - if (!GetType(input, input_type, logger)) + if (!GetType(input, input_type, logger)) { return false; + } + + const std::string_view op_type = node.OpType(); // If there is step < 0, check data type support of reverse. if (TensorExists(input_defs, 4)) { @@ -178,13 +181,15 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const GraphViewer& graph_viewer, con if (!init || !ReadIntArrayFrom1DTensor(*init, steps, graph_viewer, logger)) return false; if (std::any_of(steps.begin(), steps.end(), [](int64_t step) { return step < 0; })) { - if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger)) { + if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger) || + !IsInputRankSupported(wnn_limits, "reverse", "input", input_shape.size(), node.Name(), logger)) { return false; } } } - return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger); + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 23e73bb8f1e74..99d137f81864c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -18,11 +18,6 @@ class SoftmaxOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - private: - bool IsOpSupportedImpl(const GraphViewer&, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -46,20 +41,6 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -// Operator support related. - -bool SoftmaxOpBuilder::IsOpSupportedImpl(const GraphViewer&, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - return true; -} - void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc index 1ba6df9febf14..7e34e35ebac16 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc @@ -127,9 +127,6 @@ bool SqueezeUnsqueezeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewe const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; if (input_defs.size() < 1) { LOGS(logger, ERROR) << op_type << " has no input tensor"; diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 7a7f64b1ec96d..8973757a24e99 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -66,7 +66,8 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no return false; } - return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); + return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger) && + IsInputRankSupportedByOp(node, wnn_limits, logger); } void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc index 29b232026d7df..24d96588559ae 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/tile_op_builder.cc @@ -77,15 +77,6 @@ bool TileOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, return false; } - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - if (input_shape.empty()) { - LOGS(logger, VERBOSE) << "Tile does not support empty input shape"; - return false; - } - return true; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc index 5a267557b9454..7a4d172c556fa 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc @@ -76,15 +76,6 @@ bool TriangularOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - const auto input_size = input_shape.size(); - if (input_size < 2) { - LOGS(logger, VERBOSE) << "Triangular only supports input size >= 2D shape, input is " - << input_size << "d shape"; - return false; - } const std::string diagonal_name = GetTensorName(input_defs, 1); // Inputs contain optional 'diagonal' input. diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h index 5e860eea7cac9..1c30fed7a7916 100644 --- a/onnxruntime/core/providers/webnn/builders/map_info.h +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -47,6 +47,7 @@ constexpr std::array supported_fallback // Use ONNX-to-ONNX op mapping to improve the search complexity for WebNN ops in the op_inputs_map. const std::map> decomposed_op_map = { {"ConvInteger", {"Cast", "Conv", "DequantizeLinear"}}, + {"Einsum", {"MatMul", "Mul", "ReduceSum", "Reshape", "Transpose", "Trilu"}}, {"GroupQueryAttention", {"Add", "Cast", "Concat", "CumSum", "Div", "Expand", "Less", "MatMul", "Reshape", "ScatterND", "Softmax", "Transpose", "Where"}}, @@ -139,7 +140,7 @@ const std::unordered_map op_inputs_map = { {"Mul", {"mul", {{0, "a"}, {1, "b"}}}}, {"Pow", {"pow", {{0, "a"}, {1, "b"}}}}, {"Concat", {"concat", {{0, "inputs"}}}}, - {"Not", {"logicalNot", {{0, "input"}}}}, + {"Not", {"logicalNot", {{0, "a"}}}}, {"Flatten", {"reshape", {{0, "input"}}}}, {"LpPool", {"l2Pool2d", {{0, "input"}}}}, {"Reshape", {"reshape", {{0, "input"}}}}, @@ -159,7 +160,6 @@ const std::unordered_map op_inputs_map = { {"Softsign", {"softsign", {{0, "input"}}}}, {"Unsqueeze", {"reshape", {{0, "input"}}}}, {"Or", {"logicalOr", {{0, "a"}, {1, "b"}}}}, - {"Einsum", {"matmul", {{0, "a"}, {1, "b"}}}}, {"HardSwish", {"hardSwish", {{0, "input"}}}}, {"LeakyRelu", {"leakyRelu", {{0, "input"}}}}, {"MatMul", {"matmul", {{0, "a"}, {1, "b"}}}}, diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 4468831181d42..d2cd0639affd0 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -78,7 +78,7 @@ class ModelBuilder { const GraphViewer& graph_viewer_; const logging::Logger& logger_; const bool is_float16array_available_ = !emscripten::val::global("Float16Array").isUndefined() && - emscripten::val::global("Float16Array").hasOwnProperty("from"); + !emscripten::val::global("Float16Array")["from"].isUndefined(); emscripten::val wnn_context_ = emscripten::val::undefined(); emscripten::val wnn_builder_ = emscripten::val::undefined(); diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index d910e3ea74b57..59b0992d827e1 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -128,6 +128,35 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelPath, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + const ORTCHAR_T* output_directory, + const ORTCHAR_T* model_name) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + std::string output_dir = PathToUTF8String(output_directory); + if (output_dir.empty()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output directory: path is empty"); + } + + std::string model_name_str = ToUTF8String(model_name); + if (model_name_str.empty()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid model name: string is empty"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetEpContextBinaryInformation(output_dir, model_name_str)); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(output_directory); + ORT_UNUSED_PARAMETER(model_name); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelExternalInitializersFile, _In_ OrtModelCompilationOptions* ort_model_compile_options, const ORTCHAR_T* external_initializers_file_path, @@ -248,6 +277,7 @@ static constexpr OrtCompileApi ort_compile_api = { // End of Version 22 - DO NOT MODIFY ABOVE &OrtCompileAPI::ModelCompilationOptions_SetFlags, + &OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index 5f11b894f2004..93cc5dbf20fce 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -30,5 +30,7 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModel ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_options, size_t flags); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const ORTCHAR_T* output_dir, _In_ const ORTCHAR_T* model_name); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/ep_api_utils.h index daccd24453371..a0904c32011a7 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/ep_api_utils.h @@ -16,6 +16,10 @@ struct ForwardToFactory { return static_cast(this_ptr)->GetVendor(); } + static uint32_t ORT_API_CALL GetVendorId(const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->GetVendorId(); + } + static const char* ORT_API_CALL GetVersion(const OrtEpFactory* this_ptr) noexcept { return static_cast(this_ptr)->GetVersion(); } diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/ep_factory_internal.cc index b289010cc6c5b..fa4ef2515ca92 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/ep_factory_internal.cc @@ -14,17 +14,19 @@ namespace onnxruntime { using Forward = ForwardToFactory; -EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, +EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id, GetSupportedFunc&& get_supported_func, CreateFunc&& create_func) : ep_name_{ep_name}, vendor_{vendor}, + vendor_id_{vendor_id}, get_supported_func_{std::move(get_supported_func)}, create_func_{create_func} { ort_version_supported = ORT_API_VERSION; OrtEpFactory::GetName = Forward::GetFactoryName; OrtEpFactory::GetVendor = Forward::GetVendor; + OrtEpFactory::GetVendorId = Forward::GetVendorId; OrtEpFactory::GetVersion = Forward::GetVersion; OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices; OrtEpFactory::CreateEp = Forward::CreateEp; diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/ep_factory_internal.h index 087c0c60f8f4e..ee08e2233c529 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/ep_factory_internal.h @@ -33,12 +33,13 @@ class EpFactoryInternal : public OrtEpFactory { const OrtSessionOptions* session_options, const OrtLogger* logger, std::unique_ptr* ep)>; - EpFactoryInternal(const std::string& ep_name, const std::string& vendor, + EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id, GetSupportedFunc&& get_supported_func, CreateFunc&& create_func); const char* GetName() const noexcept { return ep_name_.c_str(); } const char* GetVendor() const noexcept { return vendor_.c_str(); } + uint32_t GetVendorId() const noexcept { return vendor_id_; } const char* GetVersion() const noexcept; OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -67,6 +68,7 @@ class EpFactoryInternal : public OrtEpFactory { private: const std::string ep_name_; // EP name library was registered with const std::string vendor_; // EP vendor name + const uint32_t vendor_id_; // EP vendor ID const GetSupportedFunc get_supported_func_; // function to return supported devices const CreateFunc create_func_; // function to create the EP instance diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc index 25f70f7549a16..ce5736f601b45 100644 --- a/onnxruntime/core/session/ep_library_internal.cc +++ b/onnxruntime/core/session/ep_library_internal.cc @@ -61,7 +61,8 @@ std::unique_ptr EpLibraryInternal::CreateCpuEp() { }; std::string ep_name = kCpuExecutionProvider; - auto cpu_factory = std::make_unique(ep_name, "Microsoft", get_supported, create_cpu_ep); + auto cpu_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, + get_supported, create_cpu_ep); return std::make_unique(std::move(cpu_factory)); } @@ -122,7 +123,8 @@ std::unique_ptr EpLibraryInternal::CreateDmlEp() { return nullptr; }; - auto dml_factory = std::make_unique(ep_name, "Microsoft", is_supported, create_dml_ep); + auto dml_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, + is_supported, create_dml_ep); return std::make_unique(std::move(dml_factory)); } @@ -170,7 +172,8 @@ std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { return nullptr; }; - auto webgpu_factory = std::make_unique(ep_name, "Microsoft", is_supported, create_webgpu_ep); + auto webgpu_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, + is_supported, create_webgpu_ep); return std::make_unique(std::move(webgpu_factory)); } diff --git a/onnxruntime/core/session/ep_library_provider_bridge.cc b/onnxruntime/core/session/ep_library_provider_bridge.cc index 73423a4744576..70937bdc5d3e8 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.cc +++ b/onnxruntime/core/session/ep_library_provider_bridge.cc @@ -72,6 +72,7 @@ Status EpLibraryProviderBridge::Load() { auto internal_factory = std::make_unique(factory->GetName(factory), factory->GetVendor(factory), + factory->GetVendorId(factory), is_supported_fn, create_fn); factory_ptrs_.push_back(internal_factory.get()); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 86a61a4d0ee74..f147242da668f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -423,7 +423,13 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, { if (!external_intra_op_thread_pool_) { bool allow_intra_op_spinning = +#if !defined(ORT_CLIENT_PACKAGE_BUILD) session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowIntraOpSpinning, "1") == "1"; +#else + // default KOrtSessionOptionsConfigAllowIntraOpSpinning to "0" for ORT builds targeting client/on-device workloads, + // to reduce CPU utilization and improve power efficiency. + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowIntraOpSpinning, "0") == "1"; +#endif OrtThreadPoolParams to = session_options_.intra_op_param; std::basic_stringstream ss; if (to.name) { @@ -461,7 +467,13 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, if (session_options_.execution_mode == ExecutionMode::ORT_PARALLEL) { if (!external_inter_op_thread_pool_) { bool allow_inter_op_spinning = +#if !defined(ORT_CLIENT_PACKAGE_BUILD) session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowInterOpSpinning, "1") == "1"; +#else + // default kOrtSessionOptionsConfigAllowInterOpSpinning to "0" for ORT builds targeting client/on-device workloads, + // to reduce CPU utilization and improve power efficiency. + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigAllowInterOpSpinning, "0") == "1"; +#endif OrtThreadPoolParams to = session_options_.inter_op_param; to.auto_set_affinity = to.thread_pool_size == 0 && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL; std::basic_stringstream ss; diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index 5de0f03fafc08..bbb110033f54c 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -72,8 +72,8 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod if (log_manager != nullptr && log_manager->HasDefaultLogger()) { const logging::Logger& logger = log_manager->DefaultLogger(); LOGS(logger, WARNING) << "Output model path length (" << ep_context_gen_options.output_model_file_path.size() - << ") exceeds limit of " << ConfigOptions::kMaxKeyLength << " characters." - << "ORT will still generated the expected output file, but EPs will see an empty " + << ") exceeds limit of " << ConfigOptions::kMaxValueLength << " characters." + << "ORT will still generate the expected output file, but EPs will see an empty " << "output model path in SessionOption's ConfigOptions."; } } @@ -98,6 +98,36 @@ Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr a return Status::OK(); } +Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::string& output_directory, + const std::string& model_name) { + if (output_directory.empty() || model_name.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir or model_name is empty."); + } + + std::filesystem::path output_dir_path(output_directory); + if (output_dir_path.has_filename() && output_dir_path.extension() == "") { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir is not a valid directory."); + } + + std::filesystem::path ctx_model_path = output_directory / std::filesystem::path(model_name); + + if (ctx_model_path.string().size() <= ConfigOptions::kMaxValueLength) { + ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, + ctx_model_path.string().c_str())); + } else { + logging::LoggingManager* log_manager = env_.GetLoggingManager(); + if (log_manager != nullptr && log_manager->HasDefaultLogger()) { + const logging::Logger& logger = log_manager->DefaultLogger(); + LOGS(logger, WARNING) << "output_directory length with model_name length together exceeds limit of " + << ConfigOptions::kMaxValueLength << " characters." + << "ORT will still generate the expected output file, but EPs will see an empty " + << "output path in SessionOption's ConfigOptions."; + } + } + + return Status::OK(); +} + Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_model) { ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry( kOrtSessionOptionEpContextEmbedMode, embed_ep_context_in_model ? "1" : "0")); @@ -146,7 +176,7 @@ Status ModelCompilationOptions::ResetOutputModelSettings() { ep_context_gen_options.output_model_buffer_ptr = nullptr; ep_context_gen_options.output_model_buffer_size_ptr = nullptr; ep_context_gen_options.output_model_buffer_allocator = nullptr; - return session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ""); + return Status::OK(); } Status ModelCompilationOptions::CheckInputModelSettings() const { diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index f96f0317cdaca..2824df863013d 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -72,6 +72,16 @@ class ModelCompilationOptions { Status SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); + /// + /// Sets information relate to EP context binary file. + /// EP use this information to decide the location and context binary file name. + /// Used while compiling model with input and output in memory buffer + /// + /// The folder path to the generated context binary file + /// Model name used to decide the context binary file name: [model_name]_[ep].bin + /// Status indicating potential error + Status SetEpContextBinaryInformation(const std::string& output_directory, const std::string& model_name); + /// /// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute of EPContext /// nodes. Defaults to false (dumped to file). diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index e7f60fd48a14f..db2a62c77d1bc 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2591,6 +2591,29 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets) { + API_IMPL_BEGIN + if (num_operator_sets == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_operator_sets' argument is NULL"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetNumOperatorSets(*num_operator_sets)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Graph_GetOperatorSets, _In_ const OrtGraph* graph, + _Out_writes_(num_operator_sets) const char** domains, + _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets) { + API_IMPL_BEGIN + gsl::span domains_span(domains, num_operator_sets); + gsl::span versions_span(opset_versions, num_operator_sets); + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetOperatorSets(domains_span, versions_span)); + + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::Graph_GetNumInputs, _In_ const OrtGraph* graph, _Out_ size_t* num_inputs) { API_IMPL_BEGIN if (num_inputs == nullptr) { @@ -2691,6 +2714,91 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, _O API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph, + _In_ const OrtNode** nodes, + _In_ size_t num_nodes, + _Outptr_ OrtGraph** dst_graph) { + API_IMPL_BEGIN + + if (num_nodes == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_nodes' argument should be > 0"); + } + + const EpGraph* ep_graph = EpGraph::ToInternal(src_graph); + if (ep_graph == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "src_graph is a ModelEditorGraph which doesn't support Graph_GetSubGraph."); + } + const Graph& graph = ep_graph->GetGraphViewer().GetGraph(); + + // Create a GraphViewer with filtered info + std::unique_ptr indexed_sub_graph = std::make_unique(); + std::unique_ptr metadef = std::make_unique(); + metadef->name = "sub_graph"; + metadef->since_version = 1; + std::unordered_set outputs; + std::unordered_set initializers; + + auto add_inputs = [&](ConstPointerContainer> defs) { + for (const auto* def : defs) { + if (def->Exists()) { + // not the output of a previous node + if (outputs.count(def->Name()) == 0) { + metadef->inputs.push_back(def->Name()); + } else { + // consumed by node so no longer subgraph output + // NOTE: Ignoring edge case where a node output is an overall graph output AND a node input + outputs.erase(def->Name()); + } + + if (graph.IsInitializedTensor(def->Name())) { + initializers.insert(def); + } + } + } + }; + + auto add_node = [&](const Node& node) { + indexed_sub_graph->nodes.push_back(node.Index()); + add_inputs(node.InputDefs()); + add_inputs(node.ImplicitInputDefs()); + + for (const auto* def : node.OutputDefs()) { + outputs.insert(def->Name()); + } + }; + + // Add nodes + for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) { + const OrtNode* ort_node = nodes[node_idx]; + const EpNode* ep_node = EpNode::ToInternal(ort_node); + if (ep_node == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Graph_GetSubGraph."); + } + add_node(ep_node->GetInternalNode()); + } + + // Add initializers + for (auto& initializer : initializers) { + metadef->constant_initializers.push_back(initializer->Name()); + } + + // Add outputs + for (auto& output : outputs) { + metadef->outputs.push_back(output); + } + + indexed_sub_graph->SetMetaDef(std::move(metadef)); + auto graph_viewer = std::make_unique(graph, *indexed_sub_graph.get()); + + std::unique_ptr result; + ORT_API_RETURN_IF_STATUS_NOT_OK(EpGraph::Create(std::move(graph_viewer), std::move(indexed_sub_graph), result)); + + *dst_graph = result.release(); + + return nullptr; + API_IMPL_END +} + // // OrtNode // @@ -2922,10 +3030,11 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetNumSubgraphs, _In_ const OrtNode* node, _Ou } ORT_API_STATUS_IMPL(OrtApis::Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs) { + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, + _Out_writes_opt_(num_subgraphs) const char** attribute_names) { API_IMPL_BEGIN gsl::span graphs_span(subgraphs, num_subgraphs); - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetSubgraphs(graphs_span)); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetSubgraphs(graphs_span, attribute_names)); return nullptr; API_IMPL_END } @@ -2943,6 +3052,23 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetGraph, _In_ const OrtNode* node, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Node_GetEpName, _In_ const OrtNode* node, + _Outptr_result_maybenull_ const char** out) { + API_IMPL_BEGIN + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL"); + } + + const EpNode* ep_node = EpNode::ToInternal(node); + if (ep_node == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetEpName."); + } + + *out = ep_node->GetEpName().c_str(); + return nullptr; + API_IMPL_END +} + ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { #ifdef ENABLE_TRAINING_APIS if (version >= 13 && version <= ORT_API_VERSION) @@ -3594,6 +3720,8 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ValueInfo_IsFromOuterScope, &OrtApis::Graph_GetName, &OrtApis::Graph_GetOnnxIRVersion, + &OrtApis::Graph_GetNumOperatorSets, + &OrtApis::Graph_GetOperatorSets, &OrtApis::Graph_GetNumInputs, &OrtApis::Graph_GetInputs, &OrtApis::Graph_GetNumOutputs, @@ -3603,6 +3731,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Graph_GetNumNodes, &OrtApis::Graph_GetNodes, &OrtApis::Graph_GetParentNode, + &OrtApis::Graph_GetGraphView, &OrtApis::Node_GetId, &OrtApis::Node_GetName, &OrtApis::Node_GetOperatorType, @@ -3622,6 +3751,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumSubgraphs, &OrtApis::Node_GetSubgraphs, &OrtApis::Node_GetGraph, + &OrtApis::Node_GetEpName, &OrtApis::GetRunConfigEntry, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index cbacbfce0740d..9ab927006c320 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -631,6 +631,10 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i // OrtGraph ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name); ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); +ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); +ORT_API_STATUS_IMPL(Graph_GetOperatorSets, _In_ const OrtGraph* graph, + _Out_writes_(num_operator_sets) const char** domains, + _Out_writes_(num_operator_sets) int64_t* opset_versions, _In_ size_t num_operator_sets); ORT_API_STATUS_IMPL(Graph_GetNumInputs, _In_ const OrtGraph* graph, _Out_ size_t* num_inputs); ORT_API_STATUS_IMPL(Graph_GetInputs, _In_ const OrtGraph* graph, _Out_writes_(num_inputs) const OrtValueInfo** inputs, _In_ size_t num_inputs); @@ -645,6 +649,8 @@ ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t* ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph, _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); +ORT_API_STATUS_IMPL(Graph_GetGraphView, _In_ const OrtGraph* graph, _In_ const OrtNode** nodes, _In_ size_t num_nodes, + _Outptr_ OrtGraph** subgraph); // OrtNode ORT_API_STATUS_IMPL(Node_GetId, _In_ const OrtNode* node, _Out_ size_t* node_id); @@ -671,8 +677,10 @@ ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOp ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, - _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs); + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs, + _Out_writes_opt_(num_subgraphs) const char** attribute_names); ORT_API_STATUS_IMPL(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); +ORT_API_STATUS_IMPL(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options, _In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value); diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index e8d62ab86f517..211bf8b2d15a4 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -22,7 +22,13 @@ namespace onnxruntime { namespace { bool MatchesEpVendor(const OrtEpDevice* d) { - // TODO: Would be better to match on Id. Should the EP add that in EP metadata? + // match on vendor id if provided + uint32_t factory_vendor_id = d->ep_factory->GetVendorId(d->ep_factory); + if (factory_vendor_id != 0 && d->device->vendor_id == factory_vendor_id) { + return true; + } + + // match on vendor name return d->device->vendor == d->ep_vendor; } diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index 0172902bdf4e2..f7d5cdb98aa1d 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -1001,4 +1001,53 @@ struct BlockedQuantizeLinear { #endif +/** + * @brief Run MlasDequantizeLinear in parallel, with provided thread pool + */ + +template +void ParDequantizeLinearStd(const InputQuantType* input, + float* output, + size_t num_elems, + float scale, + InputQuantType zero_point, + concurrency::ThreadPool* thread_pool) { + constexpr std::ptrdiff_t block_size = 128; + const std::ptrdiff_t num_blocks = (num_elems + block_size - 1) / block_size; + const TensorOpCost unit_cost{static_cast(block_size * sizeof(InputQuantType)), + static_cast(block_size * sizeof(float)), + static_cast(block_size) * 2.0}; + concurrency::ThreadPool::TryParallelFor(thread_pool, num_blocks, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto begin_idx = begin * block_size; + auto end_idx = std::min(static_cast(num_elems), end * block_size); + MlasDequantizeLinear(&(input[begin_idx]), &(output[begin_idx]), end_idx - begin_idx, scale, zero_point); + }); +} + +// Note: this doesn't use MLAS kernel. There are currently no MLAS kernels for fp16 QuantizeLinear or DequantizeLinear. +template +void ParDequantizeLinearStd(const InputQuantType* input, + MLFloat16* output, + size_t num_elems, + MLFloat16 scale, + InputQuantType zero_point, + concurrency::ThreadPool* thread_pool) { + constexpr std::ptrdiff_t block_size = 128; + const std::ptrdiff_t num_blocks = (num_elems + block_size - 1) / block_size; + const TensorOpCost unit_cost{static_cast(block_size * sizeof(InputQuantType)), + static_cast(block_size * sizeof(MLFloat16)), + static_cast(block_size) * 2.0}; + + const int32_t zp_s32 = static_cast(zero_point); + const float sc_f32 = scale.ToFloat(); + + concurrency::ThreadPool::TryParallelFor(thread_pool, num_blocks, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + auto begin_idx = begin * block_size; + auto end_idx = std::min(static_cast(num_elems), end * block_size); + for (; begin_idx != end_idx; ++begin_idx) { + output[begin_idx] = MLFloat16(static_cast(static_cast(input[begin_idx]) - zp_s32) * sc_f32); + } + }); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/util/thread_utils.h b/onnxruntime/core/util/thread_utils.h index d63d620dbc321..0b99723b2c75b 100644 --- a/onnxruntime/core/util/thread_utils.h +++ b/onnxruntime/core/util/thread_utils.h @@ -19,7 +19,13 @@ struct OrtThreadPoolParams { bool auto_set_affinity = false; // If it is true, the thread pool will spin a while after the queue became empty. +#if !defined(ORT_CLIENT_PACKAGE_BUILD) bool allow_spinning = true; +#else + // default allow_spinning to false for ORT builds targeting client/on-device workloads, + // to reduce CPU utilization and improve power efficiency. + bool allow_spinning = false; +#endif // It it is non-negative, thread pool will split a task by a decreasing block size // of remaining_of_total_iterations / (num_of_threads * dynamic_block_base_) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 9a297e451213a..e3303dac6c8c5 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -42,7 +42,7 @@ def __init__(self, **data: dict[str, Any]): for k, v in data.items(): if not isinstance(k, str): raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.") - if k != "axis" and not isinstance(v, (int, str, np.ndarray)): + if k != "axis" and not isinstance(v, (int, str, np.ndarray, float)): raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.") if k == "axis" and not isinstance(v, int) and v is not None: raise TypeError(f"Axis value must be an int or None, not {type(v)}.") diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index fbeae39c39d21..319c5aa468f7e 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -86,6 +86,7 @@ "InstanceNormalization": QDQNormalization, "LayerNormalization": QDQNormalization, "BatchNormalization": QDQNormalization, + "TopK": QDQDirect8BitOp, } diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py index fe93f5cd358bf..8711e368cd1e6 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_clip.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -269,42 +269,48 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): attention_last_node = reshape_qkv add_qk = "" + causal_mask_nodes_1 = None + causal_mask_nodes_2 = None if add_mask is not None: - # 4D Add after Q x K' - add_qk_nodes = self.model.match_parent_path( - add_mask, - [ - "Where", - "Sub", - "Cast", - "Expand", - "Unsqueeze", - "Unsqueeze", - "Reshape", - "Reshape", - "Cast", - ], - [1, 2, 1, 0, 0, 0, 0, 0, 0], - ) - if add_qk_nodes is not None: + if add_mask.input[1] == "attention_mask": add_qk = add_mask.input[1] else: - # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path - # of computing causal mask. - causal_mask_nodes_1 = self.model.match_parent_path( - add_mask, - ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], - [causal_mask_input_index, 0, 0, 0, 0, 0], - ) - # If the model is exported with batch_size == 1, there is no Concat node - causal_mask_nodes_2 = self.model.match_parent_path( + # 4D Add after Q x K' + add_qk_nodes = self.model.match_parent_path( add_mask, - ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], - [causal_mask_input_index, 0, 0, 0, 0], + [ + "Where", + "Sub", + "Cast", + "Expand", + "Unsqueeze", + "Unsqueeze", + "Reshape", + "Reshape", + "Cast", + ], + [1, 2, 1, 0, 0, 0, 0, 0, 0], ) - if causal_mask_nodes_1 is None and causal_mask_nodes_2 is None: - logger.debug("fuse_attention: failed to match causal mask subgraph") - return + if add_qk_nodes is not None: + add_qk = add_mask.input[1] + else: + # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path + # of computing causal mask. + causal_mask_nodes_1 = self.model.match_parent_path( + add_mask, + ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0, 0], + ) + # If the model is exported with batch_size == 1, there is no Concat node + causal_mask_nodes_2 = self.model.match_parent_path( + add_mask, + ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0], + ) + + if causal_mask_nodes_1 is None and causal_mask_nodes_2 is None: + logger.debug("fuse_attention: failed to match causal mask subgraph") + return new_node = self.create_attention_node( mask_index=None, @@ -320,7 +326,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): output=attention_last_node.output[0], add_qk_str=add_qk, scale=None, - causal=(add_mask is not None), + causal=(causal_mask_nodes_1 is not None) or (causal_mask_nodes_2 is not None), ) if new_node is None: logger.debug("fuse_attention: failed to create fused node") diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index 6bd698f8b75b4..e16957eab80a1 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,7 +1,7 @@ onnxscript>=0.2.3 optimum>=1.14.1 optree -transformers==4.48.0 +transformers==4.52.1 torch>=2.7.0 onnx==1.17.0 datasets>=2.8.0 diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index ac696ff3788aa..e092285d57358 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -410,7 +410,7 @@ def export_onnx_models( precision == Precision.FLOAT16, model.config.encoder_attention_heads, model.config.d_model, - model.config.num_hidden_layers, + model.config.decoder_layers, use_external_data_format, use_gpu=use_gpu, provider=provider, diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index f1758cc52280f..37fc72cd26e07 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -1,5 +1,5 @@ torch>=2.7.0 -transformers>=4.52.3 +transformers==4.52.3 openai-whisper==20240927 ffmpeg-python datasets diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index fadf271ae913b..e10e616d35d38 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -187,7 +187,7 @@ def input_names(self): *list( chain.from_iterable( (f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}") - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] @@ -205,7 +205,7 @@ def output_names(self): f"present_key_cross_{i}", f"present_value_cross_{i}", ) - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] @@ -214,8 +214,7 @@ def output_names(self): "logits", *list( chain.from_iterable( - (f"present_key_self_{i}", f"present_value_self_{i}") - for i in range(self.config.num_hidden_layers) + (f"present_key_self_{i}", f"present_value_self_{i}") for i in range(self.config.decoder_layers) ) ), ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 26dc3aee7018b..cd81edc1001be 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -127,7 +127,7 @@ def output_names(self): *list( chain.from_iterable( (f"present_key_cross_{i}", f"present_value_cross_{i}") - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] @@ -143,7 +143,7 @@ def output_names(self): f"present_key_cross_{i}", f"present_value_cross_{i}", ) - for i in range(self.config.num_hidden_layers) + for i in range(self.config.decoder_layers) ) ), ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index f66aa22eb0972..a236c4da1738e 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -763,7 +763,7 @@ def optimize_onnx( is_float16: bool, num_attention_heads: int, hidden_size: int, - num_layers: int, + num_decoder_layers: int, use_external_data_format: bool = False, use_gpu: bool = False, provider: str = "cpu", @@ -801,7 +801,7 @@ def optimize_onnx( m = add_cache_indirection_to_mha(m, past_seq_len_name) if output_qk: - m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_layers, 2))) + m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_decoder_layers, 2))) m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py index 0b0882eface72..8937fea900d14 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_inputs.py @@ -94,14 +94,14 @@ def get_sample_past_key_values( torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), ) - for _ in range(config.num_hidden_layers) + for _ in range(config.decoder_layers) ] cross_attention_kv_caches = [ ( torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype), torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype), ) - for _ in range(config.num_hidden_layers) + for _ in range(config.decoder_layers) ] return flatten_past_key_values(self_attention_kv_caches, cross_attention_kv_caches) @@ -187,7 +187,7 @@ def get_sample_QKs( # noqa: N802 torch.rand( batch_size, num_heads, sequence_length, config.max_source_positions, device=device, dtype=torch_dtype ) - for _ in range(config.num_hidden_layers) + for _ in range(config.decoder_layers) ] return QKs diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py index a7c0d3538b8da..4dd5d7de1752b 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_jump_times.py @@ -156,7 +156,7 @@ def input_names(self): "alignment_heads", "sot_sequence_length", "segment_length", - *[f"cross_qk_{i}" for i in range(self.config.num_hidden_layers)], + *[f"cross_qk_{i}" for i in range(self.config.decoder_layers)], ] return input_names diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index b498c40079f48..44b3f9a213abf 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -226,7 +226,7 @@ OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph) { /*static*/ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info) { + OrtEpGraphSupportInfo* graph_support_info) noexcept { ExampleEp* ep = static_cast(this_ptr); size_t num_nodes = 0; @@ -290,7 +290,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes) { + _Out_writes_(count) OrtNode** ep_context_nodes) noexcept { ExampleEp* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; @@ -328,6 +328,12 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[0], &node_input_names[0])); RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[1], &node_input_names[1])); + const char* ep_name = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetEpName(fused_nodes[0], &ep_name)); + if (std::strncmp(ep_name, "example_ep", 11) != 0) { + return ort_api.CreateStatus(ORT_EP_FAIL, "The fused node is expected to assigned to this EP to run on"); + } + // Associate the name of the fused node with our MulKernel. const char* fused_node_name = nullptr; RETURN_IF_ERROR(ort_api.Node_GetName(fused_nodes[0], &fused_node_name)); @@ -354,7 +360,7 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const /*static*/ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, - size_t num_node_compute_infos) { + size_t num_node_compute_infos) noexcept { (void)this_ptr; for (size_t i = 0; i < num_node_compute_infos; i++) { delete node_compute_infos[i]; diff --git a/onnxruntime/test/autoep/library/ep.h b/onnxruntime/test/autoep/library/ep.h index b8c63f39438ba..dfebcc52a0caf 100644 --- a/onnxruntime/test/autoep/library/ep.h +++ b/onnxruntime/test/autoep/library/ep.h @@ -31,14 +31,14 @@ class ExampleEp : public OrtEp, public ApiPtrs { private: static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; static OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info); + OrtEpGraphSupportInfo* graph_support_info) noexcept; static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, - _Out_writes_(count) OrtNode** ep_context_nodes); + _Out_writes_(count) OrtNode** ep_context_nodes) noexcept; static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, - size_t num_node_compute_infos); + size_t num_node_compute_infos) noexcept; OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index d4895102b0bf1..19a44008b8c97 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -14,6 +14,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis) ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; @@ -87,6 +88,12 @@ const char* ORT_API_CALL ExampleEpFactory::GetVendorImpl(const OrtEpFactory* thi return factory->vendor_.c_str(); } +/*static*/ +uint32_t ORT_API_CALL ExampleEpFactory::GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_id_; +} + /*static*/ const char* ORT_API_CALL ExampleEpFactory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h index fda77f12c4814..72fa1c1301841 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -21,6 +21,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept; static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; @@ -53,6 +54,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name + const uint32_t vendor_id_{0xB357}; // EP vendor ID const std::string ep_version_{"0.1.0"}; // EP version // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 7b77ca8c69225..4c3f9e8dd4dbd 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -527,18 +527,20 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop if (std::is_same_v) { #ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); + RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); + RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_DML execution_providers.push_back(DefaultDmlExecutionProvider()); + RunTest(opts, std::move(execution_providers)); #endif #ifdef USE_WEBGPU execution_providers.push_back(DefaultWebGpuExecutionProvider()); -#endif - RunTest(opts, std::move(execution_providers)); +#endif } else { #ifdef USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 60498e6510ec2..17e829e37f729 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -1,16 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include #include #include #include #include +#include #include "core/common/common.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_type_and_shape.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/graph/ep_api_types.h" +#include "core/graph/graph_proto_serializer.h" + +#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL +#include "core/providers/utils/ort_graph_to_proto.h" #include "test/ep_graph/test_ep_graph_utils.h" #include "test/util/include/api_asserts.h" @@ -26,6 +34,7 @@ namespace test { // forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent // to a graph represented by the internal ORT GraphViewer class. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph); +static void Check_Graph_GetSubgraph(const OrtGraph& api_graph); // // Tests @@ -68,6 +77,178 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { + // The overall structure of this model is similar to the one used in "Check3LayerNestedSubgraph" test. + // The model consists of a graph with subgraphs nested across three levels. + // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/three_layer_nested_subgraph_v2.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); +} + +static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& output_data) { + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::SessionOptions sess_options; + Ort::Session session(*ort_env, model_path, sess_options); + + std::vector input_shape = {1, 1, 28, 28}; + std::vector input_data(28 * 28, 0.5f); + std::vector ort_inputs; + std::vector ort_input_names; + + // Add 'Input3' + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size())); + ort_input_names.push_back("Input3"); + + // Run session and get outputs + std::array output_names{"Plus214_Output_0"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output type and number of elements. + Ort::Value& ort_output = ort_outputs[0]; + auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); + size_t num_output_elems = output_type_shape.GetElementCount(); + + ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + ASSERT_EQ(num_output_elems, 10); + + // Return output data. + const float* output_values = ort_output.GetTensorData(); + output_data.assign(output_values, output_values + num_output_elems); +} + +// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. +// Checks that the outputs of the serialized and original models are identical. +TEST(EpGraphTest, SerializeToProto_Mnist) { + const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/mnist.onnx"); + const ORTCHAR_T* serialized_model_path = ORT_TSTR("mnist_serialized.onnx"); + std::filesystem::remove(serialized_model_path); + + { + auto test_graph = TestGraph::Load(original_model_path); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // Serialize OrtGraph to GraphProto. Save initializers to external file. + std::string ext_ini_file_path = "mnist_serialized.bin"; + std::filesystem::remove(ext_ini_file_path); + std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, + // node consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = ext_ini_ofs.tellp(); + location = ext_ini_file_path; + ext_ini_ofs.write(static_cast(data), bytes); + ext_ini_ofs.flush(); + is_external = true; // True if is external initializer. + + return Ort::Status{nullptr}; + }; + + ONNX_NAMESPACE::ModelProto model_proto; + OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, handle_initializer_data); + + std::ofstream ofs(serialized_model_path, std::ios::binary); + model_proto.SerializeToOstream(&ofs); + ofs.flush(); + + ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); + ASSERT_TRUE(std::filesystem::exists(ext_ini_file_path)); + } + + // Compare output of the original and serialized models. Should be identical. + std::vector output_original; + std::vector output_serialized; + + RunMNISTModel(original_model_path, output_original); + RunMNISTModel(serialized_model_path, output_serialized); + + EXPECT_EQ(output_serialized, output_original); +} + +static void Run3LayerModel(const ORTCHAR_T* model_path, bool input_cond, std::vector& output_data) { + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::SessionOptions sess_options; + Ort::Session session(*ort_env, model_path, sess_options); + + std::vector input_shape = {1}; + std::vector ort_inputs; + std::vector ort_input_names; + + // Add 'if_cond_input' + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, &input_cond, 1, input_shape.data(), input_shape.size())); + ort_input_names.push_back("if_cond_input"); + + // Run session and get outputs + std::array output_names{"if_cond_output"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output type and number of elements. + Ort::Value& ort_output = ort_outputs[0]; + auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); + size_t num_output_elems = output_type_shape.GetElementCount(); + + ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + ASSERT_EQ(num_output_elems, 1); + + // Return output data. + const float* output_values = ort_output.GetTensorData(); + output_data.assign(output_values, output_values + num_output_elems); +} + +// Test serializing an OrtGraph to GraphProto. The model has 3 layers of nested subgraphs. +// Checks that the outputs of the serialized and original models are identical. +TEST(EpGraphTest, SerializeToProto_3LayerSubgraphs) { + const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/three_layer_nested_subgraph.onnx"); + const ORTCHAR_T* serialized_model_path = ORT_TSTR("three_layer_nested_subgraph_serialized.onnx"); + std::filesystem::remove(serialized_model_path); + + { + auto test_graph = TestGraph::Load(original_model_path); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // Serialize OrtGraph to ModelProto (all initializers stored within TensorProtos). + ONNX_NAMESPACE::ModelProto model_proto; + OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto); + + std::ofstream ofs(serialized_model_path, std::ios::binary); + model_proto.SerializeToOstream(&ofs); + ofs.flush(); + + ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); + } + + // Compare output of the original and serialized models. Should be identical. + std::vector output_original; + std::vector output_serialized; + + { + Run3LayerModel(original_model_path, true, output_original); + Run3LayerModel(serialized_model_path, true, output_serialized); + EXPECT_EQ(output_serialized, output_original); + } + + { + Run3LayerModel(original_model_path, false, output_original); + Run3LayerModel(serialized_model_path, false, output_serialized); + EXPECT_EQ(output_serialized, output_original); + } +} + // // Utils for traversing an OrtGraph and checking against GraphViewer. // @@ -307,6 +488,48 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span nodes(num_nodes); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, nodes.data(), nodes.size())); + + // Select a half of nodes to create a OrtGraph + size_t num_selected_nodes = std::max((nodes.size() >> 1), (size_t)1); + std::vector selected_nodes(num_selected_nodes); + + for (size_t i = 0; i < num_selected_nodes; i++) { + selected_nodes[i] = nodes[i]; + } + + OrtGraph* sub_graph; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetGraphView(&api_graph, selected_nodes.data(), selected_nodes.size(), &sub_graph)); + + // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk. + // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw. + const GraphViewer& sub_graph_viewer = EpGraph::ToInternal(sub_graph)->GetGraphViewer(); + std::unique_ptr model = std::make_unique(sub_graph_viewer.Name(), true, sub_graph_viewer.GetGraph().GetLogger()); + auto model_proto = std::make_unique(model->ToProto()); + GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast(1)); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + + const char* graph_name = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetName(&api_graph, &graph_name)); + std::string name = graph_name; + name += "_half.onnx"; + + // Dump the graph for debugging + // std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary); + // model_proto->SerializeToOstream(&dump); + + ort_api.ReleaseGraph(sub_graph); +} + // Checks that the contents of the original GraphViewer matches the contents of the OrtGraph. // Uses the public C APIs to traverse the OrtGraph. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) { @@ -470,9 +693,10 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } // Check node subgraphs - std::vector> node_subgraphs = node->GetSubgraphs(); + std::unordered_map> node_subgraphs_map = + node->GetAttributeNameToSubgraphMap(); - if (!node_subgraphs.empty()) { + if (!node_subgraphs_map.empty()) { // Check node's implicit inputs to its subgraph nodes. const auto implicit_input_node_args = node->ImplicitInputDefs(); @@ -489,18 +713,34 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ // Recursively check subgraphs. size_t api_num_node_subgraphs = 0; ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumSubgraphs(api_node, &api_num_node_subgraphs)); + ASSERT_EQ(api_num_node_subgraphs, node_subgraphs_map.size()); std::vector api_node_subgraphs(api_num_node_subgraphs); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, api_node_subgraphs.data(), api_node_subgraphs.size())); - - for (size_t subgraph_idx = 0; subgraph_idx < node_subgraphs.size(); subgraph_idx++) { - auto subgraph_viewer = std::make_unique(*node_subgraphs[subgraph_idx]); - const OrtGraph* api_subgraph = api_node_subgraphs[subgraph_idx]; + std::vector api_subgraph_attr_names(api_num_node_subgraphs); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, api_node_subgraphs.data(), api_node_subgraphs.size(), + api_subgraph_attr_names.data())); + + for (const auto& [attr_name, subgraph] : node_subgraphs_map) { + // find index of this subgraph. + size_t api_subgraph_idx = api_num_node_subgraphs; + for (size_t subgraph_idx = 0; subgraph_idx < api_num_node_subgraphs; subgraph_idx++) { + if (api_subgraph_attr_names[subgraph_idx] == attr_name) { + api_subgraph_idx = subgraph_idx; + break; + } + } + ASSERT_NE(api_subgraph_idx, api_num_node_subgraphs); + // Recursively check the subgraph + auto subgraph_viewer = std::make_unique(*subgraph); + const OrtGraph* api_subgraph = api_node_subgraphs[api_subgraph_idx]; CheckGraphCApi(*subgraph_viewer, *api_subgraph); } } } + + // Check creating an OrtGraph from a subset of nodes in an OrtGraph + Check_Graph_GetSubgraph(api_graph); } } // namespace test diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.cc b/onnxruntime/test/ep_graph/test_ep_graph_utils.cc index b7743e65061de..3b3bc4c6da911 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.cc @@ -30,6 +30,7 @@ std::unique_ptr TestGraph::Load(const ORTCHAR_T* model_path) { const OrtGraph& TestGraph::GetOrtGraph() const { return *api_graph; } const GraphViewer& TestGraph::GetGraphViewer() const { return graph_viewer; } +const Model& TestGraph::GetModel() const { return *model; } static Status GetInputIndices(const Node& consumer_node, const std::string& name, /*out*/ std::vector& indices) { diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.h b/onnxruntime/test/ep_graph/test_ep_graph_utils.h index b0ed825f21d71..2ce107cf734c6 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.h +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.h @@ -28,6 +28,7 @@ class TestGraph { static std::unique_ptr Load(const ORTCHAR_T* model_path); const OrtGraph& GetOrtGraph() const; const GraphViewer& GetGraphViewer() const; + const Model& GetModel() const; private: std::shared_ptr model; diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 18bc9cf05b36d..4c5dcd2bd7580 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -36,7 +36,7 @@ struct TestOrtEp : ::OrtEp, ApiPtrs { // Individual tests should fill out the other function pointers as needed. } - static const char* ORT_API_CALL GetNameImpl(const OrtEp* /*this_ptr*/) { + static const char* ORT_API_CALL GetNameImpl(const OrtEp* /*this_ptr*/) noexcept { constexpr const char* ep_name = "TestOrtEp"; return ep_name; } @@ -50,7 +50,7 @@ struct TestOrtEpFactory : ::OrtEpFactory { ReleaseEp = ReleaseEpImpl; } - static void ORT_API_CALL ReleaseEpImpl(::OrtEpFactory* /*this_ptr*/, OrtEp* ep) { + static void ORT_API_CALL ReleaseEpImpl(::OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { delete static_cast(ep); } }; @@ -125,7 +125,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { } { - auto prefer_nhwc_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* { + auto prefer_nhwc_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) noexcept -> ::OrtStatus* { *preferred_data_layout = OrtEpDataLayout::OrtEpDataLayout_NCHW; return nullptr; }; @@ -135,7 +135,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { #if !defined(ORT_NO_EXCEPTIONS) { - auto invalid_layout_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) -> ::OrtStatus* { + auto invalid_layout_fn = [](OrtEp* /*this_ptr*/, OrtEpDataLayout* preferred_data_layout) noexcept -> ::OrtStatus* { *preferred_data_layout = static_cast(-1); return nullptr; }; @@ -144,7 +144,7 @@ TEST(PluginExecutionProviderTest, GetPreferredLayout) { } { - auto failing_fn = [](OrtEp* this_ptr, OrtEpDataLayout* /*preferred_data_layout*/) -> ::OrtStatus* { + auto failing_fn = [](OrtEp* this_ptr, OrtEpDataLayout* /*preferred_data_layout*/) noexcept -> ::OrtStatus* { auto* test_ort_ep = static_cast(this_ptr); return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "I can't decide what data layout I prefer."); }; @@ -167,7 +167,7 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) { const char* /*node_domain*/, const char* node_op_type, OrtEpDataLayout target_data_layout, - int* should_convert) -> ::OrtStatus* { + int* should_convert) noexcept -> ::OrtStatus* { EXPECT_EQ(target_data_layout, OrtEpDataLayout::OrtEpDataLayout_NHWC); if (node_op_type == std::string_view{"Conv"}) { @@ -201,7 +201,7 @@ TEST(PluginExecutionProviderTest, ShouldConvertDataLayoutForOp) { const char* /*node_domain*/, const char* /*node_op_type*/, OrtEpDataLayout /*target_data_layout*/, - int* /*should_convert*/) -> ::OrtStatus* { + int* /*should_convert*/) noexcept -> ::OrtStatus* { auto* test_ort_ep = static_cast(this_ptr); return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "To convert to NHWC or not to convert to NHWC..."); diff --git a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp index 65822eb294d7d..ea36383f70621 100644 --- a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp +++ b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp @@ -58,10 +58,10 @@ void COMPUTESOFTMAXINPLACE(benchmark::State& state) { std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory // warming up run - MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); for (auto _ : state) { - MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); } free(ptr.underlying_buffer); diff --git a/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp b/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp new file mode 100644 index 0000000000000..b994981364947 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_dequantizelinear.cpp @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" + +template +class MlasDequantizeLinearTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInput; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + + void GenerateReference(const QuantInt* Input, float* OutputReference, size_t N, float Scale, QuantInt ZeroPoint) { + int32_t ZeroPointS32 = static_cast(ZeroPoint); + + for (size_t n = 0; n < N; n++) { + OutputReference[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; + } + } + + void Test(size_t N) { + QuantInt* Input = BufferInput.GetBuffer(N); + float* Output = BufferOutput.GetBuffer(N); + float* OutputReference = BufferOutputReference.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + + std::uniform_real_distribution min_gen(-10.f, -10e-3f); + float MinimumValue = min_gen(generator); + + std::uniform_real_distribution max_gen(10e-3f, 10.f); + float MaximumValue = max_gen(generator); + + float Scale = (MaximumValue - MinimumValue) / 512.f; + + std::uniform_int_distribution zp_distribution(std::numeric_limits::min(), + std::numeric_limits::max()); + QuantInt ZeroPoint = static_cast(zp_distribution(generator)); + + for (size_t n = 0; n < N; n++) { + Input[n] = static_cast(zp_distribution(generator)); + } + + GenerateReference(Input, OutputReference, N, Scale, ZeroPoint); + MlasDequantizeLinear(Input, Output, N, Scale, ZeroPoint); + + for (size_t n = 0; n < N; n++) { + ASSERT_EQ(Output[n], OutputReference[n]) << ", size=" << N << ", index=" << n; + } + } + + public: + static const char* GetTestSuiteName() { + if constexpr (std::is_same_v) { + return "DequantizeLinearS8"; + } else { + return "DequantizeLinearU8"; + } + } + + void ExecuteShort(void) override { + for (size_t n = 1; n <= 512; n++) { + Test(n); + } + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + } + return count; +}); diff --git a/onnxruntime/test/mlas/unittest/test_softmax.cpp b/onnxruntime/test/mlas/unittest/test_softmax.cpp index 041b6c61cd5bf..4d7a45143b311 100644 --- a/onnxruntime/test/mlas/unittest/test_softmax.cpp +++ b/onnxruntime/test/mlas/unittest/test_softmax.cpp @@ -152,7 +152,7 @@ class MlasSoftmaxTest : public MlasTestBase { } void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 1e-6f; @@ -206,7 +206,7 @@ class MlasSoftmaxTest : public MlasTestBase { InputReference[nd] = Input[nd].ToFloat(); } - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); ReferenceSoftmax(InputReference, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 5e-3f; diff --git a/onnxruntime/test/providers/cpu/math/softmax_test.cc b/onnxruntime/test/providers/cpu/math/softmax_test.cc index 649c9af7cc80b..215203b31f49c 100644 --- a/onnxruntime/test/providers/cpu/math/softmax_test.cc +++ b/onnxruntime/test/providers/cpu/math/softmax_test.cc @@ -61,7 +61,8 @@ TEST(SoftmaxOperator, webgpu_nan) { test.AddOutput("Y", dimensions, expected_result); // explicitly disable for EPs that do not handle NaN - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider, kCoreMLExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCpuExecutionProvider, kCoreMLExecutionProvider, kDmlExecutionProvider}); } #endif diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 4e7a6356a5129..8fdbf0060eaa0 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -33,6 +33,32 @@ TEST(DequantizeLinearOpTest, Int8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +// scalar zero & scale with uint8 (large enough input to execute MLAS vectorized loop) +TEST(DequantizeLinearOpTest, Uint8_Large) { + OpTester test("DequantizeLinear", 10); + std::vector dims{1, 1039}; // not evenly divisible by 16 (loop unroll amount) to test handling of leftover inputs + test.AddInput("x", dims, std::vector(1039, 1)); + test.AddInput("x_scale", {}, {1.0f}); + test.AddInput("x_zero_point", {}, {1}); + test.AddOutput("y", dims, std::vector(1039, 0.0f)); + // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. + // Disable WebGPU EP because it requires dims.Size() to be multiple of 4. Fails with error: needs at least component size 4. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); +} + +// scalar zero & scale with int8 (large enough input to execute MLAS vectorized loop) +TEST(DequantizeLinearOpTest, Int8_Large) { + OpTester test("DequantizeLinear", 10); + std::vector dims{1, 1039}; // not evenly divisible by 16 (loop unroll amount) to test handling of leftover inputs + test.AddInput("x", dims, std::vector(1039, 1)); + test.AddInput("x_scale", {}, {1.0f}); + test.AddInput("x_zero_point", {}, {1}); + test.AddOutput("y", dims, std::vector(1039, 0.0f)); + // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. + // Disable WebGPU EP because it requires dims.Size() to be multiple of 4. Fails with error: needs at least component size 4. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); +} + // scalar zero & scale with int4 TEST(DequantizeLinearOpTest, Int4) { OpTester test("DequantizeLinear", 21); diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc index 895c8ab3e53e4..e6d113e1e4dca 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_nd_op_test.cc @@ -235,5 +235,16 @@ TEST(ScatterNDOpTest, ScatterND_18_max) { test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } +// Test for ScatterND with empty indices - output should be same as input +TEST(ScatterNDOpTest, ScatterND_empty_indices) { + // Test with float data type and minimal empty case + OpTester test1("ScatterND", 11); + test1.AddInput("data", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + test1.AddInput("indices", {0, 1}, {}); // Empty indices tensor - no indices to process + test1.AddInput("updates", {0, 3}, {}); // Empty updates tensor + test1.AddOutput("output", {2, 3}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); // Same as input + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 4febfe7ba836d..739e39a6975e2 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -509,6 +509,11 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB Ort::ModelCompilationOptions compile_options(*ort_env, session_options); compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size); + std::string target_dir = "./testdata/"; + std::string model_name = "test_model_in_mem.onnx"; + auto pos = model_name.rfind(".onnx"); + std::string bin_file_name = model_name.substr(0, pos) + "_qnn.bin"; + compile_options.SetEpContextBinaryInformation(ToWideString(target_dir).c_str(), ToWideString(model_name).c_str()); compile_options.SetEpContextEmbedMode(false); // Compile the model. @@ -519,12 +524,18 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB ASSERT_TRUE(output_model_buffer != nullptr); ASSERT_TRUE(output_model_buffer_size > 0); + ASSERT_TRUE(std::filesystem::exists(target_dir + bin_file_name)) << "expected context binary file should exist"; + // Check that the compiled model has the expected number of EPContext nodes. CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2); + // Add session option "ep.context_file_path" so that the session can use it to locate the [model_name]_qnn.bin file + std::string ctx_model = target_dir + model_name; + session_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ctx_model.c_str()); // Should be able to create a session with the compiled model and the original session options. EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, session_options))); + std::filesystem::remove(target_dir + bin_file_name); allocator.Free(output_model_buffer); } } @@ -1649,7 +1660,6 @@ static void DumpModelWithSharedCtx(ProviderOptions provider_options, Ort::Session session2(*ort_env, ToPathString(onnx_model_path2).c_str(), so); } -#if defined(__aarch64__) || defined(_M_ARM64) static void GetModelInputNames(const std::string& model_path, std::vector& input_names, std::vector& output_names, @@ -1669,7 +1679,6 @@ static void GetModelInputNames(const std::string& model_path, output_names.push_back(output->Name()); } } -#endif // 1. Create 2 QDQ models // 2. Initialize 2 Ort sessions which share the same QNN EP from these 2 QDQ models @@ -1994,6 +2003,73 @@ TEST_F(QnnHTPBackendTests, LoadFromArrayWithQnnEpContextGenPathValidation) { }); } } + +TEST_F(QnnHTPBackendTests, QnnEpDynamicOptions) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + + Ort::SessionOptions so; + so.AppendExecutionProvider("QNN", provider_options); + so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx"), so); + + std::vector input_names; + std::vector output_names; + GetModelInputNames("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx", input_names, output_names, + DefaultLoggingManager().DefaultLogger()); + + // Run sessions + // prepare input + std::vector input_dim{3, 4}; + std::vector input_value(3 * 4, 0.0f); + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + std::vector ort_inputs; + std::vector input_names_c; + for (size_t i = 0; i < input_names.size(); ++i) { + auto input_tensor = Ort::Value::CreateTensor(info, input_value.data(), input_value.size(), + input_dim.data(), input_dim.size()); + ort_inputs.push_back(std::move(input_tensor)); + input_names_c.push_back(input_names[i].c_str()); + } + std::vector output_names_c; + for (size_t i = 0; i < output_names.size(); ++i) { + output_names_c.push_back(output_names[i].c_str()); + } + + auto ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + const char* const workload_type[] = {"ep.dynamic.workload_type"}; + const char* const efficient_type[] = {"Efficient"}; + const char* const default_type[] = {"Default"}; + + // Test Efficient & Default options + session.SetEpDynamicOptions(workload_type, efficient_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + session.SetEpDynamicOptions(workload_type, default_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + // Test invalid EP dynamic option and invalid workload type + const char* const dne[] = {"DNE"}; + try { + session.SetEpDynamicOptions(workload_type, dne, 1); + FAIL() << "Expected exception to be thrown for workload type DNE but was set successfully"; + } catch (const std::exception& e) { + EXPECT_STREQ("Invalid EP Workload Type.", e.what()); + } + + try { + session.SetEpDynamicOptions(dne, efficient_type, 1); + FAIL() << "Expected exception to be thrown for dynamic option DNE but was set successfully"; + } catch (const std::exception& e) { + EXPECT_STREQ("Unsupported EP Dynamic Option", e.what()); + } +} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 85f8250f70fc5..4c0a53e83e274 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1254,6 +1254,38 @@ TEST_F(QnnHTPBackendTests, GridSample_U16_Nearest) { true); } +// Test QDQ GridSample with `linear` mode on opset 20+. +TEST_F(QnnHTPBackendTests, GridSample_Linear_ZerosPadding) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), + TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, + {utils::MakeAttribute("mode", "linear"), utils::MakeAttribute("padding_mode", "zeros")}, + /*opset_version=*/20, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, GridSample_Linear_AlignCorners_BorderPadding) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), + TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, + {utils::MakeAttribute("align_corners", static_cast(1)), + utils::MakeAttribute("mode", "linear"), + utils::MakeAttribute("padding_mode", "border")}, + /*opset_version=*/20, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, GridSample_Linear_ReflectionPadding_U16) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), + TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, + {utils::MakeAttribute("mode", "linear"), utils::MakeAttribute("padding_mode", "reflection")}, + /*opset_version=*/21, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*op_domain=*/kOnnxDomain, + /*use_contrib_qdq=*/true); +} + // Test QDQ GridSample with reflection padding mode // Inaccuracy detected for output 'output', element 2. // Output quant params: scale=0.024269860237836838, zero_point=0. diff --git a/onnxruntime/test/python/quantization/test_op_topk.py b/onnxruntime/test/python/quantization/test_op_topk.py new file mode 100644 index 0000000000000..1fdd0c987d1e8 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_topk.py @@ -0,0 +1,103 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest + +import numpy as np +from onnx import TensorProto, helper, save +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static + + +class TestTopKModel(unittest.TestCase): + @staticmethod + def construct_model(model_path, input_shape, axis_attr, k): + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape) + k_tensor = helper.make_tensor("k", TensorProto.INT64, [1], [k]) + output_shape = input_shape[:] + output_shape[axis_attr] = k + output_values = helper.make_tensor_value_info("values", TensorProto.FLOAT, [1, k]) + output_indices = helper.make_tensor_value_info("indices", TensorProto.INT64, [1, k]) + + node = helper.make_node( + "TopK", inputs=["input", "k"], outputs=["values", "indices"], name="topk_node", axis=axis_attr + ) + + graph = helper.make_graph( + [node], + "quant_topk_op_test", + [input_tensor], + [output_values, output_indices], + initializer=[k_tensor], + ) + + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 16), helper.make_opsetid("com.microsoft", 1)] + ) + save(model, model_path) + + def quantize_topk_test(self, activation_type, weight_type, extra_options={}): # noqa: B006 + model_fp32_path = "topk_fp32.onnx" + input_shape = [1, 10] + axis = 1 + k = 3 + self.construct_model(model_fp32_path, input_shape, axis, k) + + input_data_list = [ + {"input": np.array([[1.8, 2.5, -5.9, 5.2, 4.1, 7.3, 0.2, -0.5, 0.845, 3.9]], dtype=np.float32)} + ] + data_reader = TestDataFeeds(input_data_list) + + activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_qdq_path = f"topk_{activation_type_str}{weight_type_str}_{'QNoInCk' if extra_options['ForceQuantizeNoInputCheck'] else 'NoQNoInCk'}_qdq.onnx" + + # Verify QDQ mode + data_reader.rewind() + quantize_static( + model_fp32_path, + model_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + qdqnode_counts = ( + { + "TopK": 1, + "QuantizeLinear": 2, + "DequantizeLinear": 2, + } + if extra_options["ForceQuantizeNoInputCheck"] + else { + "TopK": 1, + "QuantizeLinear": 0, + "DequantizeLinear": 0, + } + ) + check_op_type_count(self, model_qdq_path, **qdqnode_counts) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + check_qtype_by_node_type(self, model_qdq_path, qnode_io_qtypes) + data_reader.rewind() + check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) + + def test_quantize_topk_u8u8(self): + self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": True}) + + def test_quantize_topk_u8u8_no_force_quantize_no_input_check(self): + self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": False}) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx b/onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx new file mode 100644 index 0000000000000..34cf26c13d3fc Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/phi-4-v-instruct-vision-attention.onnx differ diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 461c243b82212..7f2134b2cda4f 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -13,6 +13,7 @@ import random import unittest from dataclasses import dataclass +from enum import Enum import numpy import torch @@ -38,11 +39,17 @@ ATOL = None -class Formats: +class Formats(Enum): BSNH = 0 BNSH = 1 +class QKOutputType(Enum): + NO_OUTPUT = 0 + BEFORE_SOFTMAX = 1 + AFTER_SOFTMAX = 2 + + @dataclass class Config: batch_size: int = 0 @@ -54,6 +61,8 @@ class Config: head_size: int = 0 has_position_ids: bool = False has_attention_bias: bool = False + has_head_sink: bool = False + qk_output: QKOutputType = QKOutputType.NO_OUTPUT @dataclass @@ -67,6 +76,8 @@ class PromptConfig: head_size: int = 0 has_position_ids: bool = False has_attention_bias: bool = False + has_head_sink: bool = False + qk_output: QKOutputType = QKOutputType.NO_OUTPUT # LLaMA Microsoft model @@ -151,6 +162,15 @@ def create_group_query_attention_graph_prompt( ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length + + output_names = [ + "output", + "present_key", + "present_value", + ] + if config.qk_output != QKOutputType.NO_OUTPUT: + output_names.append("output_qk") + nodes = [ helper.make_node( "GroupQueryAttention", @@ -166,8 +186,9 @@ def create_group_query_attention_graph_prompt( "sin_cache" if rotary else "", "position_ids" if config.has_position_ids else "", "attention_bias" if config.has_attention_bias else "", + "head_sink" if config.has_head_sink else "", ], - ["output", "present_key", "present_value"], + output_names, "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, @@ -176,6 +197,7 @@ def create_group_query_attention_graph_prompt( rotary_interleaved=rotary_interleaved, softcap=softcap, smooth_softmax=1 if use_smooth_softmax else 0, + qk_output=config.qk_output.value, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -289,6 +311,15 @@ def create_group_query_attention_graph_prompt( ), ] + if config.has_head_sink: + graph_input += [ + helper.make_tensor_value_info( + "head_sink", + ort_type, + [config.num_heads], + ), + ] + graph_output = [ helper.make_tensor_value_info( "output", @@ -337,6 +368,15 @@ def create_group_query_attention_graph_prompt( ), ] + if config.qk_output != QKOutputType.NO_OUTPUT: + graph_output += [ + helper.make_tensor_value_info( + "output_qk", + ort_type, + [config.batch_size, config.num_heads, config.kv_sequence_length, config.kv_sequence_length], + ), + ] + graph = helper.make_graph( nodes, "GroupQueryAttention_Graph", @@ -365,6 +405,15 @@ def create_group_query_attention_graph_past( present_kv_seqlen = ( config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length ) + + output_names = [ + "output", + "present_key", + "present_value", + ] + if config.qk_output != QKOutputType.NO_OUTPUT: + output_names.append("output_qk") + nodes = [ helper.make_node( "GroupQueryAttention", @@ -380,8 +429,9 @@ def create_group_query_attention_graph_past( "sin_cache" if rotary else "", "position_ids" if config.has_position_ids else "", "attention_bias" if config.has_attention_bias else "", + "head_sink" if config.has_head_sink else "", ], - ["output", "present_key", "present_value"], + output_names, "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, @@ -390,6 +440,7 @@ def create_group_query_attention_graph_past( rotary_interleaved=rotary_interleaved, softcap=softcap, smooth_softmax=1 if use_smooth_softmax else 0, + qk_output=config.qk_output.value, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -441,6 +492,7 @@ def create_group_query_attention_graph_past( [1], ), ] + if not packed: graph_input += [ helper.make_tensor_value_info( @@ -462,6 +514,7 @@ def create_group_query_attention_graph_past( ], ), ] + if rotary: graph_input += [ helper.make_tensor_value_info( @@ -498,6 +551,15 @@ def create_group_query_attention_graph_past( ), ] + if config.has_head_sink: + graph_input += [ + helper.make_tensor_value_info( + "head_sink", + ort_type, + [config.num_heads], + ), + ] + graph_output = [ helper.make_tensor_value_info( "output", @@ -526,6 +588,15 @@ def create_group_query_attention_graph_past( ), ] + if config.qk_output != QKOutputType.NO_OUTPUT: + graph_output += [ + helper.make_tensor_value_info( + "output_qk", + ort_type, + [config.batch_size, config.num_heads, config.sequence_length, present_kv_seqlen], + ), + ] + graph = helper.make_graph( nodes, "GroupQueryAttention_Graph", @@ -552,17 +623,17 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): """ Arguments: - q: (batch_size, seqlen_q, nheads, d) - k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) + q: (batch_size, seqlen_q, num_heads, d) + k: (batch_size, seqlen_k, num_heads_k, d) + v: (batch_size, seqlen_k, num_heads_k, d) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) - batch_size, seqlen_q, nheads, d = q.shape - _, seqlen_k, nheads_k, _ = k.shape - assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) + batch_size, seqlen_q, num_heads, d = q.shape + _, seqlen_k, num_heads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, num_heads_k, d) + assert v.shape == (batch_size, seqlen_k, num_heads_k, d) if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) @@ -593,7 +664,7 @@ def output_pad_fn(output_unpad): if qkvpacked: assert (query_padding_mask == key_padding_mask).all() - assert nheads == nheads_k + assert num_heads == num_heads_k qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) qkv = torch.stack([q, k, v], dim=2) if query_padding_mask is not None: @@ -714,6 +785,8 @@ def gqa_prompt_func( seqlens_k=None, position_ids=None, attention_bias=None, + head_sink=None, + output_qk=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True, @@ -746,9 +819,18 @@ def gqa_prompt_func( if config.has_attention_bias: assert attention_bias is not None + if config.qk_output != QKOutputType.NO_OUTPUT: + assert output_qk is not None + if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() + ort_outputs = {} + if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -757,10 +839,6 @@ def gqa_prompt_func( "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } - - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() if new_k is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -797,25 +875,18 @@ def gqa_prompt_func( io_binding.bind_output("output") io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v else: ort_inputs = { "query": q.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() if new_k is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() io_binding.bind_cpu_input("key", ort_inputs["key"]) io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: ort_inputs["cos_cache"] = cos.detach().cpu().numpy() ort_inputs["sin_cache"] = sin.detach().cpu().numpy() @@ -836,11 +907,26 @@ def gqa_prompt_func( io_binding.bind_output("output") io_binding.bind_output("present_key") io_binding.bind_output("present_value") - ort_session.run_with_iobinding(io_binding) + + if config.has_head_sink: + ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() + io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) + + if config.qk_output != QKOutputType.NO_OUTPUT: + ort_outputs["output_qk"] = OrtValue.ortvalue_from_numpy(output_qk.detach().cpu().numpy(), "cpu", 0) + io_binding.bind_ortvalue_output("output_qk", ort_outputs["output_qk"]) + + ort_session.run_with_iobinding(io_binding) + + out_qk = None + if config.qk_output != QKOutputType.NO_OUTPUT: + ort_output, present_k, present_v, out_qk = io_binding.copy_outputs_to_cpu() + else: ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + + return output, present_k, present_v, out_qk def gqa_past_func( @@ -855,6 +941,8 @@ def gqa_past_func( seqlens_k=None, position_ids=None, attention_bias=None, + head_sink=None, + output_qk=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1, @@ -887,9 +975,18 @@ def gqa_past_func( if config.has_attention_bias: assert attention_bias is not None + if config.qk_output != QKOutputType.NO_OUTPUT: + assert output_qk is not None + if new_k is not None: new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + io_binding = ort_session.io_binding() + ort_outputs = {} + if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -901,9 +998,6 @@ def gqa_past_func( .cpu() .numpy(), } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -940,11 +1034,6 @@ def gqa_past_func( io_binding.bind_output("output") io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v else: ort_inputs = { "query": q.detach().cpu().numpy(), @@ -958,9 +1047,6 @@ def gqa_past_func( .cpu() .numpy(), } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - io_binding = ort_session.io_binding() if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() @@ -988,11 +1074,26 @@ def gqa_past_func( io_binding.bind_output("output") io_binding.bind_output("present_key") io_binding.bind_output("present_value") - ort_session.run_with_iobinding(io_binding) + + if config.has_head_sink: + ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() + io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) + + if config.qk_output != QKOutputType.NO_OUTPUT: + ort_outputs["output_qk"] = OrtValue.ortvalue_from_numpy(output_qk.detach().cpu().numpy(), "cpu", 0) + io_binding.bind_ortvalue_output("output_qk", ort_outputs["output_qk"]) + + ort_session.run_with_iobinding(io_binding) + + out_qk = None + if config.qk_output != QKOutputType.NO_OUTPUT: + ort_output, present_k, present_v, out_qk = io_binding.copy_outputs_to_cpu() + else: ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + + return output, present_k, present_v, out_qk def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): @@ -1025,11 +1126,28 @@ def construct_local_mask( ) -def smooth_softmax_ref(x): - x_max = x.amax(axis=-1, keepdim=True) - x_max = torch.maximum(x_max, torch.zeros_like(x_max)) - w = torch.exp(x - x_max) - return w * torch.reciprocal(w.sum(axis=-1, keepdim=True) + torch.exp(-x_max)) +def smooth_softmax_ref(x, head_sink): + """ + Arguments: + x: (batch_size, num_heads, seqlen_q, seqlen_k) + head_sink: (num_heads) or None + Output: + y: (batch_size, num_heads, seqlen_q, seqlen_k) + """ + assert len(x.shape) == 4 + b, n, s, t = x.shape + + if head_sink is not None: + assert len(head_sink.shape) == 1 + assert head_sink.shape[0] == x.shape[1] + sink = head_sink.reshape(1, n, 1, 1).expand(b, -1, s, -1) + else: + sink = torch.zeros(b, n, s, 1, dtype=x.dtype) + + y = torch.cat([x, sink], dim=-1) + y = torch.softmax(y, dim=-1) + y = y[..., :-1] + return y def attention_ref( @@ -1046,16 +1164,17 @@ def attention_ref( upcast=True, reorder_ops=False, use_smooth_softmax=False, + head_sink=None, ): """ Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads_k, head_dim) - v: (batch_size, seqlen_k, nheads_k, head_dim) + q: (batch_size, seqlen_q, num_heads, head_dim) + k: (batch_size, seqlen_k, num_heads_k, head_dim) + v: (batch_size, seqlen_k, num_heads_k, head_dim) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) dropout_p: float - dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + dropout_mask: (batch_size, num_heads, seqlen_q, seqlen_k) causal: whether to apply causal masking window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast @@ -1064,8 +1183,10 @@ def attention_ref( without changing the math. This is to estimate the numerical error from operation reordering. use_smooth_softmax: whether use smooth softmax or not + head_sink: (num_heads) or None Output: output: (batch_size, seqlen_q, nheads, head_dim) + masked_scores: (batch_size, nheads, seqlen_q, seqlen_k), before softmax attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: @@ -1085,8 +1206,10 @@ def attention_ref( scores = scores / softcap scores = scores.tanh() scores = scores * softcap + masked_scores = scores.clone() if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + masked_scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, @@ -1096,10 +1219,11 @@ def attention_ref( key_padding_mask, q.device, ) + masked_scores.masked_fill_(local_mask, 0.0) scores.masked_fill_(local_mask, float("-inf")) - if use_smooth_softmax: - attention = smooth_softmax_ref(scores) + if use_smooth_softmax or (head_sink is not None): + attention = smooth_softmax_ref(scores, head_sink) else: attention = torch.softmax(scores, dim=-1) @@ -1121,7 +1245,7 @@ def attention_ref( if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + return output.to(dtype=dtype_og), masked_scores.to(dtype=dtype_og), attention.to(dtype=dtype_og) def attention_qkvpacked_ref( @@ -1133,6 +1257,7 @@ def attention_qkvpacked_ref( upcast=True, reorder_ops=False, use_smooth_softmax=False, + head_sink=None, ): return attention_ref( qkv[:, :, 0], @@ -1146,6 +1271,7 @@ def attention_qkvpacked_ref( causal=causal, reorder_ops=reorder_ops, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) @@ -1186,6 +1312,10 @@ def get_custom_position_ids(batch_size, sequence_length, seqlens_k=None, past=Fa return position_ids +def get_custom_head_sink(num_heads, torch_type=torch.float16): + return torch.rand(num_heads, dtype=torch_type) + + def parity_check_gqa_prompt( config, torch_type, @@ -1248,6 +1378,8 @@ def parity_check_gqa_prompt( requires_grad=False, ) + head_sink = get_custom_head_sink(config.num_heads, torch_type) if config.has_head_sink else None + window_size = (-1, -1) left_window_size = -1 if local: @@ -1305,6 +1437,20 @@ def parity_check_gqa_prompt( else None ) + output_qk = ( + torch.zeros( + config.batch_size, + config.num_heads, + config.kv_sequence_length, + config.q_sequence_length, + device="cpu", + dtype=torch_type, + requires_grad=False, + ) + if config.qk_output != QKOutputType.NO_OUTPUT + else None + ) + arange = rearrange(torch.arange(config.buffer_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") kv_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size) @@ -1315,7 +1461,7 @@ def parity_check_gqa_prompt( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded - out_ref, _ = attention_ref( + out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1327,6 +1473,7 @@ def parity_check_gqa_prompt( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1337,7 +1484,7 @@ def parity_check_gqa_prompt( # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1 if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_prompt_func( + out, present_k, present_v, out_qk = gqa_prompt_func( packed_qkv, k, v, @@ -1349,6 +1496,8 @@ def parity_check_gqa_prompt( cache_seqlens - 1, position_ids, attention_bias, + head_sink, + output_qk, left_window_size, past_format, True, @@ -1359,7 +1508,7 @@ def parity_check_gqa_prompt( numpy_type=numpy_type, ) else: - out, present_k, present_v = gqa_prompt_func( + out, present_k, present_v, out_qk = gqa_prompt_func( q, k, v, @@ -1371,6 +1520,8 @@ def parity_check_gqa_prompt( cache_seqlens - 1, position_ids, attention_bias, + head_sink, + output_qk, left_window_size, past_format, True, @@ -1384,6 +1535,22 @@ def parity_check_gqa_prompt( out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + if config.qk_output != QKOutputType.NO_OUTPUT: + out_qk_ref = ( + out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref + ) + out_qk_ref = out_qk_ref.detach().cpu().numpy() + + for batch_idx in range(config.batch_size): + total_seqlen = cache_seqlens[batch_idx] + assert numpy.allclose( + out_qk[batch_idx, :, :, :total_seqlen], + out_qk_ref[batch_idx, :, :, :total_seqlen], + rtol=rtol, + atol=atol, + equal_nan=True, + ) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1425,6 +1592,8 @@ def parity_check_gqa_prompt( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, + " qk_output:", + config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1531,12 +1700,28 @@ def parity_check_gqa_prompt_no_buff( else None ) + head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None + + output_qk = ( + torch.zeros( + config.batch_size, + config.num_heads, + config.kv_sequence_length, + config.q_sequence_length, + device="cpu", + dtype=torch_type, + requires_grad=False, + ) + if config.qk_output != QKOutputType.NO_OUTPUT + else None + ) + brange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") new_mask = brange < cache_seqlens_expanded k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - out_ref, _ = attention_ref( + out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1548,6 +1733,7 @@ def parity_check_gqa_prompt_no_buff( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1558,7 +1744,7 @@ def parity_check_gqa_prompt_no_buff( # Cache seqlens is reduced by 1 since it is required to be past_seq_len + seq_len - 1 if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_prompt_func( + out, present_k, present_v, out_qk = gqa_prompt_func( packed_qkv, None, None, @@ -1570,6 +1756,8 @@ def parity_check_gqa_prompt_no_buff( cache_seqlens - 1, position_ids, attention_bias, + head_sink, + output_qk, left_window_size, past_format, False, @@ -1580,7 +1768,7 @@ def parity_check_gqa_prompt_no_buff( numpy_type=numpy_type, ) else: - out, present_k, present_v = gqa_prompt_func( + out, present_k, present_v, out_qk = gqa_prompt_func( q, None, None, @@ -1592,6 +1780,8 @@ def parity_check_gqa_prompt_no_buff( cache_seqlens - 1, position_ids, attention_bias, + head_sink, + output_qk, left_window_size, past_format, False, @@ -1605,6 +1795,22 @@ def parity_check_gqa_prompt_no_buff( out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + if config.qk_output != QKOutputType.NO_OUTPUT: + out_qk_ref = ( + out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref + ) + out_qk_ref = out_qk_ref.detach().cpu().numpy() + + for batch_idx in range(config.batch_size): + total_seqlen = cache_seqlens[batch_idx] + assert numpy.allclose( + out_qk[batch_idx, :, :, :total_seqlen], + out_qk_ref[batch_idx, :, :, :total_seqlen], + rtol=rtol, + atol=atol, + equal_nan=True, + ) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1646,6 +1852,8 @@ def parity_check_gqa_prompt_no_buff( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, + " qk_output:", + config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -1759,6 +1967,8 @@ def parity_check_gqa_past( cos, sin = None, None q_ro, k_ro = q, new_k + head_sink = get_custom_head_sink(config.num_heads, torch_type=torch_type) if config.has_head_sink else None + arange = rearrange(torch.arange(config.kv_sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( @@ -1769,7 +1979,7 @@ def parity_check_gqa_past( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref( + out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -1781,6 +1991,7 @@ def parity_check_gqa_past( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1807,10 +2018,24 @@ def parity_check_gqa_past( else None ) + output_qk = ( + torch.zeros( + config.batch_size, + config.num_heads, + config.sequence_length, + config.kv_sequence_length, + device="cpu", + dtype=torch_type, + requires_grad=False, + ) + if config.qk_output != QKOutputType.NO_OUTPUT + else None + ) + # ORT function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_past_func( + out, present_k, present_v, out_qk = gqa_past_func( packed_qkv, k, v, @@ -1822,6 +2047,8 @@ def parity_check_gqa_past( cache_seqlens, position_ids, attention_bias, + head_sink, + output_qk, past_format, True, left_window_size, @@ -1832,7 +2059,7 @@ def parity_check_gqa_past( numpy_type=numpy_type, ) else: - out, present_k, present_v = gqa_past_func( + out, present_k, present_v, out_qk = gqa_past_func( q, k, v, @@ -1844,6 +2071,8 @@ def parity_check_gqa_past( cache_seqlens, position_ids, attention_bias, + head_sink, + output_qk, past_format, True, left_window_size, @@ -1857,6 +2086,22 @@ def parity_check_gqa_past( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + if config.qk_output != QKOutputType.NO_OUTPUT: + out_qk_ref = ( + out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref + ) + out_qk_ref = out_qk_ref.detach().cpu().numpy() + + for batch_idx in range(config.batch_size): + total_seqlen = cache_seqlens[batch_idx] + 1 + assert numpy.allclose( + out_qk[batch_idx, :, :, :total_seqlen], + out_qk_ref[batch_idx, :, :, :total_seqlen], + rtol=rtol, + atol=atol, + equal_nan=True, + ) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1882,6 +2127,8 @@ def parity_check_gqa_past( softcap, " smooth_softmax:", use_smooth_softmax, + " head_sink:", + config.has_head_sink, " B:", config.batch_size, " S:", @@ -1898,6 +2145,8 @@ def parity_check_gqa_past( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, + " qk_output:", + config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -2017,6 +2266,8 @@ def parity_check_gqa_past_no_buff( cos, sin = None, None q_ro, k_ro = q, new_k + head_sink = get_custom_head_sink(config.num_heads, torch_type) if config.has_head_sink else None + arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cpu"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( @@ -2027,7 +2278,7 @@ def parity_check_gqa_past_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref( + out_ref, out_qk_pre_softmax_ref, out_qk_post_softmax_ref = attention_ref( q_ro, k_cache_rep, v_cache_rep, @@ -2039,6 +2290,7 @@ def parity_check_gqa_past_no_buff( window_size=window_size, softcap=softcap, use_smooth_softmax=use_smooth_softmax, + head_sink=head_sink, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -2065,10 +2317,24 @@ def parity_check_gqa_past_no_buff( else None ) + output_qk = ( + torch.zeros( + config.batch_size, + config.num_heads, + config.sequence_length, + config.kv_sequence_length + config.sequence_length, + device="cpu", + dtype=torch_type, + requires_grad=False, + ) + if config.qk_output != QKOutputType.NO_OUTPUT + else None + ) + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_past_func( + out, present_k, present_v, out_qk = gqa_past_func( packed_qkv, k, v, @@ -2080,6 +2346,8 @@ def parity_check_gqa_past_no_buff( cache_seqlens, position_ids, attention_bias, + head_sink, + output_qk, past_format, False, window_size=left_window_size, @@ -2090,7 +2358,7 @@ def parity_check_gqa_past_no_buff( numpy_type=numpy_type, ) else: - out, present_k, present_v = gqa_past_func( + out, present_k, present_v, out_qk = gqa_past_func( q, k, v, @@ -2102,6 +2370,8 @@ def parity_check_gqa_past_no_buff( cache_seqlens, position_ids, attention_bias, + head_sink, + output_qk, past_format, False, window_size=left_window_size, @@ -2115,6 +2385,22 @@ def parity_check_gqa_past_no_buff( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + if config.qk_output != QKOutputType.NO_OUTPUT: + out_qk_ref = ( + out_qk_post_softmax_ref if config.qk_output == QKOutputType.AFTER_SOFTMAX else out_qk_pre_softmax_ref + ) + out_qk_ref = out_qk_ref.detach().cpu().numpy() + + for batch_idx in range(config.batch_size): + total_seqlen = cache_seqlens[batch_idx] + 1 + assert numpy.allclose( + out_qk[batch_idx, :, :, :total_seqlen], + out_qk_ref[batch_idx, :, :, :total_seqlen], + rtol=rtol, + atol=atol, + equal_nan=True, + ) + # Compare results all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET @@ -2134,6 +2420,8 @@ def parity_check_gqa_past_no_buff( softcap, " smooth_softmax:", use_smooth_softmax, + " head_sink:", + config.has_head_sink, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -2152,6 +2440,8 @@ def parity_check_gqa_past_no_buff( config.has_position_ids, " has_attention_bias:", config.has_attention_bias, + " qk_output:", + config.qk_output, " Mean Error:", numpy.mean(numpy.abs(out - out_ref)), correct, @@ -2180,7 +2470,16 @@ def setUp(self): ] def run_test_config( - self, test_func, config_class, batches, seqs, num_h, h_sizes, pos_ids_attn_bias, additional_params=None + self, + test_func, + config_class, + batches, + seqs, + num_h, + h_sizes, + pos_ids_attn_bias, + qk_output, + additional_params=None, ): if additional_params is None: additional_params = {} @@ -2202,33 +2501,59 @@ def run_test_config( for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: for has_pos, has_attn in pos_ids_attn_bias: - if config_class == PromptConfig: - config = config_class( - b, s, s2, s + s2 + 8, n, n2, h, has_pos, has_attn - ) - else: # Config - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = config_class(b, s, s2, sp, n, n2, h, has_pos, has_attn) - - params = { - "config": config, - "torch_type": precision["torch_type"], - "numpy_type": precision["numpy_type"], - "ort_type": precision["ort_type"], - "rtol": precision["rtol"], - "atol": precision["atol"], - "local": local, - "past_format": Formats.BNSH, - "rotary": rotary, - "rotary_interleaved": rotary_interleaved, - "packed": packed, - "softcap": softcap, - "use_smooth_softmax": use_smooth_softmax, - } - params.update(additional_params) - - all_close = test_func(**params) - self.assertTrue(all_close) + for head_sink in [False, True]: + if use_smooth_softmax and head_sink: + continue + for output_qk in qk_output: + if config_class == PromptConfig: + config = config_class( + b, + s, + s2, + s + s2 + 8, + n, + n2, + h, + has_pos, + has_attn, + head_sink, + output_qk, + ) + else: # Config + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = config_class( + b, + s, + s2, + sp, + n, + n2, + h, + has_pos, + has_attn, + head_sink, + output_qk, + ) + + params = { + "config": config, + "torch_type": precision["torch_type"], + "numpy_type": precision["numpy_type"], + "ort_type": precision["ort_type"], + "rtol": precision["rtol"], + "atol": precision["atol"], + "local": local, + "past_format": Formats.BNSH, + "rotary": rotary, + "rotary_interleaved": rotary_interleaved, + "packed": packed, + "softcap": softcap, + "use_smooth_softmax": use_smooth_softmax, + } + params.update(additional_params) + + all_close = test_func(**params) + self.assertTrue(all_close) def test_gqa_no_past(self): print("-------- TEST GQA NO PAST (PROMPT CASE) ---------") @@ -2245,12 +2570,33 @@ def test_gqa_no_past(self): ) num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + qk_output = ( + [QKOutputType.NO_OUTPUT] + if pipeline_mode + else [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] + ) # Test with buffer - self.run_test_config(parity_check_gqa_prompt, PromptConfig, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) + self.run_test_config( + parity_check_gqa_prompt, + PromptConfig, + batches, + seqs, + num_h, + h_sizes, + pos_ids_attn_bias, + qk_output, + ) # Test without buffer self.run_test_config( - parity_check_gqa_prompt_no_buff, PromptConfig, batches, seqs, num_h, h_sizes, pos_ids_attn_bias + parity_check_gqa_prompt_no_buff, + PromptConfig, + batches, + seqs, + num_h, + h_sizes, + pos_ids_attn_bias, + qk_output, ) def test_gqa_past(self): @@ -2268,11 +2614,25 @@ def test_gqa_past(self): ) num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + qk_output = ( + [QKOutputType.NO_OUTPUT] + if pipeline_mode + else [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] + ) # Test with buffer - self.run_test_config(parity_check_gqa_past, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) + self.run_test_config(parity_check_gqa_past, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias, qk_output) # Test without buffer - self.run_test_config(parity_check_gqa_past_no_buff, Config, batches, seqs, num_h, h_sizes, pos_ids_attn_bias) + self.run_test_config( + parity_check_gqa_past_no_buff, + Config, + batches, + seqs, + num_h, + h_sizes, + pos_ids_attn_bias, + qk_output, + ) def test_gqa_interactive_one_batch(self): print("-------- TEST GQA INTERACTIVE ---------") @@ -2287,6 +2647,7 @@ def test_gqa_interactive_one_batch(self): if pipeline_mode else [(False, False), (True, True), (False, True), (True, False)] ) + qk_output = [QKOutputType.NO_OUTPUT, QKOutputType.BEFORE_SOFTMAX, QKOutputType.AFTER_SOFTMAX] num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [32] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] @@ -2299,6 +2660,7 @@ def test_gqa_interactive_one_batch(self): num_h, h_sizes, pos_ids_attn_bias, + qk_output, additional_params={"softcap": 0.0, "use_smooth_softmax": False}, ) self.run_test_config( @@ -2309,6 +2671,7 @@ def test_gqa_interactive_one_batch(self): num_h, h_sizes, pos_ids_attn_bias, + qk_output, additional_params={"softcap": 0.0, "use_smooth_softmax": False}, ) diff --git a/onnxruntime/test/python/transformers/test_gqa_cuda.py b/onnxruntime/test/python/transformers/test_gqa_cuda.py index 2f5b638a57d0c..79976a92e54bf 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cuda.py +++ b/onnxruntime/test/python/transformers/test_gqa_cuda.py @@ -782,7 +782,8 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if use_smooth_softmax: - attention = smooth_softmax_ref(scores) + head_sink = None + attention = smooth_softmax_ref(scores, head_sink) else: attention = torch.softmax(scores, dim=-1) diff --git a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py index 410860a324a9d..ca5c9c2ce133f 100644 --- a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py +++ b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py @@ -401,7 +401,8 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if use_smooth_softmax: - attention = smooth_softmax_ref(scores) + head_sink = None + attention = smooth_softmax_ref(scores, head_sink) else: attention = torch.softmax(scores, dim=-1) diff --git a/onnxruntime/test/python/transformers/test_phi_vision.py b/onnxruntime/test/python/transformers/test_phi_vision.py index 67f89e633a146..d276366706af9 100644 --- a/onnxruntime/test/python/transformers/test_phi_vision.py +++ b/onnxruntime/test/python/transformers/test_phi_vision.py @@ -149,7 +149,7 @@ def __init__(self): self.attn = PhiVCLIPAttention() self.ln = torch.nn.LayerNorm(20, eps=1e-05) - def forward(self, x): + def forward(self, x, attention_mask=None): # SkipLayerNorm ------+ # | | # Attention | @@ -163,8 +163,7 @@ def forward(self, x): x = self.ln(x) residual = x - # Attention + MatMul - x = self.attn(x) + x = self.attn(x, attention_mask=attention_mask) # SkipLayerNorm x = residual + x @@ -194,14 +193,31 @@ def verify_fusion(self, optimized_model, expected_model_filename): ) def export(self, model, inputs): - torch.onnx.export( - model, - args=inputs, - f=os.path.join(os.path.dirname(__file__), "export.onnx"), - export_params=True, - opset_version=14, - do_constant_folding=True, - ) + path = os.path.join(os.path.dirname(__file__), "export.onnx") + + if len(inputs) == 2: + torch.onnx.export( + model, + args=inputs, + f=path, + export_params=True, + opset_version=14, + do_constant_folding=True, + input_names=["input", "attention_mask"], + dynamic_axes={ + "input": {0: "batch", 1: "seq"}, + "attention_mask": {0: "batch", 2: "seq", 3: "seq"}, + }, + ) + else: + torch.onnx.export( + model, + args=inputs, + f=path, + export_params=True, + opset_version=14, + do_constant_folding=True, + ) def tearDown(self): path = os.path.join(os.path.dirname(__file__), "export.onnx") @@ -249,6 +265,38 @@ def test_phi_vision_attention(self): ) self.verify_fusion(optimized_model, "phi-3.5-v-instruct-vision-attention.onnx") + def test_phi_vision_attention_with_mask(self): + model = PhiVCLIPAttentionAndLayerNorm() + + batch, seq_len, dim = 1, 2, 20 + mask = torch.zeros(batch, 1, seq_len, seq_len) + mask[:, 1:] = float("-inf") + + inputs = (torch.randn(batch, seq_len, dim), mask) + self.export(model, inputs) + original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx")) + options = FusionOptions("clip") + optimized_model = optimize_model( + original_model, + model_type="clip", + num_heads=2, + hidden_size=20, + optimization_options=options, + opt_level=0, + use_gpu=True, + ) + self.verify_fusion(optimized_model, "phi-4-v-instruct-vision-attention.onnx") + + graph = optimized_model.model.graph + attention_node = next((n for n in graph.node if n.name == "Attention_0"), None) + self.assertIsNotNone(attention_node, "Could not find the Attention fused node") + attr_names = [attr.name for attr in attention_node.attribute] + self.assertNotIn( + "unidirectional", + attr_names, + f"The attention node should not have a 'unidirectional' attribute: {attr_names}", + ) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx b/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx new file mode 100644 index 0000000000000..d036541a70aa0 Binary files /dev/null and b/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx differ diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 450b955f161af..f02e3e8058c29 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -1,6 +1,6 @@ # This file is auto updated by dependabot # When any package below is changed, you shall run "lintrunner init" again. lintrunner==0.12.7 -lintrunner-adapters==0.12.4 -ruff==0.12.2 -clang-format==20.1.7 +lintrunner-adapters==0.12.5 +ruff==0.12.3 +clang-format==20.1.8 diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index f6e37d33b2414..893f3c80fa4b8 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -284,6 +284,8 @@ def generate_vcpkg_install_options(build_dir, args): vcpkg_install_options.append("--x-feature=vsinpu-ep") if args.use_webgpu: vcpkg_install_options.append("--x-feature=webgpu-ep") + if args.wgsl_template == "dynamic": + vcpkg_install_options.append("--x-feature=webgpu-ep-wgsl-template-dynamic") if args.use_webnn: vcpkg_install_options.append("--x-feature=webnn-ep") if args.use_xnnpack: @@ -470,6 +472,7 @@ def generate_build_tree( else "OFF" ), "-Donnxruntime_REDUCED_OPS_BUILD=" + ("ON" if is_reduced_ops_build(args) else "OFF"), + "-Donnxruntime_CLIENT_PACKAGE_BUILD=" + ("ON" if args.client_package_build else "OFF"), "-Donnxruntime_BUILD_MS_EXPERIMENTAL_OPS=" + ("ON" if args.ms_experimental else "OFF"), "-Donnxruntime_ENABLE_LTO=" + ("ON" if args.enable_lto else "OFF"), "-Donnxruntime_USE_ACL=" + ("ON" if args.use_acl else "OFF"), diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index ad27b8124c458..53d53f3e15e99 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -527,6 +527,15 @@ def add_size_reduction_args(parser: argparse.ArgumentParser) -> None: ) +def add_client_package_args(parser: argparse.ArgumentParser) -> None: + """Adds arguments for client package build package.""" + parser.add_argument( + "--client_package_build", + action="store_true", + help="Create ORT package with default settings more appropriate for client/on-device workloads.", + ) + + def add_python_binding_args(parser: argparse.ArgumentParser) -> None: """Adds arguments for Python bindings.""" parser.add_argument("--enable_pybind", action="store_true", help="Enable Python bindings.") @@ -833,6 +842,7 @@ def convert_arg_line_to_args(self, arg_line: str) -> list[str]: # Use list[str] add_dependency_args(parser) add_extension_args(parser) add_size_reduction_args(parser) + add_client_package_args(parser) # Language Bindings add_python_binding_args(parser) diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index ee7f8f2fa386a..e5e2a4749ef85 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index aa25e3f31166a..202aa61da0b80 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -60,7 +60,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index 7addb3217072a..69dc9d1a8f63d 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -6,7 +6,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: IsReleaseBuild displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index cf8bbbed70525..526ed71df2006 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index de024f0b3456f..b99246625cb77 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.36.0.250627 + default: 2.36.1.250708 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 4fa916db0de39..626a638121858 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml index 84b6d30ee32ac..a87bb55441ac7 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml @@ -72,6 +72,8 @@ stages: SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + - template: ../templates/set-version-number-variables-step.yml + # Reconstruct the build dir - task: PowerShell@2 displayName: 'PS: Extract nuget files gpu' @@ -114,6 +116,7 @@ stages: -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu" -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) + -p:PackageVersion=$(OnnxRuntimeVersion) workingDirectory: '$(Build.SourcesDirectory)\csharp' - template: ../templates/win-esrp-dll.yml diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index 433250f05125e..e2c6b25f48b6d 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.36.0.250627 + default: 2.36.1.250708 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index ab779e164b36e..74f7f782fe1b2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,7 +19,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.0.250627' + default: '2.36.1.250708' - name: enableWebGpu displayName: Enable WebGPU test diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 110f83ff587c8..92e862bd79008 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -53,7 +53,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.0.250627' + default: '2.36.1.250708' - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 535784933a087..5b48a14e2afc3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -47,7 +47,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 3e7427cc7a2e3..930dc83b73460 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.0.250627' + default: '2.36.1.250708' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index e3f549e2d649f..96eea6cd6d2fb 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.36.0.250627' + default: '2.36.1.250708' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index d533fb7c83ddd..caee5367950e6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: is1ES displayName: 'Whether the pipeline is running in 1ES' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index cd060d1fbf19f..185f41822a7e5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 2a2ac49b4e073..9a1e7e5e251c9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 8528fa3907e96..5affc152a0a4a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 1406ce338f13e..29ebb8c4e4e61 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.36.0.250627' + QnnSdk: '2.36.1.250708' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false @@ -20,7 +20,7 @@ stages: name: ${{ parameters.qnn_ep_build_pool_name }} variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} - commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags ' + commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --client_package_build --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags ' steps: - template: set-version-number-variables-step.yml diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 78fce1f9b9602..7ebf5394e4530 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 jobs: - job: 'BUILD_QNN_EP' @@ -50,7 +50,7 @@ jobs: matrix: SHARED_LIB: QnnLibKind: 'shared_lib' - ExtraQnnBuildArgs: '' + ExtraQnnBuildArgs: '--client_package_build' STATIC_LIB: QnnLibKind: 'static_lib' ExtraQnnBuildArgs: '' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index eb77c9422853d..ffeb577547f69 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 jobs: - job: 'BUILD_QNN_EP'