diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index e8f6bbe895d29..228906030d14c 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -774,13 +774,24 @@ if (onnxruntime_USE_WEBGPU) endif() if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic") - onnxruntime_fetchcontent_declare( - duktape - URL ${DEP_URL_duktape} - URL_HASH SHA1=${DEP_SHA1_duktape} - EXCLUDE_FROM_ALL - ) - onnxruntime_fetchcontent_makeavailable(duktape) + if(onnxruntime_USE_VCPKG) + find_package(unofficial-duktape CONFIG REQUIRED) + add_library(duktape_static ALIAS unofficial::duktape::duktape) + else() + onnxruntime_fetchcontent_declare( + duktape + URL ${DEP_URL_duktape} + URL_HASH SHA1=${DEP_SHA1_duktape} + EXCLUDE_FROM_ALL + ) + onnxruntime_fetchcontent_makeavailable(duktape) + + if(NOT TARGET duktape_static) + add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c") + target_compile_features(duktape_static PRIVATE c_std_99) + target_include_directories(duktape_static INTERFACE $) + endif() + endif() endif() endif() diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 69c81a5ec7b9d..4184e0b049afc 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -72,10 +72,9 @@ endif() # TensorRT 10 GA onwards, the TensorRT libraries will have major version appended to the end on Windows, - # for example, nvinfer_10.dll, nvinfer_plugin_10.dll, nvonnxparser_10.dll ... + # for example, nvinfer_10.dll, nvonnxparser_10.dll ... if (WIN32 AND TRT_GREATER_OR_EQUAL_TRT_10_GA) set(NVINFER_LIB "nvinfer_${NV_TENSORRT_MAJOR}") - set(NVINFER_PLUGIN_LIB "nvinfer_plugin_${NV_TENSORRT_MAJOR}") set(PARSER_LIB "nvonnxparser_${NV_TENSORRT_MAJOR}") endif() @@ -83,15 +82,11 @@ set(NVINFER_LIB "nvinfer") endif() - if (NOT NVINFER_PLUGIN_LIB) - set(NVINFER_PLUGIN_LIB "nvinfer_plugin") - endif() - if (NOT PARSER_LIB) set(PARSER_LIB "nvonnxparser") endif() - MESSAGE(STATUS "Looking for ${NVINFER_LIB} and ${NVINFER_PLUGIN_LIB}") + MESSAGE(STATUS "Looking for ${NVINFER_LIB}") find_library(TENSORRT_LIBRARY_INFER ${NVINFER_LIB} HINTS ${TENSORRT_ROOT} @@ -101,14 +96,6 @@ MESSAGE(STATUS "Can't find ${NVINFER_LIB}") endif() - find_library(TENSORRT_LIBRARY_INFER_PLUGIN ${NVINFER_PLUGIN_LIB} - HINTS ${TENSORRT_ROOT} - PATH_SUFFIXES lib lib64 lib/x64) - - if (NOT TENSORRT_LIBRARY_INFER_PLUGIN) - MESSAGE(STATUS "Can't find ${NVINFER_PLUGIN_LIB}") - endif() - if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) MESSAGE(STATUS "Looking for ${PARSER_LIB}") @@ -120,7 +107,7 @@ MESSAGE(STATUS "Can't find ${PARSER_LIB}") endif() - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN} ${TENSORRT_LIBRARY_NVONNXPARSER}) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_NVONNXPARSER}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") else() if (TRT_GREATER_OR_EQUAL_TRT_10_GA) @@ -153,7 +140,7 @@ endif() # Static libraries are just nvonnxparser_static on all platforms set(onnxparser_link_libs nvonnxparser_static) - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN}) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") endif() @@ -161,7 +148,7 @@ # nvonnxparser_static is linked against tensorrt libraries in onnx-tensorrt # See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121 # However, starting from TRT 10 GA, nvonnxparser_static doesn't link against tensorrt libraries. - # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER} and ${TENSORRT_LIBRARY_INFER_PLUGIN}. + # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER}. if(onnxruntime_CUDA_MINIMAL) set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) else() diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index 5b80b1262464d..2865ad33b39f4 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -172,10 +172,12 @@ file(MAKE_DIRECTORY ${WGSL_GENERATED_DIR}) # Find all WGSL template input files - file(GLOB_RECURSE WGSL_TEMPLATE_FILES "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template") + file(GLOB_RECURSE WGSL_TEMPLATE_FILES + "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template" + "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.wgsl.template") # Set wgsl-gen command line options as a list - set(WGSL_GEN_OPTIONS "-i" "../" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") + set(WGSL_GEN_OPTIONS "-i" "${ONNXRUNTIME_ROOT}/core/providers/webgpu/" "-i" "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") if (onnxruntime_WGSL_TEMPLATE STREQUAL "static") if (CMAKE_BUILD_TYPE STREQUAL "Debug") list(APPEND WGSL_GEN_OPTIONS "--generator" "static-cpp-literal") @@ -207,10 +209,9 @@ # Add the generated directory to include paths target_include_directories(onnxruntime_providers_webgpu PRIVATE ${WGSL_GENERATED_ROOT}) elseif(onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic") - add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c") - target_compile_features(duktape_static PRIVATE c_std_99) target_link_libraries(onnxruntime_providers_webgpu duktape_static) - target_include_directories(onnxruntime_providers_webgpu PRIVATE ${duktape_SOURCE_DIR}/src) + onnxruntime_add_include_to_target(onnxruntime_providers_webgpu duktape_static) + # Define the path to the generated templates.js file target_compile_definitions(onnxruntime_providers_webgpu PRIVATE "ORT_WGSL_TEMPLATES_JS_PATH=\"${WGSL_GENERATED_TEMPLATES_JS}\"") diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json index da179d0bad564..373ecec440921 100644 --- a/cmake/vcpkg.json +++ b/cmake/vcpkg.json @@ -93,6 +93,10 @@ "webgpu-ep": { "description": "Build with WebGPU EP", "dependencies": [] + }, + "webgpu-ep-wgsl-template-dynamic": { + "description": "Build with WebGPU EP with dynamic WGSL template code generator", + "dependencies": ["duktape"] } }, "overrides": [ @@ -103,6 +107,10 @@ { "name": "flatbuffers", "version": "23.5.26" + }, + { + "name": "duktape", + "version": "2.7.0#2" } ] } diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 54e03a31fceef..c18a42cc1bbc1 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -952,9 +952,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return const_cast(this)->GetNodeArg(name); } - // search this and up through any parent_graph_ instance for a NodeArg + // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding mutable NodeArg NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name); + // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding const NodeArg + const NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const; + /** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found. @param name The NodeArg name. @param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created. diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index bf1dd6e20ce64..051a3f7283cbe 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5748,6 +5748,24 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); + /** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph. + * + * Note: + * The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference + * the same underlying graph. + * + * \param[in] src_graph The source OrtGraph instance. + * \param[in] nodes A subset of the nodes/OrtNodes in 'graph'. + * \param[in] num_nodes Number of nodes. + * \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetGraphView, _In_ const OrtGraph* src_graph, _In_ const OrtNode** nodes, + _In_ size_t num_nodes, _Outptr_ OrtGraph** dst_graph); + /// @} /// \name OrtNode diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 8583fac30cfbf..7f81ab3433911 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -505,10 +505,34 @@ void EpGraph::IndexToEpNodeMap::SetEpNode(NodeIndex node_index, EpNode* ep_node) EpGraph::EpGraph(const GraphViewer& graph_viewer, PrivateTag) : OrtGraph(OrtGraphIrApi::kEpApi), graph_viewer_(graph_viewer) {} +EpGraph::EpGraph(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + PrivateTag) + : OrtGraph(OrtGraphIrApi::kEpApi), + graph_viewer_(*graph_viewer.get()), + owned_graph_viewer_(std::move(graph_viewer)), + owned_indexed_sub_graph_(std::move(indexed_sub_graph)) {} + // Static class function to create a std::unique_ptr. Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { auto ep_graph = std::make_unique(graph_viewer, PrivateTag{}); + return CreateImpl(std::move(ep_graph), graph_viewer, result); +} + +// Static class function to create a std::unique_ptr. +Status EpGraph::Create(std::unique_ptr src_graph_viewer, + std::unique_ptr src_indexed_sub_graph, + /*out*/ std::unique_ptr& result) { + auto& graph_viewer = *src_graph_viewer.get(); + auto ep_graph = std::make_unique(std::move(src_graph_viewer), + std::move(src_indexed_sub_graph), + PrivateTag{}); + + return CreateImpl(std::move(ep_graph), graph_viewer, result); +} + +Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance(); std::unordered_map> value_infos_map; diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 12fa082d3f354..7b67f21bf4eb4 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -251,15 +251,32 @@ struct EpGraph : public OrtGraph { public: EpGraph(const GraphViewer& graph_viewer, PrivateTag); + EpGraph(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + PrivateTag); /// /// Creates an instance of EpGraph, which wraps a GraphViewer. + /// This call is used when creating an EpGraph from a GraphViewer instance. The GraphViewer instance is not onwed by this EpGraph. /// /// /// /// static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + /// + /// Creates an instance of EpGraph, which wraps a GraphViewer. + /// This call is used when creating an EpGraph from a subset of nodes in another EpGraph. + /// In this case, due to the implementation of OrtApis::Graph_GetGraphView, the new EpGraph instance + /// must take ownership of both the GraphViewer and IndexedSubGraph. + /// + /// + /// + /// + static Status Create(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + /*out*/ std::unique_ptr& result); + // Defines ToExternal() and ToInternal() functions to convert between OrtGraph and EpGraph. DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtGraph, EpGraph, OrtGraphIrApi::kEpApi) @@ -331,9 +348,22 @@ struct EpGraph : public OrtGraph { const OrtValue* GetInitializerValue(std::string_view name) const; private: + /// + /// The real implementation of creating an EpGraph instance. + /// Please use one of the above 'Create' functions that internally call this function, and avoid calling this function directly. + /// + /// + /// + /// + /// + static Status CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + const GraphViewer& graph_viewer_; const EpNode* parent_node_ = nullptr; + std::unique_ptr owned_graph_viewer_ = nullptr; + std::unique_ptr owned_indexed_sub_graph_ = nullptr; + std::vector> nodes_; IndexToEpNodeMap index_to_ep_node_; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index ca40bad2b4250..4d3091520d876 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,6 +1818,10 @@ NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name return node_arg; } +const NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const { + return const_cast(this)->GetNodeArgIncludingParentGraphs(node_arg_name); +} + void Graph::ReverseDFSFrom(gsl::span from, const std::function& enter, const std::function& leave, diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 1842c2b4a0d1f..948ebaa5f7e15 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -168,7 +168,15 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) filtered_node_inputs_including_initializers_.reserve(metadef->inputs.size()); for (const auto& input : metadef->inputs) { - const auto* nodearg = graph.GetNodeArg(input); + // NodeArgs from the current scope or any outer scopes should be handled correctly. + // + // There is an edge case where the model consists of a graph with subgraphs nested across three levels. + // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). + // When constructing a new GraphViewer for the second- and third-layer subgraphs, + // the second-layer graph may not have the corresponding value_info for that first-layer input, + // because the second-layer graph itself doesn't consume it. + // Therefore, when working within the second-layer graph, we need to search outer scopes for the missing value_info. + const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(input); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Input not found:", input); filtered_node_inputs_including_initializers_.push_back(nodearg); if (!graph.IsInitializedTensor(input)) { @@ -177,7 +185,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) } for (const auto& output : metadef->outputs) { - const auto* nodearg = graph.GetNodeArg(output); + const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(output); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Output not found:", output); filtered_node_outputs_.push_back(nodearg); } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 711d81186bad1..c5b6507ac847b 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1304,7 +1304,7 @@ std::vector NvExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - return std::make_unique(device_id, CUDA_PINNED); + return std::make_unique(CUDA_PINNED, device_id); }, narrow(device_id_)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 86b684f8c6ebd..21947a22e2b92 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -235,7 +235,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(reshape_input, reshape_input_info)); bool needs_reshape = false; - const std::string reshape4d = input_names[0] + "_pre_reshape"; + const std::string reshape_prior_out = input_names[0] + "_prior_reshape"; if (input_shape.size() == 3) { needs_reshape = true; // build new_shape = {N, 1, C, L} @@ -245,25 +245,24 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra input_shape[1], input_shape[2]}; - const std::string reshape_node_name = "pre_reshape"; - QnnTensorWrapper rw( - reshape4d, + QnnTensorWrapper reshape_prior_tensor( + reshape_prior_out, QNN_TENSOR_TYPE_NATIVE, reshape_input_info.qnn_data_type, reshape_input_info.quant_param.Copy(), std::move(new_shape)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(rw)), - "Failed to add reshape-4d tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_prior_tensor)), + "Failed to add reshape prior tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - reshape_node_name, + utils::GetNodeName(node_unit) + "_reshape_prior", QNN_OP_PACKAGE_NAME_QTI_AISW, - "Reshape", + QNN_OP_RESHAPE, {input_names[0]}, - {reshape4d}, + {reshape_prior_out}, {}, do_op_validation), - "Failed to create reshape-4d node."); - input_names[0] = reshape4d; + "Failed to create reshape prior node for pool op."); + input_names[0] = reshape_prior_out; input_shape = {input_shape[0], 1, input_shape[1], input_shape[2]}; } @@ -446,9 +445,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } const auto& outputs = node_unit.Outputs(); const std::string real_out = outputs[0].node_arg.Name(); - const std::string pool_name = "poolmax2d"; - const std::string pool_out = real_out + "_post_reshape"; - const std::string post_reshape_node_name = "post_reshape"; + const std::string pool_out = real_out + "_reshape_after"; const std::string qnn_op = GetQnnOpType(op_type); TensorInfo output_info{}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); @@ -466,33 +463,34 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra "Failed to add tensor for pool_out"); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - pool_name, + utils::GetNodeName(node_unit) + "_pool2d", QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op, - {reshape4d}, + {reshape_prior_out}, {pool_out}, std::move(param_tensor_names), do_op_validation), - "Failed to create QNN Pool node for rank-3 input."); + "Failed to create pool node for rank-3 input."); std::vector final_shape3d = output_info.shape; - QnnTensorWrapper reshape_back_tensor( + QnnTensorWrapper reshape_after_tensor( real_out, tensor_type, output_info.qnn_data_type, output_info.quant_param.Copy(), std::move(final_shape3d)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_back_tensor)), "Failed to add tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_after_tensor)), + "Failed to add reshape after tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - post_reshape_node_name, + utils::GetNodeName(node_unit) + "_reshape_after", QNN_OP_PACKAGE_NAME_QTI_AISW, - "Reshape", + QNN_OP_RESHAPE, {pool_out}, {real_out}, {}, do_op_validation), - "Failed to create reshape-back node."); + "Failed to create reshape after node for pool op."); return Status::OK(); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 90a4294fb47f0..1e9fafe8aa323 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -7,6 +7,25 @@ #include "tensorrt_execution_provider_custom_ops.h" #include "tensorrt_execution_provider.h" +// The filename extension for a shared library is different per platform +#ifdef _WIN32 +#define LIBRARY_PREFIX +#define LIBRARY_EXTENSION ORT_TSTR(".dll") +#elif defined(__APPLE__) +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".dylib" +#else +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".so" +#endif + +#ifdef _WIN32 +#define ORT_DEF2STR_HELPER(x) L#x +#else +#define ORT_DEF2STR_HELPER(X) #X +#endif +#define ORT_DEF2STR(x) ORT_DEF2STR_HELPER(x) + namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(bool verbose); @@ -58,8 +77,31 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& // Get all registered TRT plugins from registry LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ..."; TensorrtLogger trt_logger = GetTensorrtLogger(false); - initLibNvInferPlugins(&trt_logger, ""); + try { + void* library_handle = nullptr; + const auto& env = onnxruntime::GetDefaultEnv(); +#if NV_TENSORRT_MAJOR < 10 + auto full_path = env.GetRuntimePath() + + PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin") LIBRARY_EXTENSION); +#else +#ifdef _WIN32 + auto full_path = PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin_" ORT_DEF2STR(NV_TENSORRT_MAJOR)) LIBRARY_EXTENSION); +#else + auto full_path = PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin") LIBRARY_EXTENSION ORT_TSTR("." ORT_DEF2STR(NV_TENSORRT_MAJOR))); +#endif +#endif + + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, false, &library_handle)); + bool (*dyn_initLibNvInferPlugins)(void* logger, char const* libNamespace); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "initLibNvInferPlugins", (void**)&dyn_initLibNvInferPlugins)); + if (!dyn_initLibNvInferPlugins(&trt_logger, "")) { + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugin library was found but was not able to initialize default plugins."; + } + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugins successfully loaded."; + } catch (const std::exception&) { + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugin library is not on the path and is therefore ignored"; + } int num_plugin_creator = 0; auto plugin_creators = getPluginRegistry()->getAllCreators(&num_plugin_creator); std::unordered_set registered_plugin_names; diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json index 7cde6c17f54e9..df1940ed6416b 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json @@ -9,13 +9,13 @@ "version": "1.0.0", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.3" + "@fs-eire/wgsl-template": "^0.1.13" } }, "node_modules/@fs-eire/wgsl-template": { - "version": "0.1.10", - "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.10.tgz", - "integrity": "sha512-F5qQZxNweZ3ZD3d9RNc/g3nTiW7jyaAVi7SlMOL4wOfXh+Nm/qca2DISNTf3kjpVqkoazMJGbZ6TPQ4a/vjw0g==", + "version": "0.1.13", + "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.13.tgz", + "integrity": "sha512-SOQjVCQCUmXb9qYr2E3CKNs88/FzINuhFJiobBEkSAsyKtJby9oFWGZnrEO+hIl/oDTLA01LbjiDxuf6TGHE/w==", "license": "MIT", "dependencies": { "minimist": "^1.2.8" diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json index 34831ccddeb33..246e7365531e0 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json @@ -10,6 +10,6 @@ "author": "", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.3" + "@fs-eire/wgsl-template": "^0.1.13" } } diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index e821265fff80d..142d64caa64aa 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -99,69 +99,93 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n return true; } -// Check if all input tensor ranks of the given node are supported by WebNN. -bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) { - const std::string_view op_type = node.OpType(); - const auto it = op_inputs_map.find(op_type); - if (it == op_inputs_map.end()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type << "] is not found in the op inputs map."; +// Check if a single input's rank of an ONNX op is supported by corresponding WebNN op. +bool IsInputRankSupported(const emscripten::val& wnn_limits, + const std::string_view webnn_op_type, + const std::string_view input_name, + const size_t input_rank, + const std::string_view node_name, + const logging::Logger& logger) { + const std::string webnn_op_type_str(webnn_op_type); + const std::string input_name_str(input_name); + + if (wnn_limits[webnn_op_type_str].isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type: [" << webnn_op_type + << "] is not defined in WebNN MLOpSupportLimits."; return false; } - const auto& input_defs = node.InputDefs(); - const std::string_view webnn_op_type = it->second.opType; - const std::string webnn_op_type_str(webnn_op_type); + const emscripten::val input_limits = wnn_limits[webnn_op_type_str][input_name_str]; - for (const auto& input : it->second.inputs) { - if (static_cast(input.index) >= input_defs.size() || input_defs[input.index] == nullptr) { - LOGS(logger, VERBOSE) << "Input index [" << input.index - << "] for operator type [" << op_type - << "], corresponding WebNN op type [" << webnn_op_type - << "], WebNN input name [" << input.name - << "] is invalid."; - return false; - } + if (input_limits.isUndefined()) { + LOGS(logger, VERBOSE) << "Node name: [" << node_name + << "], WebNN op type: [" << webnn_op_type + << "], input [" << input_name + << "]: limits are not defined in WebNN MLOpSupportLimits."; + return false; + } - std::vector input_shape; - if (!GetShape(*input_defs[input.index], input_shape, logger)) { - return false; - } + const emscripten::val rank_range = input_limits["rankRange"]; + if (rank_range.isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type + << "] input [" << input_name + << "]: missing 'rankRange' attribute."; + return false; + } - const std::string input_name_str(input.name); - if (wnn_limits[webnn_op_type_str].isUndefined() || - wnn_limits[webnn_op_type_str][input_name_str].isUndefined()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name " << input.name - << " is not defined in wnn_limits."; - return false; - } + const emscripten::val min_val = rank_range["min"]; + const emscripten::val max_val = rank_range["max"]; + if (min_val.isUndefined() || max_val.isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type + << "] input [" << input_name + << "]: its 'rankRange' limits is missing valid 'min' or 'max' attributes."; + return false; + } - const auto& input_limits = wnn_limits[webnn_op_type_str][input_name_str]; - if (input_limits["rankRange"].isUndefined()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name " << input.name - << "'s rankRange is not defined."; - return false; + size_t min_rank = min_val.as(); + size_t max_rank = max_val.as(); + if (input_rank < min_rank || input_rank > max_rank) { + LOGS(logger, VERBOSE) << "Node name: [" << node_name + << "] WebNN op type [" << webnn_op_type + << "] input [" << input_name << "] rank " << input_rank + << " is not in supported range [" << min_rank << ", " << max_rank << "]"; + return false; + } + + return true; +} + +bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) { + const std::string_view onnx_op_type = node.OpType(); + const std::string_view webnn_op_type = GetWebNNOpType(onnx_op_type); + + if (webnn_op_type.empty()) { + LOGS(logger, VERBOSE) << "ONNX op type: [" << onnx_op_type << "]'s corresponding WebNN op is not found."; + return false; + } + + std::vector inputs; + if (!GetWebNNOpInputs(onnx_op_type, inputs, logger)) { + return false; + } + + const auto& input_defs = node.InputDefs(); + + for (const auto& input : inputs) { + // If it is an optional input and is absent, skip. + if (!TensorExists(input_defs, input.index)) { + continue; } - int input_dim_size = static_cast(input_shape.size()); - int min_rank = input_limits["rankRange"]["min"].as(); - int max_rank = input_limits["rankRange"]["max"].as(); - - if (input_dim_size < min_rank || input_dim_size > max_rank) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name: " << input.name - << ", input size " << input_dim_size - << " is not in supported range [" << min_rank << ", " << max_rank << "]"; + std::vector shape; + if (!GetShape(*input_defs[input.index], shape, logger) || + !IsInputRankSupported(wnn_limits, webnn_op_type, input.name, + shape.size(), + node.Name(), logger)) { return false; } } + return true; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index d59788600f997..50e361ede221e 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -216,6 +216,13 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger); +bool IsInputRankSupported(const emscripten::val& wnn_limits, + const std::string_view webnn_op_type, + const std::string_view input_name, + const size_t input_rank, + const std::string_view node_name, + const logging::Logger& logger); + // Get a set of nodes supported by WebNN EP. std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, @@ -244,6 +251,33 @@ inline std::string_view GetWebNNOpType(const std::string_view onnx_op_type) { return (it != op_inputs_map.end()) ? it->second.opType : ""; } +// Get corresponding input name of WebNN op type by ONNX op type from op_input_map +inline std::string_view GetWebNNInputName(const std::string_view onnx_op_type, const int input_index) { + const auto it = op_inputs_map.find(onnx_op_type); + + if (it != op_inputs_map.end()) { + for (const auto& input : it->second.inputs) { + if (input.index == input_index) { + return input.name; + } + } + } + + return ""; +} + +inline bool GetWebNNOpInputs(const std::string_view onnx_op_type, + std::vector& inputs, + const logging::Logger& logger) { + const auto it = op_inputs_map.find(onnx_op_type); + if (it == op_inputs_map.end()) { + LOGS(logger, VERBOSE) << "WebNN op inputs not found for op type: " << onnx_op_type; + return false; + } + inputs = it->second.inputs; + return true; +} + bool AreDataTypesSame(const std::string_view op_type, gsl::span input_types, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index 8589237617745..e0cd48b6883c2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -75,7 +75,7 @@ bool ConcatOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); } void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc index 06beb56415609..b4b9d9a0d4c6b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc @@ -56,13 +56,12 @@ bool GatherElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const N const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t data_type; - int32_t indices_type; + int32_t data_type, indices_type; if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { return false; } - return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc index 9200c596c0e53..a15542061dd60 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc @@ -61,13 +61,12 @@ bool GatherNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& n const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t data_type; - int32_t indices_type; + int32_t data_type, indices_type; if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { return false; } - return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index d84c70032e1d1..86408557013a0 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -74,13 +74,13 @@ bool GatherOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod const auto& input = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t input_type; - int32_t indices_type; + int32_t input_type, indices_type; + if (!GetType(input, input_type, logger) || !GetType(indices, indices_type, logger)) return false; - return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 02f46c85d1d06..7af17fdc5db78 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -91,7 +91,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); std::vector a_zero_point_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[2], a_zero_point_shape, logger), "Cannot get shape of a_zero_point"); - // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to deafult value 1.0f. + // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to default value 1.0f. // The scale input should have the same shape as the zero point input. a_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, @@ -268,11 +268,45 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - if (op_type == "MatMulInteger") { - // The first decomposed op of MatMulInteger is DequantizeLinear, and so - // we only need to ensure it supports the input0_type. + if (op_type == "Gemm") { + return IsInputRankSupportedByOp(node, wnn_limits, logger) && + IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); + } else if (op_type == "MatMulInteger") { + // Check up to 4 inputs for MatMulInteger + for (size_t i = 0; i < input_defs.size(); ++i) { + std::vector shape; + if (!GetShape(*input_defs[i], shape, logger)) { + return false; + } + + // We made workaround to support 1D for input A and B, skip further checks if they are 1D + if (i <= 1 && shape.size() == 1) { + continue; + } + + // For DequantizeLinear, input indices: 0 (x), 1 (scale), 2 (zero_point) + if (!IsInputRankSupported(wnn_limits, "dequantizeLinear", + (i < 2) ? "input" : "zeroPoint", + shape.size(), node.Name(), logger)) { + return false; + } + } return IsDataTypeSupportedByOp("DequantizeLinear", input0_type, wnn_limits, "input", "x", logger); - } else { + } else { // MatMul + for (int i = 0; i < 2; ++i) { + std::vector shape; + if (!GetShape(*input_defs[i], shape, logger)) { + return false; + } + + if (shape.size() == 1) { + continue; + } + + if (!IsInputRankSupported(wnn_limits, "matmul", (i == 0) ? "a" : "b", shape.size(), node.Name(), logger)) { + return false; + } + } return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index dfe80dd419092..6e86ca77464e5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -219,7 +219,7 @@ bool GruOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger); } bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 42940083cad8e..1675615280de9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -92,7 +92,7 @@ bool LogicalOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no } std::string onnx_input_name = op_type == "Not" ? "X" : "A"; - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); } void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 09e584bc66f8a..fcdc84b75c048 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -242,7 +242,7 @@ bool LstmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } bool LstmOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 4e4014e3553ea..4d9cc39bd38fe 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -108,7 +108,7 @@ bool MaxMinOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); } void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index dd25fb9bf9315..eccf67cc46c9a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -167,7 +167,7 @@ bool QDQOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "scale", "x_scale", logger) && (!has_input2 || IsDataTypeSupportedByOp(op_type, input2_type, wnn_limits, "zeroPoint", "x_zero_point", logger)); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc index f894e8bfbd517..ae3d559023625 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc @@ -71,7 +71,6 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& updates = *node.InputDefs()[2]; - const std::string_view op_type = node.OpType(); int32_t data_type; int32_t indices_type; @@ -85,7 +84,9 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const return false; } - return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + const std::string_view op_type = node.OpType(); + + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc index e61ac3dcc9617..5467e91761823 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc @@ -63,7 +63,6 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& updates = *node.InputDefs()[2]; - const std::string_view op_type = node.OpType(); int32_t data_type; int32_t indices_type; @@ -76,8 +75,8 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& if (data_type != updates_type) { return false; } - - return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + const std::string_view op_type = node.OpType(); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 7a7f64b1ec96d..5d6d59663da61 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -66,7 +66,7 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no return false; } - return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); } void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h index 5e860eea7cac9..bf95527beb44e 100644 --- a/onnxruntime/core/providers/webnn/builders/map_info.h +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -139,7 +139,7 @@ const std::unordered_map op_inputs_map = { {"Mul", {"mul", {{0, "a"}, {1, "b"}}}}, {"Pow", {"pow", {{0, "a"}, {1, "b"}}}}, {"Concat", {"concat", {{0, "inputs"}}}}, - {"Not", {"logicalNot", {{0, "input"}}}}, + {"Not", {"logicalNot", {{0, "a"}}}}, {"Flatten", {"reshape", {{0, "input"}}}}, {"LpPool", {"l2Pool2d", {{0, "input"}}}}, {"Reshape", {"reshape", {{0, "input"}}}}, diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 18b545483b38b..312ddd7e52e00 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2714,6 +2714,91 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, _O API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph, + _In_ const OrtNode** nodes, + _In_ size_t num_nodes, + _Outptr_ OrtGraph** dst_graph) { + API_IMPL_BEGIN + + if (num_nodes == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_nodes' argument should be > 0"); + } + + const EpGraph* ep_graph = EpGraph::ToInternal(src_graph); + if (ep_graph == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "src_graph is a ModelEditorGraph which doesn't support Graph_GetSubGraph."); + } + const Graph& graph = ep_graph->GetGraphViewer().GetGraph(); + + // Create a GraphViewer with filtered info + std::unique_ptr indexed_sub_graph = std::make_unique(); + std::unique_ptr metadef = std::make_unique(); + metadef->name = "sub_graph"; + metadef->since_version = 1; + std::unordered_set outputs; + std::unordered_set initializers; + + auto add_inputs = [&](ConstPointerContainer> defs) { + for (const auto* def : defs) { + if (def->Exists()) { + // not the output of a previous node + if (outputs.count(def->Name()) == 0) { + metadef->inputs.push_back(def->Name()); + } else { + // consumed by node so no longer subgraph output + // NOTE: Ignoring edge case where a node output is an overall graph output AND a node input + outputs.erase(def->Name()); + } + + if (graph.IsInitializedTensor(def->Name())) { + initializers.insert(def); + } + } + } + }; + + auto add_node = [&](const Node& node) { + indexed_sub_graph->nodes.push_back(node.Index()); + add_inputs(node.InputDefs()); + add_inputs(node.ImplicitInputDefs()); + + for (const auto* def : node.OutputDefs()) { + outputs.insert(def->Name()); + } + }; + + // Add nodes + for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) { + const OrtNode* ort_node = nodes[node_idx]; + const EpNode* ep_node = EpNode::ToInternal(ort_node); + if (ep_node == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Graph_GetSubGraph."); + } + add_node(ep_node->GetInternalNode()); + } + + // Add initializers + for (auto& initializer : initializers) { + metadef->constant_initializers.push_back(initializer->Name()); + } + + // Add outputs + for (auto& output : outputs) { + metadef->outputs.push_back(output); + } + + indexed_sub_graph->SetMetaDef(std::move(metadef)); + auto graph_viewer = std::make_unique(graph, *indexed_sub_graph.get()); + + std::unique_ptr result; + ORT_API_RETURN_IF_STATUS_NOT_OK(EpGraph::Create(std::move(graph_viewer), std::move(indexed_sub_graph), result)); + + *dst_graph = result.release(); + + return nullptr; + API_IMPL_END +} + // // OrtNode // @@ -3629,6 +3714,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Graph_GetNumNodes, &OrtApis::Graph_GetNodes, &OrtApis::Graph_GetParentNode, + &OrtApis::Graph_GetGraphView, &OrtApis::Node_GetId, &OrtApis::Node_GetName, &OrtApis::Node_GetOperatorType, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 75db44cb9e9ff..b53863c02cfef 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -649,6 +649,8 @@ ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t* ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph, _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); +ORT_API_STATUS_IMPL(Graph_GetGraphView, _In_ const OrtGraph* graph, _In_ const OrtNode** nodes, _In_ size_t num_nodes, + _Outptr_ OrtGraph** subgraph); // OrtNode ORT_API_STATUS_IMPL(Node_GetId, _In_ const OrtNode* node, _Out_ size_t* node_id); diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index fbeae39c39d21..319c5aa468f7e 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -86,6 +86,7 @@ "InstanceNormalization": QDQNormalization, "LayerNormalization": QDQNormalization, "BatchNormalization": QDQNormalization, + "TopK": QDQDirect8BitOp, } diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index e9bed3ac45529..17e829e37f729 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -7,12 +7,15 @@ #include #include #include +#include #include "core/common/common.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_type_and_shape.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/graph/ep_api_types.h" +#include "core/graph/graph_proto_serializer.h" #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL #include "core/providers/utils/ort_graph_to_proto.h" @@ -31,6 +34,7 @@ namespace test { // forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent // to a graph represented by the internal ORT GraphViewer class. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph); +static void Check_Graph_GetSubgraph(const OrtGraph& api_graph); // // Tests @@ -73,6 +77,16 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { + // The overall structure of this model is similar to the one used in "Check3LayerNestedSubgraph" test. + // The model consists of a graph with subgraphs nested across three levels. + // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/three_layer_nested_subgraph_v2.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); +} + static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& output_data) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::SessionOptions sess_options; @@ -474,6 +488,48 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span nodes(num_nodes); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, nodes.data(), nodes.size())); + + // Select a half of nodes to create a OrtGraph + size_t num_selected_nodes = std::max((nodes.size() >> 1), (size_t)1); + std::vector selected_nodes(num_selected_nodes); + + for (size_t i = 0; i < num_selected_nodes; i++) { + selected_nodes[i] = nodes[i]; + } + + OrtGraph* sub_graph; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetGraphView(&api_graph, selected_nodes.data(), selected_nodes.size(), &sub_graph)); + + // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk. + // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw. + const GraphViewer& sub_graph_viewer = EpGraph::ToInternal(sub_graph)->GetGraphViewer(); + std::unique_ptr model = std::make_unique(sub_graph_viewer.Name(), true, sub_graph_viewer.GetGraph().GetLogger()); + auto model_proto = std::make_unique(model->ToProto()); + GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast(1)); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + + const char* graph_name = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetName(&api_graph, &graph_name)); + std::string name = graph_name; + name += "_half.onnx"; + + // Dump the graph for debugging + // std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary); + // model_proto->SerializeToOstream(&dump); + + ort_api.ReleaseGraph(sub_graph); +} + // Checks that the contents of the original GraphViewer matches the contents of the OrtGraph. // Uses the public C APIs to traverse the OrtGraph. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) { @@ -682,6 +738,9 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } } } + + // Check creating an OrtGraph from a subset of nodes in an OrtGraph + Check_Graph_GetSubgraph(api_graph); } } // namespace test diff --git a/onnxruntime/test/python/quantization/test_op_topk.py b/onnxruntime/test/python/quantization/test_op_topk.py new file mode 100644 index 0000000000000..1fdd0c987d1e8 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_topk.py @@ -0,0 +1,103 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest + +import numpy as np +from onnx import TensorProto, helper, save +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static + + +class TestTopKModel(unittest.TestCase): + @staticmethod + def construct_model(model_path, input_shape, axis_attr, k): + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape) + k_tensor = helper.make_tensor("k", TensorProto.INT64, [1], [k]) + output_shape = input_shape[:] + output_shape[axis_attr] = k + output_values = helper.make_tensor_value_info("values", TensorProto.FLOAT, [1, k]) + output_indices = helper.make_tensor_value_info("indices", TensorProto.INT64, [1, k]) + + node = helper.make_node( + "TopK", inputs=["input", "k"], outputs=["values", "indices"], name="topk_node", axis=axis_attr + ) + + graph = helper.make_graph( + [node], + "quant_topk_op_test", + [input_tensor], + [output_values, output_indices], + initializer=[k_tensor], + ) + + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 16), helper.make_opsetid("com.microsoft", 1)] + ) + save(model, model_path) + + def quantize_topk_test(self, activation_type, weight_type, extra_options={}): # noqa: B006 + model_fp32_path = "topk_fp32.onnx" + input_shape = [1, 10] + axis = 1 + k = 3 + self.construct_model(model_fp32_path, input_shape, axis, k) + + input_data_list = [ + {"input": np.array([[1.8, 2.5, -5.9, 5.2, 4.1, 7.3, 0.2, -0.5, 0.845, 3.9]], dtype=np.float32)} + ] + data_reader = TestDataFeeds(input_data_list) + + activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_qdq_path = f"topk_{activation_type_str}{weight_type_str}_{'QNoInCk' if extra_options['ForceQuantizeNoInputCheck'] else 'NoQNoInCk'}_qdq.onnx" + + # Verify QDQ mode + data_reader.rewind() + quantize_static( + model_fp32_path, + model_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + qdqnode_counts = ( + { + "TopK": 1, + "QuantizeLinear": 2, + "DequantizeLinear": 2, + } + if extra_options["ForceQuantizeNoInputCheck"] + else { + "TopK": 1, + "QuantizeLinear": 0, + "DequantizeLinear": 0, + } + ) + check_op_type_count(self, model_qdq_path, **qdqnode_counts) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + check_qtype_by_node_type(self, model_qdq_path, qnode_io_qtypes) + data_reader.rewind() + check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) + + def test_quantize_topk_u8u8(self): + self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": True}) + + def test_quantize_topk_u8u8_no_force_quantize_no_input_check(self): + self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": False}) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx b/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx new file mode 100644 index 0000000000000..d036541a70aa0 Binary files /dev/null and b/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx differ diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index f6e37d33b2414..f864b8eb4a74d 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -284,6 +284,8 @@ def generate_vcpkg_install_options(build_dir, args): vcpkg_install_options.append("--x-feature=vsinpu-ep") if args.use_webgpu: vcpkg_install_options.append("--x-feature=webgpu-ep") + if args.wgsl_template == "dynamic": + vcpkg_install_options.append("--x-feature=webgpu-ep-wgsl-template-dynamic") if args.use_webnn: vcpkg_install_options.append("--x-feature=webnn-ep") if args.use_xnnpack: