diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index cfa0c430b053b..c140c781196a4 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -51,7 +51,7 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateSimpleOpBuilder("Sum", *this); CreateSimpleOpBuilder("Tanh", *this); - CreateSimpleOpBuilder("Concat", *this); + CreateConcatOpBuilder("Concat", *this); CreateSimpleOpBuilder("QuantizeLinear", *this); CreateSimpleOpBuilder("DequantizeLinear", *this); diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index 1312069892671..eb1c6db82e4f1 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -127,5 +127,7 @@ void CreateSTFTOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_ void CreateInverseOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index b51732bf0fe12..0bb3accb4d754 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -13,34 +13,6 @@ bool IsOptionalNodeUnitIODef(const NodeUnitIODef& node_io_def) { const NodeArg& arg = node_io_def.node_arg; return !arg.Exists() || arg.Name().empty(); } - -// Function to check whether we should skip processing null input which has 0 dim in shape. -// Such null inputs often exist in models saved from PyTorch, especially for Concat. -bool DoesConcatInputShapeContainZero(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - const NodeUnitIODef& node_io_def, - const logging::Logger& logger) { - // Although the 0 dim issue should be handled for all op types, restricting in Concat for now since current cases - // only happen on one of Concat inputs. One may rename the function and relax the checking here to extend for other - // ops. - if (node_unit.OpType() != "Concat") { - return false; - } - - std::vector input_shape; - if (!qnn_model_wrapper.GetOnnxShape(node_io_def.node_arg, input_shape)) { - return false; - } - - for (const uint32_t& dim : input_shape) { - if (dim == 0) { - LOGS(logger, WARNING) << "Tensor has 0 dim, ignore this input: " << node_io_def.node_arg.Name(); - return true; - } - } - - return false; -} } // namespace std::string BaseOpBuilder::GetOpBuilderType() const { @@ -154,9 +126,7 @@ Status BaseOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const auto& inputs = node_unit.Inputs(); const auto input_count = GetInputCountQnnRequired(node_unit); for (size_t input_i = 0; input_i < input_count; ++input_i) { - if (!DoesConcatInputShapeContainZero(qnn_model_wrapper, node_unit, inputs[input_i], logger)) { - ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[input_i], logger, input_names)); - } + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[input_i], logger, input_names)); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc new file mode 100644 index 0000000000000..542447b1818f2 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_utils.h" + +namespace onnxruntime { +namespace qnn { + +class ConcatOpBuilder : public BaseOpBuilder { + public: + ConcatOpBuilder() : BaseOpBuilder("ConcatOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConcatOpBuilder); + + protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; +}; + +Status ConcatOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool /*do_op_validation*/) const { + const auto& inputs = node_unit.Inputs(); + + for (const auto& input : inputs) { + const auto& input_name = input.node_arg.Name(); + bool has_zero_dim = false; + + // Check if the tensor has a 0 dimension + if (qnn_model_wrapper.IsConstantInput(input_name)) { + // Process constant inputs (initializers) + const auto* input_tensor = qnn_model_wrapper.GetConstantTensor(input_name); + if (input_tensor != nullptr) { + const auto& shape = input_tensor->dims(); + if (std::find(shape.begin(), shape.end(), 0) != shape.end()) { + // Found a 0 dimension, skip this input + LOGS(logger, VERBOSE) << "Constant input tensor " << input_name << " has a 0 dimension, excluding from Concat"; + has_zero_dim = true; + } + } + } else { + // Process non-constant inputs + std::vector shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input.node_arg, shape), "Cannot get shape"); + + if (std::find(shape.begin(), shape.end(), 0) != shape.end()) { + // Found a 0 dimension, skip this input + LOGS(logger, VERBOSE) << "Input tensor " << input_name << " has a 0 dimension, excluding from Concat"; + has_zero_dim = true; + } + } + + // Process the input if it doesn't have a 0 dimension + if (!has_zero_dim) { + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, input, logger, input_names)); + } + } + + // If all inputs have 0 dimensions, return an error as Concat requires at least one non-zero dimension input + if (input_names.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Concat operation requires at least one input without a 0 dimension"); + } + + return Status::OK(); +} + +Status ConcatOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + if (input_names.size() < 1) { + return Status::OK(); + } + + std::vector param_tensor_names; + + // Process axis attribute + int32_t default_axis = 0; + Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT; + ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, default_axis)); + QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_CONCAT_PARAM_AXIS, axis_qnn_scalar); + param_tensor_names.push_back(axis_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(axis_param)); + + // Process outputs + return ProcessOutputs(qnn_model_wrapper, node_unit, + std::move(input_names), + std::move(param_tensor_names), + logger, do_op_validation, GetQnnOpType(node_unit.OpType())); +} + +void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index daae3d939660f..b84b361b61367 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -268,6 +268,26 @@ static void RunFP16OpTest(const std::string& op_type, tolerance); } +// Test Concat with empty input +TEST_F(QnnHTPBackendTests, Concat_EmptyInput) { + RunOpTest("Concat", + {TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1, 0, 4, 4}, false, {})}, + {utils::MakeAttribute("axis", static_cast(1))}, + 13, + ExpectedEPNodeAssignment::All); +} + +// Test Concat with empty initializer +TEST_F(QnnHTPBackendTests, Concat_EmptyInitializer) { + RunOpTest("Concat", + {TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1, 0, 4, 4}, true, {})}, // true makes this an initializer + {utils::MakeAttribute("axis", static_cast(1))}, + 13, + ExpectedEPNodeAssignment::All); +} + // Test the accuracy of QDQ Sigmoid. TEST_F(QnnHTPBackendTests, UnaryOp_Sigmoid) { RunQDQOpTest("Sigmoid",