From 7c03cc9bb8906639d3f04b777810c0d46fc443f0 Mon Sep 17 00:00:00 2001 From: Yuduo Wu Date: Tue, 15 Jul 2025 13:12:14 -0700 Subject: [PATCH] [QNN-EP] Support GridSample of linear mode for ONNX opset 20+ --- .../builder/opbuilder/simple_op_builder.cc | 10 +++--- .../test/providers/qnn/simple_op_htp_test.cc | 32 +++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 2650316dd07ac..1c61bda9aeb63 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -9,7 +9,7 @@ namespace onnxruntime { namespace qnn { -// Operator which only need to hanle node inputs & outputs, no attributes or no need to handle attributes +// Operator which only need to handle node inputs & outputs, no attributes or no need to handle attributes class SimpleOpBuilder : public BaseOpBuilder { public: SimpleOpBuilder() : BaseOpBuilder("SimpleOpBuilder") {} @@ -38,7 +38,7 @@ class SimpleOpBuilder : public BaseOpBuilder { const logging::Logger& logger, bool do_op_validation) const ORT_MUST_USE_RESULT; - static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest"}; + static constexpr std::array gridsample_supported_modes = {"bilinear", "nearest", "linear"}; static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; static constexpr std::array scatternd_supported_reduction = {"none", "add", "mul"}; }; @@ -233,12 +233,12 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, std::string mode = node_helper.Get("mode", "linear"); Qnn_Scalar_t mode_qnn_scalar = QNN_SCALAR_INIT; mode_qnn_scalar.dataType = QNN_DATATYPE_UINT_32; - if ("bilinear" == mode) { + if ("linear" == mode || "bilinear" == mode) { mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_BILINEAR; } else if ("nearest" == mode) { mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_MODE_NEAREST; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support bilinear & nearest."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample mode only support [linear, bilinear, nearest]."); } QnnParamWrapper mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_MODE, mode_qnn_scalar); param_tensor_names.push_back(mode_param.GetParamTensorName()); @@ -254,7 +254,7 @@ Status ProcessGridSampleAttributes(QnnModelWrapper& qnn_model_wrapper, } else if ("reflection" == padding_mode) { padding_mode_qnn_scalar.uint32Value = QNN_OP_GRID_SAMPLE_PADDING_MODE_REFLECTION; } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support zeros, border & reflection."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GridSample padding_mode only support [zeros, border, reflection]."); } QnnParamWrapper padding_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_GRID_SAMPLE_PARAM_PADDING_MODE, padding_mode_qnn_scalar); param_tensor_names.push_back(padding_mode_param.GetParamTensorName()); diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 85f8250f70fc5..4c0a53e83e274 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1254,6 +1254,38 @@ TEST_F(QnnHTPBackendTests, GridSample_U16_Nearest) { true); } +// Test QDQ GridSample with `linear` mode on opset 20+. +TEST_F(QnnHTPBackendTests, GridSample_Linear_ZerosPadding) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), + TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, + {utils::MakeAttribute("mode", "linear"), utils::MakeAttribute("padding_mode", "zeros")}, + /*opset_version=*/20, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, GridSample_Linear_AlignCorners_BorderPadding) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), + TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, + {utils::MakeAttribute("align_corners", static_cast(1)), + utils::MakeAttribute("mode", "linear"), + utils::MakeAttribute("padding_mode", "border")}, + /*opset_version=*/20, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, GridSample_Linear_ReflectionPadding_U16) { + RunQDQOpTest("GridSample", + {TestInputDef({1, 3, 4, 6}, false, GetFloatDataInRange(-10.0f, 10.0f, 72)), + TestInputDef({1, 4, 6, 2}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))}, + {utils::MakeAttribute("mode", "linear"), utils::MakeAttribute("padding_mode", "reflection")}, + /*opset_version=*/21, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*op_domain=*/kOnnxDomain, + /*use_contrib_qdq=*/true); +} + // Test QDQ GridSample with reflection padding mode // Inaccuracy detected for output 'output', element 2. // Output quant params: scale=0.024269860237836838, zero_point=0.