Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,33 @@ Status ReshapeOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
if (do_op_validation) {
NodeAttrHelper node_helper(node_unit);
auto allowzero = node_helper.Get("allowzero", static_cast<int64_t>(0));

// Only reject allowzero=1 if dynamic shape or the shape actually contains zeros
Comment thread
qti-yuduo marked this conversation as resolved.
Outdated
if (0 != allowzero) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Reshape doesn't support dynamic shape!");
const auto& inputs = node_unit.Inputs();
const auto& initializer_tensors = qnn_model_wrapper.GetInitializerTensors();
auto shape_tensor_iter = initializer_tensors.find(inputs[1].node_arg.Name());

if (shape_tensor_iter == initializer_tensors.end()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"QNN Reshape requires a constant shape input");
}

// Check if the constant shape contains any zeros
const auto* shape_tensor = shape_tensor_iter->second;
std::vector<uint8_t> unpacked_tensor;
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*shape_tensor, unpacked_tensor));

Comment thread
edgchen1 marked this conversation as resolved.
Comment thread
edgchen1 marked this conversation as resolved.
const int64_t* shape_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
size_t shape_size = unpacked_tensor.size() / sizeof(int64_t);

for (size_t i = 0; i < shape_size; ++i) {
if (shape_data[i] == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"QNN Reshape does not support shapes with zero dimensions. "
"The 'allowzero' attribute is not supported by QNN.");
Comment thread
yuslepukhin marked this conversation as resolved.
}
}
}
}

Expand Down
32 changes: 16 additions & 16 deletions onnxruntime/test/providers/qnn/reshape_expand_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,6 @@ TEST_F(QnnCPUBackendTests, Reshape_DynamicShape_Unsupported) {
19); // Opset
}

// Test that Reshape with an enabled 'allowzero' attribute is not supported by QNN EP.
TEST_F(QnnCPUBackendTests, Reshape_AllowZeroAttr_Unsupported) {
RunReshapeExpandTest("Reshape", TestInputDef<float>({1, 3, 4, 4}, false, -10.0f, 10.0f),
TestInputDef<int64_t>({2}, true, {1, 48}),
{utils::MakeAttribute("allowzero", static_cast<int64_t>(1))},
ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP.
"cpu", // Backend
19); // Opset
}

// Test Reshape of rank 4 -> rank 2.
TEST_F(QnnCPUBackendTests, Reshape_4D_f32) {
RunReshapeExpandTest("Reshape", TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
Expand Down Expand Up @@ -271,14 +261,24 @@ TEST_F(QnnHTPBackendTests, Reshape_DynamicShape_Unsupported) {
19); // Opset
}

// Test that QDQ Reshape with an enabled 'allowzero' attribute is not supported by QNN EP.
TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_Unsupported) {
// Test that QDQ Reshape with allowzero=1 and a shape containing zeros is not supported by QNN EP.
TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_WithZeros_Unsupported) {
RunReshapeExpandTestOnHTP<float>("Reshape",
TestInputDef<float>({2, 0, 3}, false, {}),
TestInputDef<int64_t>({2}, true, {6, 0}),
{utils::MakeAttribute("allowzero", static_cast<int64_t>(1))},
ExpectedEPNodeAssignment::None,
19);
}

// Test that QDQ Reshape with allowzero=1 but no zeros in shape IS supported by QNN EP.
TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_NoZeros_Supported) {
RunQDQReshapeExpandTestOnHTP<uint8_t>("Reshape",
TestInputDef<float>({1, 3, 4, 4}, false, -10.0f, 10.0f),
TestInputDef<int64_t>({2}, true, {1, 48}),
TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
TestInputDef<int64_t>({2}, true, {1, 48}), // concrete shape with no zeros
{utils::MakeAttribute("allowzero", static_cast<int64_t>(1))},
ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP.
19); // Opset
ExpectedEPNodeAssignment::All,
19);
}

// Test 8-bit QDQ Reshape of rank 4 -> rank 2.
Expand Down