Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
CreateSimpleOpBuilder("Sum", *this);
CreateSimpleOpBuilder("Tanh", *this);

CreateSimpleOpBuilder("Concat", *this);
CreateConcatOpBuilder("Concat", *this);

CreateSimpleOpBuilder("QuantizeLinear", *this);
CreateSimpleOpBuilder("DequantizeLinear", *this);
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,7 @@ void CreateSTFTOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_

void CreateInverseOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

} // namespace qnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,6 @@ bool IsOptionalNodeUnitIODef(const NodeUnitIODef& node_io_def) {
const NodeArg& arg = node_io_def.node_arg;
return !arg.Exists() || arg.Name().empty();
}

// Function to check whether we should skip processing null input which has 0 dim in shape.
// Such null inputs often exist in models saved from PyTorch, especially for Concat.
bool DoesConcatInputShapeContainZero(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const NodeUnitIODef& node_io_def,
const logging::Logger& logger) {
// Although the 0 dim issue should be handled for all op types, restricting in Concat for now since current cases
// only happen on one of Concat inputs. One may rename the function and relax the checking here to extend for other
// ops.
if (node_unit.OpType() != "Concat") {
return false;
}

std::vector<uint32_t> input_shape;
if (!qnn_model_wrapper.GetOnnxShape(node_io_def.node_arg, input_shape)) {
return false;
}

for (const uint32_t& dim : input_shape) {
if (dim == 0) {
LOGS(logger, WARNING) << "Tensor has 0 dim, ignore this input: " << node_io_def.node_arg.Name();
return true;
}
}

return false;
}
} // namespace

std::string BaseOpBuilder::GetOpBuilderType() const {
Expand Down Expand Up @@ -154,9 +126,7 @@ Status BaseOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const auto& inputs = node_unit.Inputs();
const auto input_count = GetInputCountQnnRequired(node_unit);
for (size_t input_i = 0; input_i < input_count; ++input_i) {
if (!DoesConcatInputShapeContainZero(qnn_model_wrapper, node_unit, inputs[input_i], logger)) {
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[input_i], logger, input_names));
}
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[input_i], logger, input_names));
}

return Status::OK();
Expand Down
111 changes: 111 additions & 0 deletions onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/op_builder_factory.h"
#include "core/providers/qnn/builder/qnn_utils.h"

namespace onnxruntime {
namespace qnn {

class ConcatOpBuilder : public BaseOpBuilder {
public:
ConcatOpBuilder() : BaseOpBuilder("ConcatOpBuilder") {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConcatOpBuilder);

protected:
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool do_op_validation) const override ORT_MUST_USE_RESULT;

Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const override ORT_MUST_USE_RESULT;
};

Status ConcatOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
std::vector<std::string>& input_names,
bool /*do_op_validation*/) const {
const auto& inputs = node_unit.Inputs();

for (const auto& input : inputs) {
const auto& input_name = input.node_arg.Name();
bool has_zero_dim = false;

// Check if the tensor has a 0 dimension
if (qnn_model_wrapper.IsConstantInput(input_name)) {
// Process constant inputs (initializers)
const auto* input_tensor = qnn_model_wrapper.GetConstantTensor(input_name);
if (input_tensor != nullptr) {
const auto& shape = input_tensor->dims();
if (std::find(shape.begin(), shape.end(), 0) != shape.end()) {
// Found a 0 dimension, skip this input
LOGS(logger, VERBOSE) << "Constant input tensor " << input_name << " has a 0 dimension, excluding from Concat";
has_zero_dim = true;
}
}
} else {
// Process non-constant inputs
std::vector<uint32_t> shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input.node_arg, shape), "Cannot get shape");

if (std::find(shape.begin(), shape.end(), 0) != shape.end()) {
// Found a 0 dimension, skip this input
LOGS(logger, VERBOSE) << "Input tensor " << input_name << " has a 0 dimension, excluding from Concat";
has_zero_dim = true;
}
}

// Process the input if it doesn't have a 0 dimension
if (!has_zero_dim) {
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, input, logger, input_names));
}
}

// If all inputs have 0 dimensions, return an error as Concat requires at least one non-zero dimension input
if (input_names.empty()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Concat operation requires at least one input without a 0 dimension");
}

return Status::OK();
}

Status ConcatOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const {
if (input_names.size() < 1) {
return Status::OK();
}

std::vector<std::string> param_tensor_names;

Check warning on line 89 in onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc:89: Add #include <vector> for vector<> [build/include_what_you_use] [4]

// Process axis attribute
int32_t default_axis = 0;
Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, default_axis));
QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_CONCAT_PARAM_AXIS, axis_qnn_scalar);
param_tensor_names.push_back(axis_param.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(axis_param));

// Process outputs
return ProcessOutputs(qnn_model_wrapper, node_unit,
std::move(input_names),
std::move(param_tensor_names),

Check warning on line 102 in onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc:102: Add #include <utility> for move [build/include_what_you_use] [4]
logger, do_op_validation, GetQnnOpType(node_unit.OpType()));
}

void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

Check warning on line 106 in onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc:106: Add #include <string> for string [build/include_what_you_use] [4]
op_registrations.AddOpBuilder(op_type, std::make_unique<ConcatOpBuilder>());

Check warning on line 107 in onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_unique<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc:107: Add #include <memory> for make_unique<> [build/include_what_you_use] [4]
}

} // namespace qnn
} // namespace onnxruntime
20 changes: 20 additions & 0 deletions onnxruntime/test/providers/qnn/simple_op_htp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,26 @@ static void RunFP16OpTest(const std::string& op_type,
tolerance);
}

// Test Concat with empty input
TEST_F(QnnHTPBackendTests, Concat_EmptyInput) {
RunOpTest("Concat",
{TestInputDef<float>({1, 3, 4, 4}, false, -10.0f, 10.0f),
TestInputDef<float>({1, 0, 4, 4}, false, {})},
{utils::MakeAttribute("axis", static_cast<int64_t>(1))},
13,
ExpectedEPNodeAssignment::All);
}

// Test Concat with empty initializer
TEST_F(QnnHTPBackendTests, Concat_EmptyInitializer) {
RunOpTest("Concat",
{TestInputDef<float>({1, 3, 4, 4}, false, -10.0f, 10.0f),
TestInputDef<float>({1, 0, 4, 4}, true, {})}, // true makes this an initializer
{utils::MakeAttribute("axis", static_cast<int64_t>(1))},
13,
ExpectedEPNodeAssignment::All);
}

// Test the accuracy of QDQ Sigmoid.
TEST_F(QnnHTPBackendTests, UnaryOp_Sigmoid) {
RunQDQOpTest<uint8_t>("Sigmoid",
Expand Down
Loading