Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,33 @@ Status ReshapeOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
if (do_op_validation) {
NodeAttrHelper node_helper(node_unit);
auto allowzero = node_helper.Get("allowzero", static_cast<int64_t>(0));

// Only reject allowzero=1 if the shape input is not constant or it actually contains zeros
if (0 != allowzero) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Reshape doesn't support dynamic shape!");
const auto& inputs = node_unit.Inputs();
const auto& initializer_tensors = qnn_model_wrapper.GetInitializerTensors();
auto shape_tensor_iter = initializer_tensors.find(inputs[1].node_arg.Name());

if (shape_tensor_iter == initializer_tensors.end()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"QNN Reshape requires a constant shape input");
}

// Check if the constant shape contains any zeros
const auto* shape_tensor = shape_tensor_iter->second;
std::vector<uint8_t> unpacked_tensor;
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*shape_tensor, unpacked_tensor));

const int64_t* shape_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
size_t shape_size = unpacked_tensor.size() / sizeof(int64_t);

for (size_t i = 0; i < shape_size; ++i) {
if (shape_data[i] == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"QNN Reshape does not support shapes with zero dimensions. "
"The 'allowzero' attribute is not supported by QNN.");
}
}
}
}

Expand Down
32 changes: 16 additions & 16 deletions onnxruntime/test/providers/qnn/reshape_expand_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,6 @@ TEST_F(QnnCPUBackendTests, Reshape_DynamicShape_Unsupported) {
19); // Opset
}

// Test that Reshape with an enabled 'allowzero' attribute is not supported by QNN EP.
TEST_F(QnnCPUBackendTests, Reshape_AllowZeroAttr_Unsupported) {
RunReshapeExpandTest("Reshape", TestInputDef<float>({1, 3, 4, 4}, false, -10.0f, 10.0f),
TestInputDef<int64_t>({2}, true, {1, 48}),
{utils::MakeAttribute("allowzero", static_cast<int64_t>(1))},
ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP.
"cpu", // Backend
19); // Opset
}

// Test Reshape of rank 4 -> rank 2.
TEST_F(QnnCPUBackendTests, Reshape_4D_f32) {
RunReshapeExpandTest("Reshape", TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
Expand Down Expand Up @@ -271,14 +261,24 @@ TEST_F(QnnHTPBackendTests, Reshape_DynamicShape_Unsupported) {
19); // Opset
}

// Test that QDQ Reshape with an enabled 'allowzero' attribute is not supported by QNN EP.
TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_Unsupported) {
// Test that QDQ Reshape with allowzero=1 and a shape containing zeros is not supported by QNN EP.
TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_WithZeros_Unsupported) {
RunReshapeExpandTestOnHTP<float>("Reshape",
TestInputDef<float>({2, 0, 3}, false, {}),
TestInputDef<int64_t>({2}, true, {6, 0}),
{utils::MakeAttribute("allowzero", static_cast<int64_t>(1))},
ExpectedEPNodeAssignment::None,
19);
}

// Test that QDQ Reshape with allowzero=1 but no zeros in shape IS supported by QNN EP.
TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_NoZeros_Supported) {
RunQDQReshapeExpandTestOnHTP<uint8_t>("Reshape",
TestInputDef<float>({1, 3, 4, 4}, false, -10.0f, 10.0f),
TestInputDef<int64_t>({2}, true, {1, 48}),
TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
TestInputDef<int64_t>({2}, true, {1, 48}), // concrete shape with no zeros
{utils::MakeAttribute("allowzero", static_cast<int64_t>(1))},
ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP.
19); // Opset
ExpectedEPNodeAssignment::All,
19);
}

// Test 8-bit QDQ Reshape of rank 4 -> rank 2.
Expand Down
Loading