Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,12 @@
auto transB = node_helper.Get("transB", static_cast<int64_t>(0));
auto M = (transB == 0) ? inputB_shape.at(1) : inputB_shape.at(0);
if (inputC_shape.size() == 0 || (inputC_shape.size() == 1 && inputC_shape.at(0) != M) ||
(inputC_shape.size() == 2 && (inputC_shape.at(0) != 1 || inputC_shape.at(1) != M))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN FullyConnected Op only support C with shape [M].");
(inputC_shape.size() == 2 && inputC_shape.at(1) != M)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN FullyConnected Op only support C with shape [N, M].");
}

if (inputC_shape.size() == 2 && node_unit.Inputs()[2].quant_param.has_value() && inputC_shape.at(0) != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN FullyConnected Op only support quantized C with shape [1, M].");
}
}

Expand Down Expand Up @@ -133,7 +137,8 @@
qnn_model_wrapper.IsGraphInput(node_input_name)));
}

if (2 == input_i && 2 == input_shape.size()) {
// Reshape [1, M] shape Bias.
if (2 == input_i && 2 == input_shape.size() && input_shape[0] == 1) {
input_shape[0] = input_shape[1];
input_shape.resize(1);
}
Expand Down Expand Up @@ -199,8 +204,70 @@
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const {
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), {},
logger, do_op_validation, GetQnnOpType(node_unit.OpType())));
// FullyConnected dosen't support 2d bias with shape [N, M], In this case, decompose Gemm into FullyConnected + Add for compatibility.
bool split_gemm = false;
if (node_unit.Inputs().size() == 3) {
auto& input_c = node_unit.Inputs()[2];
std::vector<uint32_t> input_c_shape;
QnnModelWrapper::GetOnnxShape(input_c.node_arg, input_c_shape);

// Split when input_c has 2d shape and not [1, M]
split_gemm = (input_c_shape.size() == 2 && input_c_shape.at(0) != 1);
}

if (split_gemm) {
// If split_gemm, input and output of Gemm must at least 2d.
const std::string& org_output_name = node_unit.Outputs()[0].node_arg.Name();
TensorInfo input_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_info));
TensorInfo output_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info));
std::vector<uint32_t> output_shape = output_info.shape;
QnnQuantParamsWrapper op_output_quant_param = output_info.quant_param.Copy();

const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(org_output_name);

// Create FullyConnected Node
std::vector<std::string> gemm_input_0_1;
gemm_input_0_1.push_back(input_names[0]);
gemm_input_0_1.push_back(input_names[1]);
std::string split_fully_connected_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_split_FullyConnected";
std::string split_fully_connected_output_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_split_FullyConnected_output";
QnnTensorWrapper fully_connected_output(split_fully_connected_output_name, QNN_TENSOR_TYPE_NATIVE, input_info.qnn_data_type,
QnnQuantParamsWrapper(), std::vector<uint32_t>(output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(fully_connected_output)),
"Failed to add FullyConnected output tensor.");
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(split_fully_connected_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
QNN_OP_FULLY_CONNECTED,
std::move(gemm_input_0_1),
{split_fully_connected_output_name},
{},
do_op_validation),
"Failed to add FullyConnected node.");

// Create Add Node
Qnn_TensorType_t op_output_tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
std::string split_add_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_split_add";
QnnTensorWrapper op_output_tensor_wrapper(org_output_name, op_output_tensor_type, output_info.qnn_data_type,
op_output_quant_param.Copy(), std::vector<uint32_t>(output_shape));

Check warning on line 253 in onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc:253: Add #include <vector> for vector<> [build/include_what_you_use] [4]
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(op_output_tensor_wrapper)),
"Failed to add ElementWiseAdd output tensor.");
std::string bias_name = input_names[2];

ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(split_add_name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
QNN_OP_ELEMENT_WISE_ADD,
{split_fully_connected_output_name, bias_name}, // FullyConnected output as input
{org_output_name}, // Original output as output
{},
do_op_validation),
"Failed to add ElementWiseAdd node.");
} else {
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), {},

Check warning on line 267 in onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc:267: Add #include <utility> for move [build/include_what_you_use] [4]
logger, do_op_validation, GetQnnOpType(node_unit.OpType())));
}

return Status::OK();
}

Expand Down
21 changes: 12 additions & 9 deletions onnxruntime/test/providers/qnn/gemm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,17 @@ TEST_F(QnnCPUBackendTests, Gemm_NonDefaultAlphaBeta_Unsupported) {
ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP.
}

// Test that Gemm with general 2D bias (M, N) is NOT supported (unless M == 1).
// QNN's FullyConnected operator only supports `outputVector = ( inputAsVector * weightsMatrix ) + biasesVector`
TEST_F(QnnCPUBackendTests, Gemm_2D_Bias_Unsupported) {
// Test Gemm with 2D bias is supported.
TEST_F(QnnCPUBackendTests, Gemm_2D_Bias) {
std::vector<float> input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6);
std::vector<float> input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 12);

// 2D matrix mul with bias not supported.
// 2D matrix mul with bias is supported.
RunGemmTest<float>({TestInputDef<float>({2, 3}, false, input_a_data),
TestInputDef<float>({3, 4}, false, input_b_data),
TestInputDef<float>({2, 4}, false, -1.0f, 1.0f)},
{},
ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP.
ExpectedEPNodeAssignment::All); // Assigned to QNN EP.

// However, 2D matrix mul without a bias is supported. Input A's 0th dimension is interpreted as `batch_size`.
RunGemmTest<float>({TestInputDef<float>({2, 3}, false, input_a_data),
Expand Down Expand Up @@ -525,15 +524,19 @@ TEST_F(QnnGPUBackendTests, Gemm_AlphaBetaUnsupported) {
"gpu");
}

// Gemm with matrix bias ie 2D (M, N) is NOT supported. (Note: vector bias is supported ie when M == 1).
// Gemm with matrix bias ie 2D (M, N) is supported.
// When vector bias ie M == 1
// QNN's FullyConnected operator only supports `outputVector = ( inputAsVector * weightsMatrix ) + biasesVector`
TEST_F(QnnGPUBackendTests, Gemm_2DBiasUnsupported) {
// 2D matrix mul with 2D bias not supported.
// When 2D bias i.e. M != 1, N != 1.
// When 2D bias i.e. M != 1, N != 1.
// QNN's Gemm will be split in to FullyConnected and ElementwiseAdd.
TEST_F(QnnGPUBackendTests, Gemm_2D_Bias) {
// 2D matrix mul with 2D bias is supported when Gemm is not a QDQ node.
RunGemmTest<float>({TestInputDef<float>({2, 3}, false, -10.0f, 10.0f),
TestInputDef<float>({3, 4}, false, -10.0f, 10.0f),
TestInputDef<float>({2, 4}, false, -1.0f, 1.0f)},
{},
ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP.
ExpectedEPNodeAssignment::All, // Should be assigned to QNN EP.
"gpu");
}

Expand Down
Loading