diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 6515661a2ee6a..255714054cdaa 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -764,6 +764,24 @@ bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& n return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); } +bool CumSumNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + // Only the first input has DQ node + if (!CheckQDQNodes(graph_viewer, node, redundant_clip_node, dq_nodes, q_nodes, 1)) { + return false; + } + + int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + if (dt_input != dt_output) { + return false; + } + + return true; +} + } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index e4f4844fb88ad..6f7e153ec6ecb 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -285,6 +285,14 @@ class TopKNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; }; +// one DQ node for first input -> node -> Q +class CumSumNodeGroupSelector : public NodeGroupSelector { + bool Check(const GraphViewer& graph_viewer, + const Node& node, const Node* redundant_clip_node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; +}; + /* * NodeSelector instances for use in the QDQ::SelectorActionTransformer. */ diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index d3957a34dcfca..a39a6e8cc0e93 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -146,6 +146,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetPadOpVersionsMap() { static const OpVersionsAndSelector::OpVersionsMap GetTopKOpVersionsMap() { return {{"TopK", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetCumSumOpVersionsMap() { + return {{"CumSum", {}}}; +} /* Selector rules registration related */ void RegisterMiscSelectors(Selectors& qdq_selectors) { @@ -268,6 +271,13 @@ void RegisterTopKSelector(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterCumSumSelector(Selectors& qdq_selectors) { + /* register selector for cumsum op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetCumSumOpVersionsMap(), + std::move(selector)); +} + void SelectorManager::CreateSelectors() { RegisterMiscSelectors(qdq_selectors_); RegisterDropDQSelectors(qdq_selectors_); @@ -286,6 +296,7 @@ void SelectorManager::CreateSelectors() { RegisterWhereSelectors(qdq_selectors_); RegisterPadSelectors(qdq_selectors_); RegisterTopKSelector(qdq_selectors_); + RegisterCumSumSelector(qdq_selectors_); } void SelectorManager::InitializeSelectorsMap() { diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index e4d768093aa37..53fef09aec0fa 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -185,6 +185,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { { CreateLSTMOpBuilder("LSTM", *this); } + + { + CreateCumSumOpBuilder("CumSum", *this); + } } const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) { diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index c1cc61ad19341..1cc8e12068cca 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -104,5 +104,8 @@ void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& o void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateLSTMOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + +void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 5b3fa6ed3b950..a83e8e064c7d0 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -230,6 +230,7 @@ class BaseOpBuilder : public IOpBuilder { {"LogSoftmax", QNN_OP_LOG_SOFTMAX}, {"Concat", QNN_OP_CONCAT}, + {"CumSum", QNN_OP_CUMULATIVE_SUM}, {"Gemm", QNN_OP_FULLY_CONNECTED}, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/cumsum_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/cumsum_op_builder.cc new file mode 100644 index 0000000000000..68d2808a91e3e --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/cumsum_op_builder.cc @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" + +namespace onnxruntime { +namespace qnn { +namespace { + +Status GetOnnxAxis(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, uint32_t& onnx_axis) { + const auto& inputs = node_unit.Inputs(); + TensorInfo axis_input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], axis_input_info)); + ORT_RETURN_IF_NOT(axis_input_info.is_initializer, "axis must be initializers"); + std::vector axis_unpacked_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*axis_input_info.initializer_tensor, axis_unpacked_tensor)); + ORT_RETURN_IF_NOT(1 == static_cast(axis_unpacked_tensor.size() / sizeof(axis_input_info.qnn_data_type)), + "axis should be a single element"); + + int32_t axis = 0; + if (axis_input_info.qnn_data_type == QNN_DATATYPE_INT_64) { + axis = static_cast(*reinterpret_cast(axis_unpacked_tensor.data())); + } else { + axis = static_cast(*reinterpret_cast(axis_unpacked_tensor.data())); + } + + std::vector input_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape"); + + auto rank = static_cast(input_shape.size()); + if (axis < 0) { + axis += rank; + } + + ORT_RETURN_IF_NOT((axis >= 0 && axis < static_cast(input_shape.size())), "QNN requires axis range [0, rank-1]."); + + onnx_axis = static_cast(axis); + + return Status::OK(); +} + +} // namespace + +class CumSumOpBuilder : public BaseOpBuilder { + public: + CumSumOpBuilder() : BaseOpBuilder("CumSumOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CumSumOpBuilder); + + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; +}; + +Status CumSumOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const { + const auto& inputs = node_unit.Inputs(); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(inputs[1].node_arg.Name()), + "QNN CumSum needs axis as a param, hence input[1] must be a constant."); + + NodeAttrHelper node_helper(node_unit); + int64_t exclusive = node_helper.Get("exclusive", static_cast(0)); + int64_t reverse = node_helper.Get("reverse", static_cast(0)); + + // QNN HTP op validation passes for non-default values of attributes but fails in finalize. + // Hence adding the checks here. + ORT_RETURN_IF_NOT(exclusive == 0, "QNN only supports default value 0 for exclusive attribute"); + ORT_RETURN_IF_NOT(reverse == 0, "QNN only supports default value 0 for reverse attribute"); + + return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); +} + +Status CumSumOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + const auto& inputs = node_unit.Inputs(); + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); + return Status::OK(); +} + +Status CumSumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + + std::vector param_tensor_names; + + // Add axis param + Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT; + uint32_t onnx_axis = 0; + ORT_RETURN_IF_ERROR(GetOnnxAxis(qnn_model_wrapper, node_unit, onnx_axis)); + axis_qnn_scalar.dataType = QNN_DATATYPE_UINT_32; + axis_qnn_scalar.uint32Value = onnx_axis; + QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_CUMULATIVE_SUM_PARAM_AXIS, axis_qnn_scalar); + param_tensor_names.push_back(axis_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(axis_param)); + + // Add exclusive param + NodeAttrHelper node_helper(node_unit); + int64_t exclusive = node_helper.Get("exclusive", static_cast(0)); + Qnn_Scalar_t exclusive_qnn_scalar = QNN_SCALAR_INIT; + exclusive_qnn_scalar.dataType = QNN_DATATYPE_BOOL_8; + exclusive_qnn_scalar.bool8Value = static_cast(exclusive == 0 ? 0 : 1); + QnnParamWrapper exclusive_param(node_unit.Index(), node_unit.Name(), QNN_OP_CUMULATIVE_SUM_PARAM_EXCLUSIVE, exclusive_qnn_scalar); + param_tensor_names.push_back(exclusive_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(exclusive_param)); + + // Add reverse param + int64_t reverse = node_helper.Get("reverse", static_cast(0)); + Qnn_Scalar_t reverse_qnn_scalar = QNN_SCALAR_INIT; + reverse_qnn_scalar.dataType = QNN_DATATYPE_BOOL_8; + reverse_qnn_scalar.bool8Value = static_cast(reverse == 0 ? 0 : 1); + QnnParamWrapper reverse_param(node_unit.Index(), node_unit.Name(), QNN_OP_CUMULATIVE_SUM_PARAM_REVERSE, reverse_qnn_scalar); + param_tensor_names.push_back(reverse_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(reverse_param)); + + return ProcessOutputs(qnn_model_wrapper, node_unit, + std::move(input_names), + std::move(param_tensor_names), + logger, do_op_validation, GetQnnOpType(node_unit.OpType())); +} + +void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/cumsum_op_htp_test.cc b/onnxruntime/test/providers/qnn/cumsum_op_htp_test.cc new file mode 100644 index 0000000000000..f3affc18d8a9a --- /dev/null +++ b/onnxruntime/test/providers/qnn/cumsum_op_htp_test.cc @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" + +#include "test/providers/qnn/qnn_test_utils.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +// Runs a non-QDQ model on HTP and compares output to CPU EP. +template +static void RunCumSumOpTest(const std::string& op_type, + const TestInputDef& input_def_1, + const TestInputDef& input_def_2, + const std::vector& attrs, + int opset_version, + ExpectedEPNodeAssignment expected_ep_assignment, + float fp32_abs_err = 2e-3f) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + + // Runs model with a Q/DQ binary op and compares the outputs of the CPU and QNN EPs. + RunQnnModelTest(BuildOpTestCase(op_type, {input_def_1}, {input_def_2}, attrs), + provider_options, + opset_version, + expected_ep_assignment, + fp32_abs_err); +} + +// Non-QDQ model, CumSum with float input and axis input as initializer with axis 0 +TEST_F(QnnHTPBackendTests, CumSum_float_int32_e0_r0_axis_0) { + RunCumSumOpTest("CumSum", + TestInputDef({3, 2}, false, {1.3f, 7.2f, 0.4f, 3.4f, 5.7f, 0.8f}), + TestInputDef({}, true, {0}), + {utils::MakeAttribute("exclusive", static_cast(0)), + utils::MakeAttribute("reverse", static_cast(0))}, + 17, + ExpectedEPNodeAssignment::All); +} + +// Non-QDQ model, CumSum with float input and axis input as initializer with axis -1 +TEST_F(QnnHTPBackendTests, CumSum_float_int32_e0_r0_axis_neg1) { + RunCumSumOpTest("CumSum", + TestInputDef({3, 2}, false, {1.3f, 7.2f, 0.4f, 3.4f, 5.7f, 0.8f}), + TestInputDef({}, true, {-1}), + {utils::MakeAttribute("exclusive", static_cast(0)), + utils::MakeAttribute("reverse", static_cast(0))}, + 17, + ExpectedEPNodeAssignment::All); +} + +// Returns a function that creates a graph with a QDQ CumSum operator. +template +GetTestQDQModelFn BuildQDQCumSumTestCase(const TestInputDef& input_def, + const TestInputDef& axis_def, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_def, axis_def, attrs, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // axis input + NodeArg* axis_input = MakeTestInput(builder, axis_def); + + // CumSum op + NodeArg* op_output = builder.MakeIntermediate(); + Node& cumsum_node = builder.AddNode("CumSum", {input_qdq, axis_input}, {op_output}); + + for (const auto& attr : attrs) { + cumsum_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +// Test the accuracy of a QDQ CumSum model on QNN EP. Checks if the QDQ model on QNN EP is as accurate as the QDQ model on CPU EP +// (compared to float32 model). +template +static void RunQDQCumSumOpTest(const TestInputDef& input_def, + const TestInputDef& axis_def, + const std::vector& attrs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + + auto f32_model_builder = BuildOpTestCase("CumSum", {input_def}, {axis_def}, attrs); + auto qdq_model_builder = BuildQDQCumSumTestCase(input_def, axis_def, attrs, + use_contrib_qdq); + + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test creates a DQ -> CumSum -> Q -> DQ graph, and checks that all +// nodes are supported by the QNN EP, and that the inference results are as accurate as CPU EP. +// +// QDQ model, CumSum with uint8 input and axis input as initializer +TEST_F(QnnHTPBackendTests, CumSum_uint8_int32_e0_r0) { + RunQDQCumSumOpTest(TestInputDef({3, 2}, false, {1.3f, 7.2f, 0.4f, 3.4f, 5.7f, 0.8f}), + TestInputDef({}, true, {0}), + {utils::MakeAttribute("exclusive", static_cast(0)), + utils::MakeAttribute("reverse", static_cast(0))}, + 17, + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime + +#endif