From d29328588e00bf2e578d904403fa5c6627754346 Mon Sep 17 00:00:00 2001 From: qti-yuduo Date: Wed, 9 Jul 2025 20:17:13 -0700 Subject: [PATCH 1/6] [QNN EP] Fix pool with reshape name conflicts (#25332) Naming conflicts when expand-pool2d-squeeze (implemented as reshape) logic is invoked during ONNX -> QNN op lowering. Model with multiple pool 1D ops would hit this issue. --- .../qnn/builder/opbuilder/pool_op_builder.cc | 42 +++++++++---------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 86b684f8c6ebd..21947a22e2b92 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -235,7 +235,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(reshape_input, reshape_input_info)); bool needs_reshape = false; - const std::string reshape4d = input_names[0] + "_pre_reshape"; + const std::string reshape_prior_out = input_names[0] + "_prior_reshape"; if (input_shape.size() == 3) { needs_reshape = true; // build new_shape = {N, 1, C, L} @@ -245,25 +245,24 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra input_shape[1], input_shape[2]}; - const std::string reshape_node_name = "pre_reshape"; - QnnTensorWrapper rw( - reshape4d, + QnnTensorWrapper reshape_prior_tensor( + reshape_prior_out, QNN_TENSOR_TYPE_NATIVE, reshape_input_info.qnn_data_type, reshape_input_info.quant_param.Copy(), std::move(new_shape)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(rw)), - "Failed to add reshape-4d tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_prior_tensor)), + "Failed to add reshape prior tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - reshape_node_name, + utils::GetNodeName(node_unit) + "_reshape_prior", QNN_OP_PACKAGE_NAME_QTI_AISW, - "Reshape", + QNN_OP_RESHAPE, {input_names[0]}, - {reshape4d}, + {reshape_prior_out}, {}, do_op_validation), - "Failed to create reshape-4d node."); - input_names[0] = reshape4d; + "Failed to create reshape prior node for pool op."); + input_names[0] = reshape_prior_out; input_shape = {input_shape[0], 1, input_shape[1], input_shape[2]}; } @@ -446,9 +445,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } const auto& outputs = node_unit.Outputs(); const std::string real_out = outputs[0].node_arg.Name(); - const std::string pool_name = "poolmax2d"; - const std::string pool_out = real_out + "_post_reshape"; - const std::string post_reshape_node_name = "post_reshape"; + const std::string pool_out = real_out + "_reshape_after"; const std::string qnn_op = GetQnnOpType(op_type); TensorInfo output_info{}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); @@ -466,33 +463,34 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra "Failed to add tensor for pool_out"); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - pool_name, + utils::GetNodeName(node_unit) + "_pool2d", QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op, - {reshape4d}, + {reshape_prior_out}, {pool_out}, std::move(param_tensor_names), do_op_validation), - "Failed to create QNN Pool node for rank-3 input."); + "Failed to create pool node for rank-3 input."); std::vector final_shape3d = output_info.shape; - QnnTensorWrapper reshape_back_tensor( + QnnTensorWrapper reshape_after_tensor( real_out, tensor_type, output_info.qnn_data_type, output_info.quant_param.Copy(), std::move(final_shape3d)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_back_tensor)), "Failed to add tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_after_tensor)), + "Failed to add reshape after tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - post_reshape_node_name, + utils::GetNodeName(node_unit) + "_reshape_after", QNN_OP_PACKAGE_NAME_QTI_AISW, - "Reshape", + QNN_OP_RESHAPE, {pool_out}, {real_out}, {}, do_op_validation), - "Failed to create reshape-back node."); + "Failed to create reshape after node for pool op."); return Status::OK(); } From ff815674fdbbe6b1b78309950c2dad9d49cf4e8f Mon Sep 17 00:00:00 2001 From: Akupadhye Date: Thu, 10 Jul 2025 23:38:16 +0530 Subject: [PATCH 2/6] Added creation of QDQ for TopK node (#25309) - Added TopK in registry.py so as to create QDQ nodes for the op - Ensure that both the input and output quantization params are equal - Added unit test to verify the creation of QDQ nodes for TopK ### Description: Added support for creation of QDQ nodes for TopK when quantized with ORT static quantization tool ### Motivation and Context: Currently there is support to form a node unit for TopK operator when QDQ nodes are present and both the input and output quantization params are equal. But there was no support to create QDQ nodes for TopK operator in the ORT static quantization tool --- .../python/tools/quantization/registry.py | 1 + .../test/python/quantization/test_op_topk.py | 103 ++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 onnxruntime/test/python/quantization/test_op_topk.py diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index fbeae39c39d21..319c5aa468f7e 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -86,6 +86,7 @@ "InstanceNormalization": QDQNormalization, "LayerNormalization": QDQNormalization, "BatchNormalization": QDQNormalization, + "TopK": QDQDirect8BitOp, } diff --git a/onnxruntime/test/python/quantization/test_op_topk.py b/onnxruntime/test/python/quantization/test_op_topk.py new file mode 100644 index 0000000000000..1fdd0c987d1e8 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_topk.py @@ -0,0 +1,103 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest + +import numpy as np +from onnx import TensorProto, helper, save +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static + + +class TestTopKModel(unittest.TestCase): + @staticmethod + def construct_model(model_path, input_shape, axis_attr, k): + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, input_shape) + k_tensor = helper.make_tensor("k", TensorProto.INT64, [1], [k]) + output_shape = input_shape[:] + output_shape[axis_attr] = k + output_values = helper.make_tensor_value_info("values", TensorProto.FLOAT, [1, k]) + output_indices = helper.make_tensor_value_info("indices", TensorProto.INT64, [1, k]) + + node = helper.make_node( + "TopK", inputs=["input", "k"], outputs=["values", "indices"], name="topk_node", axis=axis_attr + ) + + graph = helper.make_graph( + [node], + "quant_topk_op_test", + [input_tensor], + [output_values, output_indices], + initializer=[k_tensor], + ) + + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 16), helper.make_opsetid("com.microsoft", 1)] + ) + save(model, model_path) + + def quantize_topk_test(self, activation_type, weight_type, extra_options={}): # noqa: B006 + model_fp32_path = "topk_fp32.onnx" + input_shape = [1, 10] + axis = 1 + k = 3 + self.construct_model(model_fp32_path, input_shape, axis, k) + + input_data_list = [ + {"input": np.array([[1.8, 2.5, -5.9, 5.2, 4.1, 7.3, 0.2, -0.5, 0.845, 3.9]], dtype=np.float32)} + ] + data_reader = TestDataFeeds(input_data_list) + + activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_qdq_path = f"topk_{activation_type_str}{weight_type_str}_{'QNoInCk' if extra_options['ForceQuantizeNoInputCheck'] else 'NoQNoInCk'}_qdq.onnx" + + # Verify QDQ mode + data_reader.rewind() + quantize_static( + model_fp32_path, + model_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + qdqnode_counts = ( + { + "TopK": 1, + "QuantizeLinear": 2, + "DequantizeLinear": 2, + } + if extra_options["ForceQuantizeNoInputCheck"] + else { + "TopK": 1, + "QuantizeLinear": 0, + "DequantizeLinear": 0, + } + ) + check_op_type_count(self, model_qdq_path, **qdqnode_counts) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + check_qtype_by_node_type(self, model_qdq_path, qnode_io_qtypes) + data_reader.rewind() + check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) + + def test_quantize_topk_u8u8(self): + self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": True}) + + def test_quantize_topk_u8u8_no_force_quantize_no_input_check(self): + self.quantize_topk_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={"ForceQuantizeNoInputCheck": False}) + + +if __name__ == "__main__": + unittest.main() From d3820a25de02e24d72a9977aec28d205e9d99ee2 Mon Sep 17 00:00:00 2001 From: Wang Ning Date: Fri, 11 Jul 2025 02:32:59 +0800 Subject: [PATCH 3/6] [WebNN] Refactor webnn op input rank check and add validation for ops (#25185) ### Description Development for webnn op input rank range check ### Motivation and Context - refactor webnn op input rank check - add validation for various ops - take `gemm` op as an example to perform inputs rank check of decomposed ops @honry @fdwr PTAL --- .../core/providers/webnn/builders/helper.cc | 126 +++++++++++------- .../core/providers/webnn/builders/helper.h | 34 +++++ .../webnn/builders/impl/concat_op_builder.cc | 2 +- .../impl/gatherElements_op_builder.cc | 5 +- .../builders/impl/gatherND_op_builder.cc | 5 +- .../webnn/builders/impl/gather_op_builder.cc | 6 +- .../webnn/builders/impl/gemm_op_builder.cc | 44 +++++- .../webnn/builders/impl/gru_op_builder.cc | 2 +- .../webnn/builders/impl/logical_op_builder.cc | 2 +- .../webnn/builders/impl/lstm_op_builder.cc | 2 +- .../webnn/builders/impl/max_min_op_builder.cc | 2 +- .../webnn/builders/impl/qdq_op_builder.cc | 2 +- .../impl/scatterElements_op_builder.cc | 5 +- .../builders/impl/scatterND_op_builder.cc | 5 +- .../webnn/builders/impl/ternary_op_builder.cc | 2 +- .../core/providers/webnn/builders/map_info.h | 2 +- 16 files changed, 168 insertions(+), 78 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index e821265fff80d..142d64caa64aa 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -99,69 +99,93 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n return true; } -// Check if all input tensor ranks of the given node are supported by WebNN. -bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) { - const std::string_view op_type = node.OpType(); - const auto it = op_inputs_map.find(op_type); - if (it == op_inputs_map.end()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type << "] is not found in the op inputs map."; +// Check if a single input's rank of an ONNX op is supported by corresponding WebNN op. +bool IsInputRankSupported(const emscripten::val& wnn_limits, + const std::string_view webnn_op_type, + const std::string_view input_name, + const size_t input_rank, + const std::string_view node_name, + const logging::Logger& logger) { + const std::string webnn_op_type_str(webnn_op_type); + const std::string input_name_str(input_name); + + if (wnn_limits[webnn_op_type_str].isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type: [" << webnn_op_type + << "] is not defined in WebNN MLOpSupportLimits."; return false; } - const auto& input_defs = node.InputDefs(); - const std::string_view webnn_op_type = it->second.opType; - const std::string webnn_op_type_str(webnn_op_type); + const emscripten::val input_limits = wnn_limits[webnn_op_type_str][input_name_str]; - for (const auto& input : it->second.inputs) { - if (static_cast(input.index) >= input_defs.size() || input_defs[input.index] == nullptr) { - LOGS(logger, VERBOSE) << "Input index [" << input.index - << "] for operator type [" << op_type - << "], corresponding WebNN op type [" << webnn_op_type - << "], WebNN input name [" << input.name - << "] is invalid."; - return false; - } + if (input_limits.isUndefined()) { + LOGS(logger, VERBOSE) << "Node name: [" << node_name + << "], WebNN op type: [" << webnn_op_type + << "], input [" << input_name + << "]: limits are not defined in WebNN MLOpSupportLimits."; + return false; + } - std::vector input_shape; - if (!GetShape(*input_defs[input.index], input_shape, logger)) { - return false; - } + const emscripten::val rank_range = input_limits["rankRange"]; + if (rank_range.isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type + << "] input [" << input_name + << "]: missing 'rankRange' attribute."; + return false; + } - const std::string input_name_str(input.name); - if (wnn_limits[webnn_op_type_str].isUndefined() || - wnn_limits[webnn_op_type_str][input_name_str].isUndefined()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name " << input.name - << " is not defined in wnn_limits."; - return false; - } + const emscripten::val min_val = rank_range["min"]; + const emscripten::val max_val = rank_range["max"]; + if (min_val.isUndefined() || max_val.isUndefined()) { + LOGS(logger, VERBOSE) << "WebNN op type [" << webnn_op_type + << "] input [" << input_name + << "]: its 'rankRange' limits is missing valid 'min' or 'max' attributes."; + return false; + } - const auto& input_limits = wnn_limits[webnn_op_type_str][input_name_str]; - if (input_limits["rankRange"].isUndefined()) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name " << input.name - << "'s rankRange is not defined."; - return false; + size_t min_rank = min_val.as(); + size_t max_rank = max_val.as(); + if (input_rank < min_rank || input_rank > max_rank) { + LOGS(logger, VERBOSE) << "Node name: [" << node_name + << "] WebNN op type [" << webnn_op_type + << "] input [" << input_name << "] rank " << input_rank + << " is not in supported range [" << min_rank << ", " << max_rank << "]"; + return false; + } + + return true; +} + +bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) { + const std::string_view onnx_op_type = node.OpType(); + const std::string_view webnn_op_type = GetWebNNOpType(onnx_op_type); + + if (webnn_op_type.empty()) { + LOGS(logger, VERBOSE) << "ONNX op type: [" << onnx_op_type << "]'s corresponding WebNN op is not found."; + return false; + } + + std::vector inputs; + if (!GetWebNNOpInputs(onnx_op_type, inputs, logger)) { + return false; + } + + const auto& input_defs = node.InputDefs(); + + for (const auto& input : inputs) { + // If it is an optional input and is absent, skip. + if (!TensorExists(input_defs, input.index)) { + continue; } - int input_dim_size = static_cast(input_shape.size()); - int min_rank = input_limits["rankRange"]["min"].as(); - int max_rank = input_limits["rankRange"]["max"].as(); - - if (input_dim_size < min_rank || input_dim_size > max_rank) { - LOGS(logger, VERBOSE) << "Operator type: [" << op_type - << "], input index: [" << input.index - << "], corresponding WebNN op type: " << webnn_op_type - << ", WebNN input name: " << input.name - << ", input size " << input_dim_size - << " is not in supported range [" << min_rank << ", " << max_rank << "]"; + std::vector shape; + if (!GetShape(*input_defs[input.index], shape, logger) || + !IsInputRankSupported(wnn_limits, webnn_op_type, input.name, + shape.size(), + node.Name(), logger)) { return false; } } + return true; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index d59788600f997..50e361ede221e 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -216,6 +216,13 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n bool IsInputRankSupportedByOp(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger); +bool IsInputRankSupported(const emscripten::val& wnn_limits, + const std::string_view webnn_op_type, + const std::string_view input_name, + const size_t input_rank, + const std::string_view node_name, + const logging::Logger& logger); + // Get a set of nodes supported by WebNN EP. std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, @@ -244,6 +251,33 @@ inline std::string_view GetWebNNOpType(const std::string_view onnx_op_type) { return (it != op_inputs_map.end()) ? it->second.opType : ""; } +// Get corresponding input name of WebNN op type by ONNX op type from op_input_map +inline std::string_view GetWebNNInputName(const std::string_view onnx_op_type, const int input_index) { + const auto it = op_inputs_map.find(onnx_op_type); + + if (it != op_inputs_map.end()) { + for (const auto& input : it->second.inputs) { + if (input.index == input_index) { + return input.name; + } + } + } + + return ""; +} + +inline bool GetWebNNOpInputs(const std::string_view onnx_op_type, + std::vector& inputs, + const logging::Logger& logger) { + const auto it = op_inputs_map.find(onnx_op_type); + if (it == op_inputs_map.end()) { + LOGS(logger, VERBOSE) << "WebNN op inputs not found for op type: " << onnx_op_type; + return false; + } + inputs = it->second.inputs; + return true; +} + bool AreDataTypesSame(const std::string_view op_type, gsl::span input_types, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index 8589237617745..e0cd48b6883c2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -75,7 +75,7 @@ bool ConcatOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); } void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc index 06beb56415609..b4b9d9a0d4c6b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc @@ -56,13 +56,12 @@ bool GatherElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const N const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t data_type; - int32_t indices_type; + int32_t data_type, indices_type; if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { return false; } - return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc index 9200c596c0e53..a15542061dd60 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc @@ -61,13 +61,12 @@ bool GatherNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& n const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t data_type; - int32_t indices_type; + int32_t data_type, indices_type; if (!GetType(data, data_type, logger) || !GetType(indices, indices_type, logger)) { return false; } - return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index d84c70032e1d1..86408557013a0 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -74,13 +74,13 @@ bool GatherOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod const auto& input = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const std::string_view op_type = node.OpType(); - int32_t input_type; - int32_t indices_type; + int32_t input_type, indices_type; + if (!GetType(input, input_type, logger) || !GetType(indices, indices_type, logger)) return false; - return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 02f46c85d1d06..7af17fdc5db78 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -91,7 +91,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); std::vector a_zero_point_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[2], a_zero_point_shape, logger), "Cannot get shape of a_zero_point"); - // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to deafult value 1.0f. + // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to default value 1.0f. // The scale input should have the same shape as the zero point input. a_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, @@ -268,11 +268,45 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - if (op_type == "MatMulInteger") { - // The first decomposed op of MatMulInteger is DequantizeLinear, and so - // we only need to ensure it supports the input0_type. + if (op_type == "Gemm") { + return IsInputRankSupportedByOp(node, wnn_limits, logger) && + IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); + } else if (op_type == "MatMulInteger") { + // Check up to 4 inputs for MatMulInteger + for (size_t i = 0; i < input_defs.size(); ++i) { + std::vector shape; + if (!GetShape(*input_defs[i], shape, logger)) { + return false; + } + + // We made workaround to support 1D for input A and B, skip further checks if they are 1D + if (i <= 1 && shape.size() == 1) { + continue; + } + + // For DequantizeLinear, input indices: 0 (x), 1 (scale), 2 (zero_point) + if (!IsInputRankSupported(wnn_limits, "dequantizeLinear", + (i < 2) ? "input" : "zeroPoint", + shape.size(), node.Name(), logger)) { + return false; + } + } return IsDataTypeSupportedByOp("DequantizeLinear", input0_type, wnn_limits, "input", "x", logger); - } else { + } else { // MatMul + for (int i = 0; i < 2; ++i) { + std::vector shape; + if (!GetShape(*input_defs[i], shape, logger)) { + return false; + } + + if (shape.size() == 1) { + continue; + } + + if (!IsInputRankSupported(wnn_limits, "matmul", (i == 0) ? "a" : "b", shape.size(), node.Name(), logger)) { + return false; + } + } return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index dfe80dd419092..6e86ca77464e5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -219,7 +219,7 @@ bool GruOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input_X_type, wnn_limits, "input", "X", logger); } bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 42940083cad8e..1675615280de9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -92,7 +92,7 @@ bool LogicalOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no } std::string onnx_input_name = op_type == "Not" ? "X" : "A"; - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); } void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 09e584bc66f8a..fcdc84b75c048 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -242,7 +242,7 @@ bool LstmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } bool LstmOpBuilder::HasSupportedOutputsImpl(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 4e4014e3553ea..4d9cc39bd38fe 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -108,7 +108,7 @@ bool MaxMinOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod } } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); } void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index dd25fb9bf9315..eccf67cc46c9a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -167,7 +167,7 @@ bool QDQOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "scale", "x_scale", logger) && (!has_input2 || IsDataTypeSupportedByOp(op_type, input2_type, wnn_limits, "zeroPoint", "x_zero_point", logger)); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc index f894e8bfbd517..ae3d559023625 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc @@ -71,7 +71,6 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& updates = *node.InputDefs()[2]; - const std::string_view op_type = node.OpType(); int32_t data_type; int32_t indices_type; @@ -85,7 +84,9 @@ bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const return false; } - return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + const std::string_view op_type = node.OpType(); + + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc index e61ac3dcc9617..5467e91761823 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc @@ -63,7 +63,6 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& updates = *node.InputDefs()[2]; - const std::string_view op_type = node.OpType(); int32_t data_type; int32_t indices_type; @@ -76,8 +75,8 @@ bool ScatterNDOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& if (data_type != updates_type) { return false; } - - return IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && + const std::string_view op_type = node.OpType(); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, data_type, wnn_limits, "input", "data", logger) && IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 7a7f64b1ec96d..5d6d59663da61 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -66,7 +66,7 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& no return false; } - return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); + return IsInputRankSupportedByOp(node, wnn_limits, logger) && IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); } void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h index 5e860eea7cac9..bf95527beb44e 100644 --- a/onnxruntime/core/providers/webnn/builders/map_info.h +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -139,7 +139,7 @@ const std::unordered_map op_inputs_map = { {"Mul", {"mul", {{0, "a"}, {1, "b"}}}}, {"Pow", {"pow", {{0, "a"}, {1, "b"}}}}, {"Concat", {"concat", {{0, "inputs"}}}}, - {"Not", {"logicalNot", {{0, "input"}}}}, + {"Not", {"logicalNot", {{0, "a"}}}}, {"Flatten", {"reshape", {{0, "input"}}}}, {"LpPool", {"l2Pool2d", {{0, "input"}}}}, {"Reshape", {"reshape", {{0, "input"}}}}, From 8a27eabb05ad6bd3319792c4c4b5b5dd61c7be65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Fri, 11 Jul 2025 03:10:12 +0800 Subject: [PATCH 4/6] Make TRT plugins optional (#25261) ### Description The parser does no longer link agains the plugin library but also loads it dynamic. Due to that I think we should also make the library optional in ORT. @chilo-ms --- cmake/onnxruntime_providers_tensorrt.cmake | 23 +++------- .../nv_tensorrt_rtx/nv_execution_provider.cc | 2 +- .../tensorrt_execution_provider_custom_ops.cc | 44 ++++++++++++++++++- 3 files changed, 49 insertions(+), 20 deletions(-) diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 69c81a5ec7b9d..4184e0b049afc 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -72,10 +72,9 @@ endif() # TensorRT 10 GA onwards, the TensorRT libraries will have major version appended to the end on Windows, - # for example, nvinfer_10.dll, nvinfer_plugin_10.dll, nvonnxparser_10.dll ... + # for example, nvinfer_10.dll, nvonnxparser_10.dll ... if (WIN32 AND TRT_GREATER_OR_EQUAL_TRT_10_GA) set(NVINFER_LIB "nvinfer_${NV_TENSORRT_MAJOR}") - set(NVINFER_PLUGIN_LIB "nvinfer_plugin_${NV_TENSORRT_MAJOR}") set(PARSER_LIB "nvonnxparser_${NV_TENSORRT_MAJOR}") endif() @@ -83,15 +82,11 @@ set(NVINFER_LIB "nvinfer") endif() - if (NOT NVINFER_PLUGIN_LIB) - set(NVINFER_PLUGIN_LIB "nvinfer_plugin") - endif() - if (NOT PARSER_LIB) set(PARSER_LIB "nvonnxparser") endif() - MESSAGE(STATUS "Looking for ${NVINFER_LIB} and ${NVINFER_PLUGIN_LIB}") + MESSAGE(STATUS "Looking for ${NVINFER_LIB}") find_library(TENSORRT_LIBRARY_INFER ${NVINFER_LIB} HINTS ${TENSORRT_ROOT} @@ -101,14 +96,6 @@ MESSAGE(STATUS "Can't find ${NVINFER_LIB}") endif() - find_library(TENSORRT_LIBRARY_INFER_PLUGIN ${NVINFER_PLUGIN_LIB} - HINTS ${TENSORRT_ROOT} - PATH_SUFFIXES lib lib64 lib/x64) - - if (NOT TENSORRT_LIBRARY_INFER_PLUGIN) - MESSAGE(STATUS "Can't find ${NVINFER_PLUGIN_LIB}") - endif() - if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) MESSAGE(STATUS "Looking for ${PARSER_LIB}") @@ -120,7 +107,7 @@ MESSAGE(STATUS "Can't find ${PARSER_LIB}") endif() - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN} ${TENSORRT_LIBRARY_NVONNXPARSER}) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_NVONNXPARSER}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") else() if (TRT_GREATER_OR_EQUAL_TRT_10_GA) @@ -153,7 +140,7 @@ endif() # Static libraries are just nvonnxparser_static on all platforms set(onnxparser_link_libs nvonnxparser_static) - set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN}) + set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") endif() @@ -161,7 +148,7 @@ # nvonnxparser_static is linked against tensorrt libraries in onnx-tensorrt # See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121 # However, starting from TRT 10 GA, nvonnxparser_static doesn't link against tensorrt libraries. - # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER} and ${TENSORRT_LIBRARY_INFER_PLUGIN}. + # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER}. if(onnxruntime_CUDA_MINIMAL) set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) else() diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 711d81186bad1..c5b6507ac847b 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1304,7 +1304,7 @@ std::vector NvExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - return std::make_unique(device_id, CUDA_PINNED); + return std::make_unique(CUDA_PINNED, device_id); }, narrow(device_id_)); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 90a4294fb47f0..1e9fafe8aa323 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -7,6 +7,25 @@ #include "tensorrt_execution_provider_custom_ops.h" #include "tensorrt_execution_provider.h" +// The filename extension for a shared library is different per platform +#ifdef _WIN32 +#define LIBRARY_PREFIX +#define LIBRARY_EXTENSION ORT_TSTR(".dll") +#elif defined(__APPLE__) +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".dylib" +#else +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".so" +#endif + +#ifdef _WIN32 +#define ORT_DEF2STR_HELPER(x) L#x +#else +#define ORT_DEF2STR_HELPER(X) #X +#endif +#define ORT_DEF2STR(x) ORT_DEF2STR_HELPER(x) + namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(bool verbose); @@ -58,8 +77,31 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& // Get all registered TRT plugins from registry LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ..."; TensorrtLogger trt_logger = GetTensorrtLogger(false); - initLibNvInferPlugins(&trt_logger, ""); + try { + void* library_handle = nullptr; + const auto& env = onnxruntime::GetDefaultEnv(); +#if NV_TENSORRT_MAJOR < 10 + auto full_path = env.GetRuntimePath() + + PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin") LIBRARY_EXTENSION); +#else +#ifdef _WIN32 + auto full_path = PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin_" ORT_DEF2STR(NV_TENSORRT_MAJOR)) LIBRARY_EXTENSION); +#else + auto full_path = PathString(LIBRARY_PREFIX ORT_TSTR("nvinfer_plugin") LIBRARY_EXTENSION ORT_TSTR("." ORT_DEF2STR(NV_TENSORRT_MAJOR))); +#endif +#endif + + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, false, &library_handle)); + bool (*dyn_initLibNvInferPlugins)(void* logger, char const* libNamespace); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "initLibNvInferPlugins", (void**)&dyn_initLibNvInferPlugins)); + if (!dyn_initLibNvInferPlugins(&trt_logger, "")) { + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugin library was found but was not able to initialize default plugins."; + } + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugins successfully loaded."; + } catch (const std::exception&) { + LOGS_DEFAULT(INFO) << "[TensorRT EP] Default plugin library is not on the path and is therefore ignored"; + } int num_plugin_creator = 0; auto plugin_creators = getPluginRegistry()->getAllCreators(&num_plugin_creator); std::unordered_set registered_plugin_names; From e6658c020a9accf2263c31909eb15147b9848b20 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Thu, 10 Jul 2025 14:54:35 -0700 Subject: [PATCH 5/6] [EP ABI] Add Graph_GetGraphView API to get a OrtGraph from a subset of nodes (#25191) Added an API that creates a sub-graph from a set of nodes in an OrtGraph. This API is needed in the GetCapability EP ABI porting when EP wants to check whether a 'sub-graph' of the graph is supported by the hardware backend. --- include/onnxruntime/core/graph/graph.h | 5 +- .../core/session/onnxruntime_c_api.h | 18 ++++ onnxruntime/core/graph/ep_api_types.cc | 24 +++++ onnxruntime/core/graph/ep_api_types.h | 30 ++++++ onnxruntime/core/graph/graph.cc | 4 + onnxruntime/core/graph/graph_viewer.cc | 12 ++- onnxruntime/core/session/onnxruntime_c_api.cc | 86 ++++++++++++++++++ onnxruntime/core/session/ort_apis.h | 2 + onnxruntime/test/ep_graph/test_ep_graph.cc | 59 ++++++++++++ .../three_layer_nested_subgraph_v2.onnx | Bin 0 -> 1892 bytes 10 files changed, 237 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 54e03a31fceef..c18a42cc1bbc1 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -952,9 +952,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return const_cast(this)->GetNodeArg(name); } - // search this and up through any parent_graph_ instance for a NodeArg + // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding mutable NodeArg NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name); + // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding const NodeArg + const NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const; + /** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found. @param name The NodeArg name. @param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created. diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index bf1dd6e20ce64..051a3f7283cbe 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5748,6 +5748,24 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); + /** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph. + * + * Note: + * The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference + * the same underlying graph. + * + * \param[in] src_graph The source OrtGraph instance. + * \param[in] nodes A subset of the nodes/OrtNodes in 'graph'. + * \param[in] num_nodes Number of nodes. + * \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetGraphView, _In_ const OrtGraph* src_graph, _In_ const OrtNode** nodes, + _In_ size_t num_nodes, _Outptr_ OrtGraph** dst_graph); + /// @} /// \name OrtNode diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 8583fac30cfbf..7f81ab3433911 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -505,10 +505,34 @@ void EpGraph::IndexToEpNodeMap::SetEpNode(NodeIndex node_index, EpNode* ep_node) EpGraph::EpGraph(const GraphViewer& graph_viewer, PrivateTag) : OrtGraph(OrtGraphIrApi::kEpApi), graph_viewer_(graph_viewer) {} +EpGraph::EpGraph(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + PrivateTag) + : OrtGraph(OrtGraphIrApi::kEpApi), + graph_viewer_(*graph_viewer.get()), + owned_graph_viewer_(std::move(graph_viewer)), + owned_indexed_sub_graph_(std::move(indexed_sub_graph)) {} + // Static class function to create a std::unique_ptr. Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { auto ep_graph = std::make_unique(graph_viewer, PrivateTag{}); + return CreateImpl(std::move(ep_graph), graph_viewer, result); +} + +// Static class function to create a std::unique_ptr. +Status EpGraph::Create(std::unique_ptr src_graph_viewer, + std::unique_ptr src_indexed_sub_graph, + /*out*/ std::unique_ptr& result) { + auto& graph_viewer = *src_graph_viewer.get(); + auto ep_graph = std::make_unique(std::move(src_graph_viewer), + std::move(src_indexed_sub_graph), + PrivateTag{}); + + return CreateImpl(std::move(ep_graph), graph_viewer, result); +} + +Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance(); std::unordered_map> value_infos_map; diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 12fa082d3f354..7b67f21bf4eb4 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -251,15 +251,32 @@ struct EpGraph : public OrtGraph { public: EpGraph(const GraphViewer& graph_viewer, PrivateTag); + EpGraph(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + PrivateTag); /// /// Creates an instance of EpGraph, which wraps a GraphViewer. + /// This call is used when creating an EpGraph from a GraphViewer instance. The GraphViewer instance is not onwed by this EpGraph. /// /// /// /// static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + /// + /// Creates an instance of EpGraph, which wraps a GraphViewer. + /// This call is used when creating an EpGraph from a subset of nodes in another EpGraph. + /// In this case, due to the implementation of OrtApis::Graph_GetGraphView, the new EpGraph instance + /// must take ownership of both the GraphViewer and IndexedSubGraph. + /// + /// + /// + /// + static Status Create(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + /*out*/ std::unique_ptr& result); + // Defines ToExternal() and ToInternal() functions to convert between OrtGraph and EpGraph. DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtGraph, EpGraph, OrtGraphIrApi::kEpApi) @@ -331,9 +348,22 @@ struct EpGraph : public OrtGraph { const OrtValue* GetInitializerValue(std::string_view name) const; private: + /// + /// The real implementation of creating an EpGraph instance. + /// Please use one of the above 'Create' functions that internally call this function, and avoid calling this function directly. + /// + /// + /// + /// + /// + static Status CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + const GraphViewer& graph_viewer_; const EpNode* parent_node_ = nullptr; + std::unique_ptr owned_graph_viewer_ = nullptr; + std::unique_ptr owned_indexed_sub_graph_ = nullptr; + std::vector> nodes_; IndexToEpNodeMap index_to_ep_node_; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index ca40bad2b4250..4d3091520d876 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,6 +1818,10 @@ NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name return node_arg; } +const NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const { + return const_cast(this)->GetNodeArgIncludingParentGraphs(node_arg_name); +} + void Graph::ReverseDFSFrom(gsl::span from, const std::function& enter, const std::function& leave, diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 1842c2b4a0d1f..948ebaa5f7e15 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -168,7 +168,15 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) filtered_node_inputs_including_initializers_.reserve(metadef->inputs.size()); for (const auto& input : metadef->inputs) { - const auto* nodearg = graph.GetNodeArg(input); + // NodeArgs from the current scope or any outer scopes should be handled correctly. + // + // There is an edge case where the model consists of a graph with subgraphs nested across three levels. + // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). + // When constructing a new GraphViewer for the second- and third-layer subgraphs, + // the second-layer graph may not have the corresponding value_info for that first-layer input, + // because the second-layer graph itself doesn't consume it. + // Therefore, when working within the second-layer graph, we need to search outer scopes for the missing value_info. + const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(input); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Input not found:", input); filtered_node_inputs_including_initializers_.push_back(nodearg); if (!graph.IsInitializedTensor(input)) { @@ -177,7 +185,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) } for (const auto& output : metadef->outputs) { - const auto* nodearg = graph.GetNodeArg(output); + const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(output); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Output not found:", output); filtered_node_outputs_.push_back(nodearg); } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 18b545483b38b..312ddd7e52e00 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2714,6 +2714,91 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, _O API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph, + _In_ const OrtNode** nodes, + _In_ size_t num_nodes, + _Outptr_ OrtGraph** dst_graph) { + API_IMPL_BEGIN + + if (num_nodes == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_nodes' argument should be > 0"); + } + + const EpGraph* ep_graph = EpGraph::ToInternal(src_graph); + if (ep_graph == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "src_graph is a ModelEditorGraph which doesn't support Graph_GetSubGraph."); + } + const Graph& graph = ep_graph->GetGraphViewer().GetGraph(); + + // Create a GraphViewer with filtered info + std::unique_ptr indexed_sub_graph = std::make_unique(); + std::unique_ptr metadef = std::make_unique(); + metadef->name = "sub_graph"; + metadef->since_version = 1; + std::unordered_set outputs; + std::unordered_set initializers; + + auto add_inputs = [&](ConstPointerContainer> defs) { + for (const auto* def : defs) { + if (def->Exists()) { + // not the output of a previous node + if (outputs.count(def->Name()) == 0) { + metadef->inputs.push_back(def->Name()); + } else { + // consumed by node so no longer subgraph output + // NOTE: Ignoring edge case where a node output is an overall graph output AND a node input + outputs.erase(def->Name()); + } + + if (graph.IsInitializedTensor(def->Name())) { + initializers.insert(def); + } + } + } + }; + + auto add_node = [&](const Node& node) { + indexed_sub_graph->nodes.push_back(node.Index()); + add_inputs(node.InputDefs()); + add_inputs(node.ImplicitInputDefs()); + + for (const auto* def : node.OutputDefs()) { + outputs.insert(def->Name()); + } + }; + + // Add nodes + for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) { + const OrtNode* ort_node = nodes[node_idx]; + const EpNode* ep_node = EpNode::ToInternal(ort_node); + if (ep_node == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Graph_GetSubGraph."); + } + add_node(ep_node->GetInternalNode()); + } + + // Add initializers + for (auto& initializer : initializers) { + metadef->constant_initializers.push_back(initializer->Name()); + } + + // Add outputs + for (auto& output : outputs) { + metadef->outputs.push_back(output); + } + + indexed_sub_graph->SetMetaDef(std::move(metadef)); + auto graph_viewer = std::make_unique(graph, *indexed_sub_graph.get()); + + std::unique_ptr result; + ORT_API_RETURN_IF_STATUS_NOT_OK(EpGraph::Create(std::move(graph_viewer), std::move(indexed_sub_graph), result)); + + *dst_graph = result.release(); + + return nullptr; + API_IMPL_END +} + // // OrtNode // @@ -3629,6 +3714,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Graph_GetNumNodes, &OrtApis::Graph_GetNodes, &OrtApis::Graph_GetParentNode, + &OrtApis::Graph_GetGraphView, &OrtApis::Node_GetId, &OrtApis::Node_GetName, &OrtApis::Node_GetOperatorType, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 75db44cb9e9ff..b53863c02cfef 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -649,6 +649,8 @@ ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t* ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph, _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); +ORT_API_STATUS_IMPL(Graph_GetGraphView, _In_ const OrtGraph* graph, _In_ const OrtNode** nodes, _In_ size_t num_nodes, + _Outptr_ OrtGraph** subgraph); // OrtNode ORT_API_STATUS_IMPL(Node_GetId, _In_ const OrtNode* node, _Out_ size_t* node_id); diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index e9bed3ac45529..17e829e37f729 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -7,12 +7,15 @@ #include #include #include +#include #include "core/common/common.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_type_and_shape.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/graph/ep_api_types.h" +#include "core/graph/graph_proto_serializer.h" #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL #include "core/providers/utils/ort_graph_to_proto.h" @@ -31,6 +34,7 @@ namespace test { // forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent // to a graph represented by the internal ORT GraphViewer class. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph); +static void Check_Graph_GetSubgraph(const OrtGraph& api_graph); // // Tests @@ -73,6 +77,16 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { + // The overall structure of this model is similar to the one used in "Check3LayerNestedSubgraph" test. + // The model consists of a graph with subgraphs nested across three levels. + // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/three_layer_nested_subgraph_v2.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); +} + static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& output_data) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::SessionOptions sess_options; @@ -474,6 +488,48 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span nodes(num_nodes); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, nodes.data(), nodes.size())); + + // Select a half of nodes to create a OrtGraph + size_t num_selected_nodes = std::max((nodes.size() >> 1), (size_t)1); + std::vector selected_nodes(num_selected_nodes); + + for (size_t i = 0; i < num_selected_nodes; i++) { + selected_nodes[i] = nodes[i]; + } + + OrtGraph* sub_graph; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetGraphView(&api_graph, selected_nodes.data(), selected_nodes.size(), &sub_graph)); + + // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk. + // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw. + const GraphViewer& sub_graph_viewer = EpGraph::ToInternal(sub_graph)->GetGraphViewer(); + std::unique_ptr model = std::make_unique(sub_graph_viewer.Name(), true, sub_graph_viewer.GetGraph().GetLogger()); + auto model_proto = std::make_unique(model->ToProto()); + GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast(1)); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + + const char* graph_name = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetName(&api_graph, &graph_name)); + std::string name = graph_name; + name += "_half.onnx"; + + // Dump the graph for debugging + // std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary); + // model_proto->SerializeToOstream(&dump); + + ort_api.ReleaseGraph(sub_graph); +} + // Checks that the contents of the original GraphViewer matches the contents of the OrtGraph. // Uses the public C APIs to traverse the OrtGraph. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) { @@ -682,6 +738,9 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } } } + + // Check creating an OrtGraph from a subset of nodes in an OrtGraph + Check_Graph_GetSubgraph(api_graph); } } // namespace test diff --git a/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx b/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d036541a70aa087f6007ec7261f5f1115b0e22f2 GIT binary patch literal 1892 zcmc&#&2G~`5Y8GWaV8~=whcrEAaP&c$9+r)V_e7+CMS24QOn};JS@5uuwhojaXfGV8^v_J5P zdou3)00^KKITpQ`lc^PqiAS-P$a?eD%nd@~hP}}dH(80r++4G?cA)rhT<3}SzV zo0H)NM;CKS-}5BRvOL4jGD9eH!WEX$*~psB!&{MZ1XACWRiwTs0x4Tok{~7J8<3Kg zyCC)P1w$%{yoS^cLrIz#W*xi{bu2O*#%bu4m&0LSw8Ff{j<~^0?WofhLE6E5aOx9p z+-hn{z1&rhvR6xj#l0Rpft7%`1{)f}8YmiKUuB|a&*yC0BB8Y#SE$(ftwJ>%Q#WDR zcU7`1t}VgNkwx6ZGU0g_>iSUC`&kVX?S*?o`uvGX(i5vu@IN| z?*bMeM{k4CleJI+1N)Ij+#zSHS&GlH!_In#AF`<{cM)O@maxVVTc1$c`~O>;pxRP( zIXcxvj{r1AK$R14@*wLUUe<5PZY?Vr@Ax{%IKP6((s~;_f@}%gISCP9`Mn8CLM$zz ztjLTTtJ|gos#d`TJ`=yt>P%cC=%)K$iEO=@Z0!RYCTsuD^;qlE&7WDoU|`u|{wi#u zIc1`bUgE^=dGRX1OxccXz73K+AWBcXbEQ8{)4^}*ur}5cKJ0ex&TT6IS7u(=7R%@O gpK*_GjBuRe!e9%~W$t+nclOVTCER-|6zZFQ0aGjCcK`qY literal 0 HcmV?d00001 From 591003b1ecd13e7862d655f91bed8fba27499cf6 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 10 Jul 2025 15:17:37 -0700 Subject: [PATCH 6/6] [webgpu] a few optimization to WGSL template (#25333) ### Description This change is a follow up to #25130. - consume duktape from vcpkg if --use_vcpkg is specified - ~~add a Windows CI pipeline for dynamic WGSL template~~ (Will do in a separate PR) - upgrade wgsl-template package from 0.1.10 to 0.1.13 - support adding contribop folder as input --- .../external/onnxruntime_external_deps.cmake | 25 +++++++++++++------ cmake/onnxruntime_providers_webgpu.cmake | 11 ++++---- cmake/vcpkg.json | 8 ++++++ .../webgpu/wgsl_templates/package-lock.json | 8 +++--- .../webgpu/wgsl_templates/package.json | 2 +- tools/ci_build/build.py | 2 ++ 6 files changed, 39 insertions(+), 17 deletions(-) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index e8f6bbe895d29..228906030d14c 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -774,13 +774,24 @@ if (onnxruntime_USE_WEBGPU) endif() if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic") - onnxruntime_fetchcontent_declare( - duktape - URL ${DEP_URL_duktape} - URL_HASH SHA1=${DEP_SHA1_duktape} - EXCLUDE_FROM_ALL - ) - onnxruntime_fetchcontent_makeavailable(duktape) + if(onnxruntime_USE_VCPKG) + find_package(unofficial-duktape CONFIG REQUIRED) + add_library(duktape_static ALIAS unofficial::duktape::duktape) + else() + onnxruntime_fetchcontent_declare( + duktape + URL ${DEP_URL_duktape} + URL_HASH SHA1=${DEP_SHA1_duktape} + EXCLUDE_FROM_ALL + ) + onnxruntime_fetchcontent_makeavailable(duktape) + + if(NOT TARGET duktape_static) + add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c") + target_compile_features(duktape_static PRIVATE c_std_99) + target_include_directories(duktape_static INTERFACE $) + endif() + endif() endif() endif() diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index 5b80b1262464d..2865ad33b39f4 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -172,10 +172,12 @@ file(MAKE_DIRECTORY ${WGSL_GENERATED_DIR}) # Find all WGSL template input files - file(GLOB_RECURSE WGSL_TEMPLATE_FILES "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template") + file(GLOB_RECURSE WGSL_TEMPLATE_FILES + "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.wgsl.template" + "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.wgsl.template") # Set wgsl-gen command line options as a list - set(WGSL_GEN_OPTIONS "-i" "../" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") + set(WGSL_GEN_OPTIONS "-i" "${ONNXRUNTIME_ROOT}/core/providers/webgpu/" "-i" "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/" "--output" "${WGSL_GENERATED_DIR}" "-I" "wgsl_template_gen/" "--preserve-code-ref" "--verbose") if (onnxruntime_WGSL_TEMPLATE STREQUAL "static") if (CMAKE_BUILD_TYPE STREQUAL "Debug") list(APPEND WGSL_GEN_OPTIONS "--generator" "static-cpp-literal") @@ -207,10 +209,9 @@ # Add the generated directory to include paths target_include_directories(onnxruntime_providers_webgpu PRIVATE ${WGSL_GENERATED_ROOT}) elseif(onnxruntime_WGSL_TEMPLATE STREQUAL "dynamic") - add_library(duktape_static STATIC "${duktape_SOURCE_DIR}/src/duktape.c") - target_compile_features(duktape_static PRIVATE c_std_99) target_link_libraries(onnxruntime_providers_webgpu duktape_static) - target_include_directories(onnxruntime_providers_webgpu PRIVATE ${duktape_SOURCE_DIR}/src) + onnxruntime_add_include_to_target(onnxruntime_providers_webgpu duktape_static) + # Define the path to the generated templates.js file target_compile_definitions(onnxruntime_providers_webgpu PRIVATE "ORT_WGSL_TEMPLATES_JS_PATH=\"${WGSL_GENERATED_TEMPLATES_JS}\"") diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json index da179d0bad564..373ecec440921 100644 --- a/cmake/vcpkg.json +++ b/cmake/vcpkg.json @@ -93,6 +93,10 @@ "webgpu-ep": { "description": "Build with WebGPU EP", "dependencies": [] + }, + "webgpu-ep-wgsl-template-dynamic": { + "description": "Build with WebGPU EP with dynamic WGSL template code generator", + "dependencies": ["duktape"] } }, "overrides": [ @@ -103,6 +107,10 @@ { "name": "flatbuffers", "version": "23.5.26" + }, + { + "name": "duktape", + "version": "2.7.0#2" } ] } diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json index 7cde6c17f54e9..df1940ed6416b 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json @@ -9,13 +9,13 @@ "version": "1.0.0", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.3" + "@fs-eire/wgsl-template": "^0.1.13" } }, "node_modules/@fs-eire/wgsl-template": { - "version": "0.1.10", - "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.10.tgz", - "integrity": "sha512-F5qQZxNweZ3ZD3d9RNc/g3nTiW7jyaAVi7SlMOL4wOfXh+Nm/qca2DISNTf3kjpVqkoazMJGbZ6TPQ4a/vjw0g==", + "version": "0.1.13", + "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.13.tgz", + "integrity": "sha512-SOQjVCQCUmXb9qYr2E3CKNs88/FzINuhFJiobBEkSAsyKtJby9oFWGZnrEO+hIl/oDTLA01LbjiDxuf6TGHE/w==", "license": "MIT", "dependencies": { "minimist": "^1.2.8" diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json index 34831ccddeb33..246e7365531e0 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json @@ -10,6 +10,6 @@ "author": "", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.3" + "@fs-eire/wgsl-template": "^0.1.13" } } diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index f6e37d33b2414..f864b8eb4a74d 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -284,6 +284,8 @@ def generate_vcpkg_install_options(build_dir, args): vcpkg_install_options.append("--x-feature=vsinpu-ep") if args.use_webgpu: vcpkg_install_options.append("--x-feature=webgpu-ep") + if args.wgsl_template == "dynamic": + vcpkg_install_options.append("--x-feature=webgpu-ep-wgsl-template-dynamic") if args.use_webnn: vcpkg_install_options.append("--x-feature=webnn-ep") if args.use_xnnpack: