Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,25 @@
return true;
}

bool ScatterElementsNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {

Check warning on line 819 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc:819: Add #include <vector> for vector<> [build/include_what_you_use] [4]
// ScatterElements has 1 INT32 input and 2 dq inputs
if (!CheckQDQNodes(graph_viewer, node, redundant_clip_node, dq_nodes, q_nodes, 2)) {
return false;
}
const int32_t dt_input_1 = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
const int32_t dt_input_2 = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
const int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();

// All input and output types must match.
if (dt_input_1 != dt_input_2 || dt_input_1 != dt_output) {
return false;
}

return true;
}

} // namespace QDQ
} // namespace onnxruntime

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,15 @@ class CumSumNodeGroupSelector : public NodeGroupSelector {
const std::vector<const Node*>& q_nodes) const override;
};

// Input: DQ nodes for Data, and Update
// Output: Q node for output
class ScatterElementsNodeGroupSelector : public NodeGroupSelector {
bool Check(const GraphViewer& graph_viewer,
const Node& node, const Node* redundant_clip_node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;
};

/*
* NodeSelector instances for use in the QDQ::SelectorActionTransformer.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ static const OpVersionsAndSelector::OpVersionsMap GetCumSumOpVersionsMap() {
return {{"CumSum", {}}};
}

static const OpVersionsAndSelector::OpVersionsMap GetScatterElementsOpVersionsMap() {
return {{"ScatterElements", {}}};
}

/* Selector rules registration related */
void RegisterMiscSelectors(Selectors& qdq_selectors) {
/* register selectors for miscellaneous ops */
Expand Down Expand Up @@ -290,6 +294,13 @@ void RegisterCumSumSelector(Selectors& qdq_selectors) {
std::move(selector));
}

void RegisterScatterElementsSelector(Selectors& qdq_selectors) {
/* register selector for cumsum op */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<ScatterElementsNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetScatterElementsOpVersionsMap(),
std::move(selector));
}

void SelectorManager::CreateSelectors() {
RegisterMiscSelectors(qdq_selectors_);
RegisterDropDQSelectors(qdq_selectors_);
Expand All @@ -310,6 +321,7 @@ void SelectorManager::CreateSelectors() {
RegisterPadSelectors(qdq_selectors_);
RegisterTopKSelector(qdq_selectors_);
RegisterCumSumSelector(qdq_selectors_);
RegisterScatterElementsSelector(qdq_selectors_);
}

void SelectorManager::InitializeSelectorsMap() {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
CreateSimpleOpBuilder("GridSample", *this);

CreateSimpleOpBuilder("LpNormalization", *this);

CreateSimpleOpBuilder("ScatterElements", *this);
}

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ class BaseOpBuilder : public IOpBuilder {

{"Pad", QNN_OP_PAD},

{"ScatterElements", QNN_OP_SCATTER_ELEMENTS},

{"Expand", QNN_OP_ELEMENT_WISE_MULTIPLY}};
auto it = onnx_op_type_to_qnn_op_type.find(onnx_op_type);
ORT_ENFORCE(it != onnx_op_type_to_qnn_op_type.end());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class SimpleOpBuilder : public BaseOpBuilder {
static constexpr std::array<std::string_view, 3> gridsample_supported_modes = {"bilinear", "nearest", "linear"};
static constexpr std::array<std::string_view, 3> gridsample_supported_padding_modes = {"zeros", "border", "reflection"};
static constexpr std::array<std::string_view, 3> scatternd_supported_reduction = {"none", "add", "mul"};
static constexpr std::array<std::string_view, 4> scatterelements_supported_reduction = {"none", "add", "mul", "max"};
};

Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper,
Expand Down Expand Up @@ -110,6 +111,14 @@ Status SimpleOpBuilder::ExplicitOpCheck(QnnModelWrapper& qnn_model_wrapper,
reduction.c_str());
}

// QNN ScatterElements doesn't support MIN reduction
if (op_type == "ScatterElements") {
NodeAttrHelper node_helper(node_unit);
std::string reduction = node_helper.Get("reduction", "none");
ORT_RETURN_IF_NOT(utils::ArrayHasString(scatterelements_supported_reduction, reduction), "ScatterElements does not support reduction ",
reduction.c_str());
}

return Status::OK();
}

Expand Down Expand Up @@ -288,6 +297,33 @@ Status ProcessScatterNDReductionAttribute(QnnModelWrapper& qnn_model_wrapper,
return Status::OK();
}

// Process Reduction attribute of ScatterElements op
Status ProcessReductionAttribute(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>& param_tensor_names) {
NodeAttrHelper node_helper(node_unit);
std::string reduction = node_helper.Get("reduction", "none");
Qnn_Scalar_t reduction_qnn_scalar = QNN_SCALAR_INIT;
reduction_qnn_scalar.dataType = QNN_DATATYPE_UINT_32;
if ("none" == reduction) {
reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ELEMENTS_REDUCTION_NONE;
} else if ("add" == reduction) {
reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ELEMENTS_REDUCTION_ADD;
} else if ("mul" == reduction) {
reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ELEMENTS_REDUCTION_MUL;
} else if ("max" == reduction) {
reduction_qnn_scalar.uint32Value = QNN_OP_SCATTER_ELEMENTS_REDUCTION_MAX;
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ScatterElements support only reduction:{none, add, mul, max}.");
}
QnnParamWrapper reduction_param(node_unit.Index(), node_unit.Name(), QNN_OP_SCATTER_ELEMENTS_PARAM_REDUCTION,
reduction_qnn_scalar);
param_tensor_names.push_back(reduction_param.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(reduction_param));

return Status::OK();
}

Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
Expand Down Expand Up @@ -397,6 +433,19 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
ORT_RETURN_IF_ERROR(ProcessScatterNDReductionAttribute(qnn_model_wrapper, node_unit, param_tensor_names));
}

if (op_type == "ScatterElements") {
// Process axis attribute
int32_t default_axis = 0;
Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, default_axis));
QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_SCATTER_ELEMENTS_PARAM_AXIS, axis_qnn_scalar);
param_tensor_names.push_back(axis_param.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(axis_param));

// Process reduction attribute
ORT_RETURN_IF_ERROR(ProcessReductionAttribute(qnn_model_wrapper, node_unit, param_tensor_names));
}

return ProcessOutputs(qnn_model_wrapper, node_unit,
std::move(input_names),
std::move(param_tensor_names),
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/onnx/TestCase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,7 @@ std::unique_ptr<std::set<BrokenTest>> GetBrokenTests(const std::string& provider
// Fails with QNN 2.31 on Windows x64 for CPU
broken_tests->insert({"gelu_tanh_2", "y:expected -0.0131778 (bc57e7d5), got -0.0136333 (bc5f5e38), diff: 0.000455472, tol=2.31778e-05."});
broken_tests->insert({"averagepool_2d_ceil", "result differs. expected 13.5 (41580000), got 0 (0)"});
broken_tests->insert({"scatter_elements_with_negative_indices", "unknown version"});
}

#ifdef DISABLE_CONTRIB_OPS
Expand Down
20 changes: 10 additions & 10 deletions onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ void RunTest(const std::vector<int64_t>& input_dims, const std::vector<int64_t>&
test.AddInput<TIndex>("indices", indices_dims, indices_data);
test.AddInput<T>("updates", indices_dims, updates_data);
test.AddOutput<T>("y", input_dims, output_data);
// OpenVINO doesn't support negative indices value.
// OpenVINO and QNN doesn't support negative indices value.
// Disable TensorRT due to missing int8 calibrator.
if (std::is_same<T, int8_t>::value) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider});
} else if (std::is_same<T, MLFloat16>::value) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kQnnExecutionProvider});
} else {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kQnnExecutionProvider});
}

onnxruntime::test::OpTester test1("ScatterElements", 11);
Expand All @@ -82,11 +82,11 @@ void RunTest(const std::vector<int64_t>& input_dims, const std::vector<int64_t>&
test1.AddInput<T>("updates", indices_dims, updates_data);
test1.AddOutput<T>("y", input_dims, output_data);
if (std::is_same<T, int8_t>::value) {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider});
} else if (std::is_same<T, MLFloat16>::value) {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kQnnExecutionProvider});
} else {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kQnnExecutionProvider});
}
}

Expand Down Expand Up @@ -268,7 +268,7 @@ static void scatter_invalid_index(const char* op_name, int op_version) {
test.AddOutput<float>("y", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f});
test.Run(OpTester::ExpectResult::kExpectFailure,
"indices element out of data bounds, idx=4 must be within the inclusive range [-4,3]",
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider});
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider});
}

TEST(Scatter, InvalidIndex) {
Expand Down Expand Up @@ -344,7 +344,7 @@ TEST(ScatterElements, AddReductionAxis1) {
test.AddInput<float>("updates", {2, 4}, {2.f, 5.f, 3.f, 6.f, 7.f, 9.f, 8.f, 10.f});
test.AddOutput<float>("y", {2, 3}, {9.f, 4.f + (2.f + 5.f + 3.f + 6.f), 1.f, 7.f, 3.f + (7.f + 9.f + 8.f + 10.f), 6.f});

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider});
}

TEST(ScatterElements, MulReduction) {
Expand All @@ -371,7 +371,7 @@ TEST(ScatterElements, MulReductionAxis1) {
test.AddInput<float>("updates", {2, 4}, {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f});
test.AddOutput<float>("y", {2, 3}, {9.f, 4.f * (2.f * 3.f * 4.f * 5.f), 1.f, 7.f, 3.f * (6.f * 7.f * 8.f * 9.f), 6.f});

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider});
}

TEST(ScatterElements, MaxReduction_MLFloat16) {
Expand Down
88 changes: 88 additions & 0 deletions onnxruntime/test/providers/qnn/qnn_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,42 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type,
};
}

template <typename InputType1, typename InputType2 = int64_t>
inline GetTestModelFn BuildOpTestCase(const std::string& op_type,
const std::vector<TestInputDef<InputType1>>& input_defs_1,
const std::vector<TestInputDef<InputType2>>& input_defs_2,
const std::vector<TestInputDef<InputType1>>& input_defs_3,
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs,
const std::string& op_domain = kOnnxDomain,
AllocatorPtr input_allocator = nullptr) {
return [op_type, input_defs_1, input_defs_2, input_defs_3, attrs, op_domain, input_allocator](ModelTestBuilder& builder) {
std::vector<NodeArg*> op_inputs;
op_inputs.reserve(input_defs_1.size() + input_defs_2.size() + input_defs_3.size());

for (const auto& input_def : input_defs_1) {
NodeArg* input = MakeTestInput<InputType1>(builder, input_def, input_allocator);
op_inputs.push_back(input);
}

for (const auto& input_def : input_defs_2) {
NodeArg* input = MakeTestInput<InputType2>(builder, input_def, input_allocator);
op_inputs.push_back(input);
}

for (const auto& input_def : input_defs_3) {
NodeArg* input = MakeTestInput<InputType1>(builder, input_def, input_allocator);
op_inputs.push_back(input);
}

auto* output = builder.MakeOutput();
Node& onnx_node = builder.AddNode(op_type, op_inputs, {output}, op_domain);

for (const auto& attr : attrs) {
onnx_node.AddAttributeProto(attr);
}
};
}

/**
* Returns a function that builds a model with a single QDQ operator with N float (quantizeable) inputs
* and M inputs of a potentially different type.
Expand Down Expand Up @@ -1066,7 +1102,59 @@ inline GetTestQDQModelFn<QuantType> BuildQDQOpTestCase(
output_qparams[0].zero_point, use_contrib_qdq);
};
}
template <typename QuantType, typename OtherInputType = int64_t>
inline GetTestQDQModelFn<QuantType> BuildQDQOpTestCase(
const std::string& op_type,
const std::vector<TestInputDef<float>>& quant_input_defs,
const std::vector<TestInputDef<OtherInputType>>& non_quant_input_defs,
const std::vector<TestInputDef<float>>& quant_input_defs_2,
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs,
const std::string& op_domain = kOnnxDomain,
bool use_contrib_qdq = false,
AllocatorPtr input_allocator = nullptr) {
return [op_type, quant_input_defs, non_quant_input_defs, quant_input_defs_2, attrs, op_domain,
use_contrib_qdq, input_allocator](
ModelTestBuilder& builder, std::vector<QuantParams<QuantType>>& output_qparams) {
std::vector<NodeArg*> op_inputs;
op_inputs.reserve(quant_input_defs.size() + non_quant_input_defs.size() + quant_input_defs_2.size());

// Create QDQ inputs
for (const auto& input_def : quant_input_defs) {
NodeArg* input = MakeTestInput<float>(builder, input_def, input_allocator);
QuantParams<QuantType> input_qparams = GetTestInputQuantParams<QuantType>(input_def);
NodeArg* input_after_qdq = AddQDQNodePair<QuantType>(builder, input, input_qparams.scale,
input_qparams.zero_point, use_contrib_qdq);
op_inputs.push_back(input_after_qdq);
}

// Create non-QDQ inputs
for (const auto& input_def : non_quant_input_defs) {
NodeArg* input = MakeTestInput<OtherInputType>(builder, input_def, input_allocator);
op_inputs.push_back(input);
}

// Create QDQ inputs
for (const auto& input_def : quant_input_defs_2) {
NodeArg* input = MakeTestInput<float>(builder, input_def, input_allocator);
QuantParams<QuantType> input_qparams = GetTestInputQuantParams<QuantType>(input_def);
NodeArg* input_after_qdq = AddQDQNodePair<QuantType>(builder, input, input_qparams.scale,
input_qparams.zero_point, use_contrib_qdq);
op_inputs.push_back(input_after_qdq);
}

// Op -> op_output
auto* op_output = builder.MakeIntermediate();
Node& onnx_node = builder.AddNode(op_type, op_inputs, {op_output}, op_domain);

for (const auto& attr : attrs) {
onnx_node.AddAttributeProto(attr);
}

// op_output -> Q -> DQ -> output
AddQDQNodePairWithOutputAsGraphOutput<QuantType>(builder, op_output, output_qparams[0].scale,
output_qparams[0].zero_point, use_contrib_qdq);
};
}
/**
* Runs a test model on the QNN EP. Checks the graph node assignment, and that inference
* outputs for QNN and CPU match.
Expand Down
Loading
Loading