diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index 731cb30b74429..28d65c3ff35d5 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -43,6 +43,7 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateSimpleOpBuilder("Elu", *this); CreateSimpleOpBuilder("Round", *this); CreateSimpleOpBuilder("Where", *this); + CreateSimpleOpBuilder("ScatterND", *this); CreateSimpleOpBuilder("Sigmoid", *this); CreateSimpleOpBuilder("Sin", *this); CreateSimpleOpBuilder("Sqrt", *this); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 272d226cd743d..5474db0590f92 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -155,6 +155,7 @@ class BaseOpBuilder : public IOpBuilder { {"ReduceSum", QNN_OP_REDUCE_SUM}, {"Round", QNN_OP_ELEMENT_WISE_ROUND}, {"Where", QNN_OP_ELEMENT_WISE_SELECT}, + {"ScatterND", QNN_OP_SCATTER_ND}, {"Sigmoid", QNN_OP_SIGMOID}, {"Sin", QNN_OP_ELEMENT_WISE_SIN}, {"Slice", QNN_OP_STRIDED_SLICE}, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 229d86082f6dc..ab022df063c96 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -56,6 +56,13 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper, padding_mode.c_str()); } + // To DO: Remove once QNN CPU supports ScatterND + const auto qnn_backend_type = qnn_model_wrapper.GetQnnBackendType(); + if (op_type == "ScatterND") { + ORT_RETURN_IF_NOT(qnn_backend_type == QnnBackendType::HTP, + "QNN EP only supports ScatterND op on HTP backend. Falling back to ORT CPU."); + } + // ONNX's Min, Max, and Sum operators accept a variable number of inputs (i.e., variadic). // However, QNN's Min, Max, and Add operators must take in exactly two inputs. if (op_type == "Min" || op_type == "Max") { diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 0eec5f800916f..bfdb1a1a6afdd 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -991,6 +991,22 @@ TEST_F(QnnHTPBackendTests, BinaryOp_And4D) { ExpectedEPNodeAssignment::All); } +// Test ScatterND op on HTP +TEST_F(QnnHTPBackendTests, ScatterND_int64_int64) { + std::vector data = {0, 1, 2, 3}; + std::vector indices = {1}; + std::vector updates = {10}; + RunOpTest("ScatterND", + { + TestInputDef({4}, false, std::move(data)), + TestInputDef({1, 1}, false, std::move(indices)), + TestInputDef({1}, false, std::move(updates)), + }, + {}, + 17, + ExpectedEPNodeAssignment::All); +} + // Test that Or is not yet supported on CPU backend. TEST_F(QnnHTPBackendTests, BinaryOp_HTP_Or_Unsupported) { RunOpTest("Or",