Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 168 additions & 1 deletion onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,50 @@
return true;
}

bool IsEquationReduceSumMulBroadcastX(const Equation& equation) {
// E.g., bhwc,wkc->bhwk
const auto& [term_1, term_2, result] = equation;
if (term_1.size() != 4) {
return false;
}
if (term_2.size() != 3) {
return false;
}
if (result.size() != 4) {
return false;
}

// Check contraction over last axis (c)
char c1 = term_1[3];
char c2 = term_2[2];
if (c1 != c2) {
return false;
}

// Check w axis alignment
if (term_1[2] != term_2[0]) {
return false;
}
if (term_1[2] != result[2]) {
return false;
}

// Check k axis alignment
if (term_2[1] != result[3]) {
return false;
}

// Check batch dimensions
if (term_1[0] != result[0]) {
return false;
}
if (term_1[1] != result[1]) {
return false;
}

return true;
}

/**
* @brief Sets the parameter tensor names for a MatMul op.
*
Expand Down Expand Up @@ -305,6 +349,113 @@
return Status::OK();
}

/**
* @brief Creates a ReduceSum, Multiply on broadcasted input X and original input Y.
*
* @param qnn_model_wrapper Pointer to the QnnModelWrapper instance used to manage the QNN model.
* @param node_unit The NodeUnit representing the ONNX node to be converted.
* @param do_op_validation A boolean flag indicating whether to perform operation validation.
* @return Status indicating success or failure of the operation.
*/
Status CreateReduceSumMulBroadcastX(
onnxruntime::qnn::QnnModelWrapper* qnn_model_wrapper,
const onnxruntime::NodeUnit& node_unit,
std::vector<std::string>&& input_names,
bool do_op_validation) {
// Reshape in0 to shape (b, h, w, 1, c) to expand dimension before the contraction axis 'c'.
// Allowing broadcast with in1 for multiplication, aligning the contraction axis for reduce.
onnxruntime::qnn::TensorInfo tensor_info_in0{}, tensor_info_in1{}, tensor_info_out{};
ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[0], tensor_info_in0));
ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[1], tensor_info_in1));
ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Outputs()[0], tensor_info_out));
const std::vector<uint32_t>& shape_in0 = tensor_info_in0.shape;
const std::vector<uint32_t>& shape_in1 = tensor_info_in1.shape;
ORT_RETURN_IF_NOT(shape_in0.size() == 4, "CreateReduceSumMulBroadcastX expects input 0 to be rank 4");
ORT_RETURN_IF_NOT(shape_in1.size() == 3, "CreateReduceSumMulBroadcastX expects input 1 to be rank 3");
const std::vector<uint32_t> new_shape_in0{shape_in0[0], shape_in0[1], shape_in0[2], 1, shape_in0[3]};
const std::string reshape_out_name = input_names[0] + "_reshaped";
ORT_RETURN_IF_ERROR(qnn_model_wrapper->AddReshapeNode(
/*input_name=*/input_names[0],
/*output_name=*/reshape_out_name,
/*input_shape=*/shape_in0,
/*output_shape=*/new_shape_in0,
/*tensor_data_type=*/tensor_info_in0.qnn_data_type,
/*quantize_param=*/tensor_info_in0.quant_param.Copy(),
/*do_op_validation=*/do_op_validation,
/*is_for_input=*/qnn_model_wrapper->IsGraphInput(input_names[0])));

// Multiply: reshaped in0 * in1
// The output shape of the multiplication is determined by broadcasting the reshaped in0 of
// (b, h, w, 1, c) and in1 (w, k, c) along the matching axes, resulting in (b, h, w, k, c).
const std::string mul_out_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_mul";
std::vector<uint32_t> shape_out_mul{new_shape_in0[0], new_shape_in0[1], new_shape_in0[2], shape_in1[1], new_shape_in0[4]};
onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_mul(mul_out_name,
QNN_TENSOR_TYPE_NATIVE,
tensor_info_in0.qnn_data_type,
tensor_info_in0.quant_param.Copy(),
std::move(shape_out_mul));
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(tensor_wrapper_mul)),
"CreateReduceSumMulBroadcastX: failed to AddTensorWrapper");
ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(
/*qnn_node_name=*/mul_out_name,
/*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW,
/*qnn_node_type=*/QNN_OP_ELEMENT_WISE_MULTIPLY,
/*input_names=*/{reshape_out_name, input_names[1]},
/*output_names=*/{mul_out_name},
/*param_tensor_names=*/{},
/*do_op_validation=*/do_op_validation),
"CreateReduceSumMulBroadcastX: failed to create Mul node");

std::vector<std::string> param_tensor_names{};

// ReduceSum on last axes={4}, keep_dims=False
// Axis '4' corresponds to the last dimension ('c') of the reshaped tensor (b, h, w, k, c),
// which is the contraction axis for reduce sum op in the einsum equation (bhwc,wkc->bhwk).
std::vector<uint32_t> axes_shape{SafeInt<uint32_t>(1)};
std::vector<uint32_t> axes_value{SafeInt<uint32_t>(4)};
onnxruntime::qnn::QnnParamWrapper param_axes(node_unit.Index(),
node_unit.Name(),
QNN_OP_REDUCE_SUM_PARAM_AXES,
std::move(axes_shape),
std::move(axes_value));
param_tensor_names.push_back(param_axes.GetParamTensorName());
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_axes)),
"CreateReduceSumMulBroadcastX: failed to add param axes");

Qnn_Scalar_t keep_dims_scalar = QNN_SCALAR_INIT;
keep_dims_scalar.dataType = QNN_DATATYPE_BOOL_8;
keep_dims_scalar.bool8Value = SafeInt<uint8_t>(0);
onnxruntime::qnn::QnnParamWrapper param_keep_dims(node_unit.Index(),
node_unit.Name(),
QNN_OP_REDUCE_SUM_PARAM_KEEP_DIMS,
keep_dims_scalar);
param_tensor_names.push_back(param_keep_dims.GetParamTensorName());
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_keep_dims)),
"CreateReduceSumMulBroadcastX: failed to add param keep_dims");

const std::string out_name = node_unit.Outputs()[0].node_arg.Name();
Qnn_TensorType_t out_tensor_type = qnn_model_wrapper->IsGraphOutput(out_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_out(out_name,
out_tensor_type,
tensor_info_out.qnn_data_type,
tensor_info_out.quant_param.Copy(),
std::move(tensor_info_out.shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(tensor_wrapper_out)),
"CreateReduceSumMulBroadcastX: failed to AddTensorWrapper");

ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(
/*qnn_node_name=*/out_name,
/*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW,
/*qnn_node_type=*/QNN_OP_REDUCE_SUM,
/*input_names=*/{mul_out_name},
/*output_names=*/{out_name},
/*param_tensor_names=*/std::move(param_tensor_names),
/*do_op_validation=*/do_op_validation),
"CreateReduceSumMulBroadcastX: failed to create ReduceSum node");

return Status::OK();
}

} // namespace

namespace onnxruntime {
Expand Down Expand Up @@ -356,9 +507,20 @@
if (!IsEquationMatMul(parsed_equation.value()) &&
!IsEquationMatMulTransposeY(parsed_equation.value()) &&
!IsEquationMatMulBroadcastTransposeY(parsed_equation.value()) &&
!IsEquationMatMulTransposeAll(parsed_equation.value())) {
!IsEquationMatMulTransposeAll(parsed_equation.value()) &&
!IsEquationReduceSumMulBroadcastX(parsed_equation.value())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation);
}
if (IsEquationReduceSumMulBroadcastX(parsed_equation.value())) {
if (IsGpuBackend(qnn_model_wrapper.GetQnnBackendType())) {
// QAIRT 3.36.1: Failed to validate on GPU.
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation + " on backend GPU");
}
if (node_unit.Inputs()[0].quant_param.has_value()) {
// QAIRT 3.36.1: Failed to finalize QNN graph 1002.
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation + " for quantized inputs");
}
}
return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true);
}

Expand Down Expand Up @@ -408,6 +570,11 @@
/*node_unit=*/node_unit,
/*input_names=*/std::move(input_names),
/*do_op_validation=*/do_op_validation));
} else if (IsEquationReduceSumMulBroadcastX(parsed_equation.value())) {
ORT_RETURN_IF_ERROR(CreateReduceSumMulBroadcastX(/*qnn_model_wrapper=*/&qnn_model_wrapper,
/*node_unit=*/node_unit,
/*input_names=*/std::move(input_names),

Check warning on line 576 in onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc:576: Add #include <utility> for move [build/include_what_you_use] [4]
/*do_op_validation=*/do_op_validation));
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation);
}
Expand Down
61 changes: 57 additions & 4 deletions onnxruntime/test/providers/qnn/einsum_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,19 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) {
/*tolerance=*/1e-4f);
}

TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) {
const std::vector<int64_t> shape0{1, 7, 1, 7};
const std::vector<int64_t> shape1{1, 9, 1, 7};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/kQnnBackendTypeCpu,
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bkhq,bchk->bchq",
/*tolerance=*/1e-4f);
}

TEST_F(QnnCPUBackendTests, EinsumMatMulBroadcastTransposeY) {
const std::vector<int64_t> shape0{2, 3, 3, 4};
const std::vector<int64_t> shape1{3, 3, 4};
Expand All @@ -202,16 +215,16 @@ TEST_F(QnnCPUBackendTests, EinsumMatMulBroadcastTransposeY) {
/*tolerance=*/1e-4f);
}

TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) {
const std::vector<int64_t> shape0{1, 7, 1, 7};
const std::vector<int64_t> shape1{1, 9, 1, 7};
TEST_F(QnnCPUBackendTests, EinsumReduceSumMulBroadcastX) {
const std::vector<int64_t> shape0{2, 3, 4, 5};
const std::vector<int64_t> shape1{4, 6, 5};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/kQnnBackendTypeCpu,
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bkhq,bchk->bchq",
/*equation=*/"bhwc,wkc->bhwk",
/*tolerance=*/1e-4f);
}

Expand Down Expand Up @@ -299,6 +312,19 @@ TEST_F(QnnHTPBackendTests, EinsumF16MatMulBroadcastTransposeY) {
/*tolerance=*/1e-2f);
}

TEST_F(QnnHTPBackendTests, EinsumF16ReduceSumMulBroadcastX) {
const std::vector<int64_t> shape0{1, 3, 2, 4};
const std::vector<int64_t> shape1{2, 3, 4};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/kQnnBackendTypeHtp,
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bhwc,wkc->bhwk",
/*tolerance=*/1e-2f);
}

//
// QNN HTP QDQ
//
Expand Down Expand Up @@ -375,6 +401,19 @@ TEST_F(QnnHTPBackendTests, EinsumQdqMatMulBroadcastTransposeY) {
/*tolerance=*/QDQTolerance());
}

// TODO: Re-enable. QAIRT 3.36.1: failed to finalize QNN graph 1002.
TEST_F(QnnHTPBackendTests, DISABLED_EinsumQdqReduceSumMulBroadcastX) {
const std::vector<int64_t> shape0{1, 3, 2, 4};
const std::vector<int64_t> shape1{2, 3, 4};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnHtpQdqEinsum<uint8_t, uint8_t>(
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bhwc,wkc->bhwk",
/*tolerance=*/QDQTolerance());
}

#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)

#if defined(_M_ARM64)
Expand Down Expand Up @@ -474,6 +513,20 @@ TEST_F(QnnGPUBackendTests, DISABLED_EinsumMatMulBroadcastTransposeY) {
/*tolerance=*/1e-4f);
}

// TODO: Re-enable. Failed on QAIRT 3.36.1.
TEST_F(QnnGPUBackendTests, DISABLED_EinsumReduceSumMulBroadcastX) {
const std::vector<int64_t> shape0{1, 3, 2, 4};
const std::vector<int64_t> shape1{2, 3, 4};
const std::vector<float> data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f);
const std::vector<float> data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f);
RunQnnEinsum<float>(
/*backend=*/kQnnBackendTypeGpu,
/*in0=*/TestInputDef<float>(shape0, /*is_initializer=*/false, std::move(data0)),
/*in1=*/TestInputDef<float>(shape1, /*is_initializer=*/false, std::move(data1)),
/*equation=*/"bhwc,wkc->bhwk",
/*tolerance=*/1e-4f);
}

#endif // defined(_M_ARM64) GPU tests

} // namespace test
Expand Down
Loading