Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace onnxruntime {
namespace qnn {

// Operator which only need to hanle node inputs & outputs, no attributes or no need to handle attributes
// Operator which only need to handle node inputs & outputs, no attributes or no need to handle attributes
class SimpleOpBuilder : public BaseOpBuilder {
public:
SimpleOpBuilder() : BaseOpBuilder("SimpleOpBuilder") {}
Expand Down Expand Up @@ -38,7 +38,7 @@ class SimpleOpBuilder : public BaseOpBuilder {
const logging::Logger& logger,
bool do_op_validation) const ORT_MUST_USE_RESULT;

static constexpr std::array<std::string_view, 2> gridsample_supported_modes = {"bilinear", "nearest"};
static constexpr std::array<std::string_view, 3> gridsample_supported_modes = {"bilinear", "nearest", "linear"};
static constexpr std::array<std::string_view, 3> gridsample_supported_padding_modes = {"zeros", "border", "reflection"};
static constexpr std::array<std::string_view, 3> scatternd_supported_reduction = {"none", "add", "mul"};
};
Expand Down Expand Up @@ -233,12 +233,12 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper,
std::string mode = node_helper.Get("mode", "linear");
Qnn_Scalar_t mode_qnn_scalar = QNN_SCALAR_INIT;
mode_qnn_scalar.dataType = QNN_DATATYPE_UINT_32;
if ("bilinear" == mode) {
if ("linear" == mode || "bilinear" == mode) {
mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_BILINEAR;
} else if ("nearest" == mode) {
mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_NEAREST;
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support bilinear & nearest.");
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support [linear, bilinear, nearest].");
}
QnnParamWrapper mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_MODE, mode_qnn_scalar);
param_tensor_names.push_back(mode_param.GetParamTensorName());
Expand All @@ -254,7 +254,7 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper,
} else if ("reflection" == padding_mode) {
padding_mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_PADDING_MODE_REFLECTION;
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support zeros, border & reflection.");
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support [zeros, border, reflection].");
}
QnnParamWrapper padding_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_PADDING_MODE, padding_mode_qnn_scalar);
param_tensor_names.push_back(padding_mode_param.GetParamTensorName());
Expand Down
32 changes: 32 additions & 0 deletions onnxruntime/test/providers/qnn/simple_op_htp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,38 @@ TEST_F(QnnHTPBackendTests, GridSample_U16_Nearest) {
true);
}

// Test QDQ GridSample with `linear` mode on opset 20+.
TEST_F(QnnHTPBackendTests, GridSample_Linear_ZerosPadding) {
RunQDQOpTest<uint8_t>("GridSample",
{TestInputDef<float>({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)),
TestInputDef<float>({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))},
{utils::MakeAttribute("mode", "linear"), utils::MakeAttribute("padding_mode", "zeros")},
/*opset_version=*/20,
/*expected_ep_assignment=*/ExpectedEPNodeAssignment::All);
}

TEST_F(QnnHTPBackendTests, GridSample_Linear_AlignCorners_BorderPadding) {
RunQDQOpTest<uint8_t>("GridSample",
{TestInputDef<float>({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)),
TestInputDef<float>({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))},
{utils::MakeAttribute("align_corners", static_cast<int64_t>(1)),
utils::MakeAttribute("mode", "linear"),
utils::MakeAttribute("padding_mode", "border")},
/*opset_version=*/20,
/*expected_ep_assignment=*/ExpectedEPNodeAssignment::All);
}

TEST_F(QnnHTPBackendTests, GridSample_Linear_ReflectionPadding_U16) {
RunQDQOpTest<uint16_t>("GridSample",
{TestInputDef<float>({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)),
TestInputDef<float>({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))},
{utils::MakeAttribute("mode", "linear"), utils::MakeAttribute("padding_mode", "reflection")},
/*opset_version=*/21,
/*expected_ep_assignment=*/ExpectedEPNodeAssignment::All,
/*op_domain=*/kOnnxDomain,
/*use_contrib_qdq=*/true);
}

// Test QDQ GridSample with reflection padding mode
// Inaccuracy detected for output 'output', element 2.
// Output quant params: scale=0.024269860237836838, zero_point=0.
Expand Down
Loading