diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc index 7e17addf2f577..51c38b4483cb9 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc @@ -192,6 +192,50 @@ bool IsEquationMatMulBroadcastTransposeY(const Equation& equation) { return true; } +bool IsEquationReduceSumMulBroadcastX(const Equation& equation) { + // E.g., bhwc,wkc->bhwk + const auto& [term_1, term_2, result] = equation; + if (term_1.size() != 4) { + return false; + } + if (term_2.size() != 3) { + return false; + } + if (result.size() != 4) { + return false; + } + + // Check contraction over last axis (c) + char c1 = term_1[3]; + char c2 = term_2[2]; + if (c1 != c2) { + return false; + } + + // Check w axis alignment + if (term_1[2] != term_2[0]) { + return false; + } + if (term_1[2] != result[2]) { + return false; + } + + // Check k axis alignment + if (term_2[1] != result[3]) { + return false; + } + + // Check batch dimensions + if (term_1[0] != result[0]) { + return false; + } + if (term_1[1] != result[1]) { + return false; + } + + return true; +} + /** * @brief Sets the parameter tensor names for a MatMul op. * @@ -305,6 +349,113 @@ Status CreateMatMulTransposeAll( return Status::OK(); } +/** + * @brief Creates a ReduceSum, Multiply on broadcasted input X and original input Y. + * + * @param qnn_model_wrapper Pointer to the QnnModelWrapper instance used to manage the QNN model. + * @param node_unit The NodeUnit representing the ONNX node to be converted. + * @param do_op_validation A boolean flag indicating whether to perform operation validation. + * @return Status indicating success or failure of the operation. + */ +Status CreateReduceSumMulBroadcastX( + onnxruntime::qnn::QnnModelWrapper* qnn_model_wrapper, + const onnxruntime::NodeUnit& node_unit, + std::vector&& input_names, + bool do_op_validation) { + // Reshape in0 to shape (b, h, w, 1, c) to expand dimension before the contraction axis 'c'. + // Allowing broadcast with in1 for multiplication, aligning the contraction axis for reduce. + onnxruntime::qnn::TensorInfo tensor_info_in0{}, tensor_info_in1{}, tensor_info_out{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[0], tensor_info_in0)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[1], tensor_info_in1)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Outputs()[0], tensor_info_out)); + const std::vector& shape_in0 = tensor_info_in0.shape; + const std::vector& shape_in1 = tensor_info_in1.shape; + ORT_RETURN_IF_NOT(shape_in0.size() == 4, "CreateReduceSumMulBroadcastX expects input 0 to be rank 4"); + ORT_RETURN_IF_NOT(shape_in1.size() == 3, "CreateReduceSumMulBroadcastX expects input 1 to be rank 3"); + const std::vector new_shape_in0{shape_in0[0], shape_in0[1], shape_in0[2], 1, shape_in0[3]}; + const std::string reshape_out_name = input_names[0] + "_reshaped"; + ORT_RETURN_IF_ERROR(qnn_model_wrapper->AddReshapeNode( + /*input_name=*/input_names[0], + /*output_name=*/reshape_out_name, + /*input_shape=*/shape_in0, + /*output_shape=*/new_shape_in0, + /*tensor_data_type=*/tensor_info_in0.qnn_data_type, + /*quantize_param=*/tensor_info_in0.quant_param.Copy(), + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/qnn_model_wrapper->IsGraphInput(input_names[0]))); + + // Multiply: reshaped in0 * in1 + // The output shape of the multiplication is determined by broadcasting the reshaped in0 of + // (b, h, w, 1, c) and in1 (w, k, c) along the matching axes, resulting in (b, h, w, k, c). + const std::string mul_out_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_mul"; + std::vector shape_out_mul{new_shape_in0[0], new_shape_in0[1], new_shape_in0[2], shape_in1[1], new_shape_in0[4]}; + onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_mul(mul_out_name, + QNN_TENSOR_TYPE_NATIVE, + tensor_info_in0.qnn_data_type, + tensor_info_in0.quant_param.Copy(), + std::move(shape_out_mul)); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(tensor_wrapper_mul)), + "CreateReduceSumMulBroadcastX: failed to AddTensorWrapper"); + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode( + /*qnn_node_name=*/mul_out_name, + /*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW, + /*qnn_node_type=*/QNN_OP_ELEMENT_WISE_MULTIPLY, + /*input_names=*/{reshape_out_name, input_names[1]}, + /*output_names=*/{mul_out_name}, + /*param_tensor_names=*/{}, + /*do_op_validation=*/do_op_validation), + "CreateReduceSumMulBroadcastX: failed to create Mul node"); + + std::vector param_tensor_names{}; + + // ReduceSum on last axes={4}, keep_dims=False + // Axis '4' corresponds to the last dimension ('c') of the reshaped tensor (b, h, w, k, c), + // which is the contraction axis for reduce sum op in the einsum equation (bhwc,wkc->bhwk). + std::vector axes_shape{SafeInt(1)}; + std::vector axes_value{SafeInt(4)}; + onnxruntime::qnn::QnnParamWrapper param_axes(node_unit.Index(), + node_unit.Name(), + QNN_OP_REDUCE_SUM_PARAM_AXES, + std::move(axes_shape), + std::move(axes_value)); + param_tensor_names.push_back(param_axes.GetParamTensorName()); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_axes)), + "CreateReduceSumMulBroadcastX: failed to add param axes"); + + Qnn_Scalar_t keep_dims_scalar = QNN_SCALAR_INIT; + keep_dims_scalar.dataType = QNN_DATATYPE_BOOL_8; + keep_dims_scalar.bool8Value = SafeInt(0); + onnxruntime::qnn::QnnParamWrapper param_keep_dims(node_unit.Index(), + node_unit.Name(), + QNN_OP_REDUCE_SUM_PARAM_KEEP_DIMS, + keep_dims_scalar); + param_tensor_names.push_back(param_keep_dims.GetParamTensorName()); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_keep_dims)), + "CreateReduceSumMulBroadcastX: failed to add param keep_dims"); + + const std::string out_name = node_unit.Outputs()[0].node_arg.Name(); + Qnn_TensorType_t out_tensor_type = qnn_model_wrapper->IsGraphOutput(out_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_out(out_name, + out_tensor_type, + tensor_info_out.qnn_data_type, + tensor_info_out.quant_param.Copy(), + std::move(tensor_info_out.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(tensor_wrapper_out)), + "CreateReduceSumMulBroadcastX: failed to AddTensorWrapper"); + + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode( + /*qnn_node_name=*/out_name, + /*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW, + /*qnn_node_type=*/QNN_OP_REDUCE_SUM, + /*input_names=*/{mul_out_name}, + /*output_names=*/{out_name}, + /*param_tensor_names=*/std::move(param_tensor_names), + /*do_op_validation=*/do_op_validation), + "CreateReduceSumMulBroadcastX: failed to create ReduceSum node"); + + return Status::OK(); +} + } // namespace namespace onnxruntime { @@ -356,9 +507,20 @@ Status EinsumOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, if (!IsEquationMatMul(parsed_equation.value()) && !IsEquationMatMulTransposeY(parsed_equation.value()) && !IsEquationMatMulBroadcastTransposeY(parsed_equation.value()) && - !IsEquationMatMulTransposeAll(parsed_equation.value())) { + !IsEquationMatMulTransposeAll(parsed_equation.value()) && + !IsEquationReduceSumMulBroadcastX(parsed_equation.value())) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); } + if (IsEquationReduceSumMulBroadcastX(parsed_equation.value())) { + if (IsGpuBackend(qnn_model_wrapper.GetQnnBackendType())) { + // QAIRT 3.36.1: Failed to validate on GPU. + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation + " on backend GPU"); + } + if (node_unit.Inputs()[0].quant_param.has_value()) { + // QAIRT 3.36.1: Failed to finalize QNN graph 1002. + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation + " for quantized inputs"); + } + } return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } @@ -408,6 +570,11 @@ Status EinsumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w /*node_unit=*/node_unit, /*input_names=*/std::move(input_names), /*do_op_validation=*/do_op_validation)); + } else if (IsEquationReduceSumMulBroadcastX(parsed_equation.value())) { + ORT_RETURN_IF_ERROR(CreateReduceSumMulBroadcastX(/*qnn_model_wrapper=*/&qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_names=*/std::move(input_names), + /*do_op_validation=*/do_op_validation)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); } diff --git a/onnxruntime/test/providers/qnn/einsum_op_test.cc b/onnxruntime/test/providers/qnn/einsum_op_test.cc index d8dbbd799a427..11a3d5a083aab 100644 --- a/onnxruntime/test/providers/qnn/einsum_op_test.cc +++ b/onnxruntime/test/providers/qnn/einsum_op_test.cc @@ -189,6 +189,19 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) { /*tolerance=*/1e-4f); } +TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) { + const std::vector shape0{1, 7, 1, 7}; + const std::vector shape1{1, 9, 1, 7}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bkhq,bchk->bchq", + /*tolerance=*/1e-4f); +} + TEST_F(QnnCPUBackendTests, EinsumMatMulBroadcastTransposeY) { const std::vector shape0{2, 3, 3, 4}; const std::vector shape1{3, 3, 4}; @@ -202,16 +215,16 @@ TEST_F(QnnCPUBackendTests, EinsumMatMulBroadcastTransposeY) { /*tolerance=*/1e-4f); } -TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) { - const std::vector shape0{1, 7, 1, 7}; - const std::vector shape1{1, 9, 1, 7}; +TEST_F(QnnCPUBackendTests, EinsumReduceSumMulBroadcastX) { + const std::vector shape0{2, 3, 4, 5}; + const std::vector shape1{4, 6, 5}; const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); RunQnnEinsum( /*backend=*/kQnnBackendTypeCpu, /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), - /*equation=*/"bkhq,bchk->bchq", + /*equation=*/"bhwc,wkc->bhwk", /*tolerance=*/1e-4f); } @@ -299,6 +312,19 @@ TEST_F(QnnHTPBackendTests, EinsumF16MatMulBroadcastTransposeY) { /*tolerance=*/1e-2f); } +TEST_F(QnnHTPBackendTests, EinsumF16ReduceSumMulBroadcastX) { + const std::vector shape0{1, 3, 2, 4}; + const std::vector shape1{2, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,wkc->bhwk", + /*tolerance=*/1e-2f); +} + // // QNN HTP QDQ // @@ -375,6 +401,19 @@ TEST_F(QnnHTPBackendTests, EinsumQdqMatMulBroadcastTransposeY) { /*tolerance=*/QDQTolerance()); } +// TODO: Re-enable. QAIRT 3.36.1: failed to finalize QNN graph 1002. +TEST_F(QnnHTPBackendTests, DISABLED_EinsumQdqReduceSumMulBroadcastX) { + const std::vector shape0{1, 3, 2, 4}; + const std::vector shape1{2, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,wkc->bhwk", + /*tolerance=*/QDQTolerance()); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #if defined(_M_ARM64) @@ -474,6 +513,20 @@ TEST_F(QnnGPUBackendTests, DISABLED_EinsumMatMulBroadcastTransposeY) { /*tolerance=*/1e-4f); } +// TODO: Re-enable. Failed on QAIRT 3.36.1. +TEST_F(QnnGPUBackendTests, DISABLED_EinsumReduceSumMulBroadcastX) { + const std::vector shape0{1, 3, 2, 4}; + const std::vector shape1{2, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeGpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,wkc->bhwk", + /*tolerance=*/1e-4f); +} + #endif // defined(_M_ARM64) GPU tests } // namespace test