Skip to content

Commit

Permalink
[QNN EP] Add more op unit tests (#17424)
Browse files Browse the repository at this point in the history
### Description
Adds more units and enables HTP support for several ops:
- Exp
- Floor (enable qdq node unit)
- Min (enable qdq node unit)
- Max (enable qdq node unit)
- Neg (enable qdq node unit)
- Not
- Pow
- PRelu (enable qdq node unit)
- Relu **(Does not work!)**
- Sigmoid
- Sqrt
- Tanh
- LogSoftmax (enable qdq node unit)
- Concat
- GlobalAveragePool

Still missing (9):
- Reshape
- Flatten
- Squeeze
- Unsqueeze
- Gemm
- Clip
- Split
- Topk
- Tile

### Motivation and Context
Increase test coverage and op support
  • Loading branch information
adrianlizarraga authored Sep 7, 2023
1 parent ede339f commit 1e4bfa1
Show file tree
Hide file tree
Showing 20 changed files with 703 additions and 448 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() {
{"HardSwish", {}},
{"Sigmoid", {}},
{"Slice", {}},
{"LogSoftmax", {}},
{"Softmax", {}},
{"Sqrt", {}},
{"Atan", {}},
Expand All @@ -72,7 +73,10 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() {
{"Log", {}},
{"LRN", {}},
{"Ceil", {}},
{"Floor", {}},
{"Round", {}},
{"Abs", {}},
{"Neg", {}},
{"DepthToSpace", {}},
{"SpaceToDepth", {}}};
}
Expand All @@ -82,10 +86,13 @@ static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() {
{"Mul", {}},
{"Pow", {}},
{"Sub", {}},
{"PRelu", {}},
{"GridSample", {}}};
}
static const OpVersionsAndSelector::OpVersionsMap GetVariadicOpVersionsMap() {
return {{"Concat", {}}};
return {{"Concat", {}},
{"Max", {}},
{"Min", {}}};
}
static const OpVersionsAndSelector::OpVersionsMap GetConvOpVersionsMap() {
return {{"Conv", {}}};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,37 @@ class SimpleOpBuilder : public BaseOpBuilder {
bool do_op_validation) const override ORT_MUST_USE_RESULT;

private:
Status ExplictOpCheck(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const;
Status ExplicitOpCheck(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const;

static constexpr std::array<std::string_view, 2> gridsample_supported_modes = {"bilinear", "nearest"};
static constexpr std::array<std::string_view, 3> gridsample_supported_padding_modes = {"zeros", "border", "reflection"};
};

Status SimpleOpBuilder::ExplictOpCheck(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
// QNN Softmax only supports an axis value equal to input_rank - 1 (i.e., same as -1).
if (node_unit.OpType() == "Softmax") {
int32_t axis = node_unit.SinceVersion() < 13 ? 1 : -1; // Default axis changed from 1 to -1 in opset 13.
static int32_t GetDefaultAxisAttribute(const std::string& op_type, int opset_version) {
if (op_type == "Softmax" || op_type == "LogSoftmax") {
// Default axis changed from 1 to -1 in opset 13.
return opset_version < 13 ? 1 : -1;
}

return 0;
}

Status SimpleOpBuilder::ExplicitOpCheck(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
const std::string& op_type = node_unit.OpType();

// QNN Softmax and LogSoftmax only support an axis value equal to input_rank - 1 (i.e., same as -1).
if (op_type == "Softmax" || op_type == "LogSoftmax") {
int32_t axis = GetDefaultAxisAttribute(op_type, node_unit.SinceVersion());
Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, axis));
std::vector<uint32_t> input_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(node_unit.Inputs()[0].node_arg, input_shape),
"QNN EP: Cannot get shape for Softmax input");
ORT_RETURN_IF(axis != static_cast<int32_t>(input_shape.size() - 1),
"QNN Softmax only supports an `axis` attribute equal to input_rank-1 (or -1)");
"QNN ", op_type.c_str(), " only supports an `axis` attribute equal to input_rank-1 (or -1)");
}

if (node_unit.OpType() == "GridSample") {
if (op_type == "GridSample") {
NodeAttrHelper node_helper(node_unit);
std::string mode = node_helper.Get("mode", "linear");
ORT_RETURN_IF_NOT(utils::ArrayHasString(gridsample_supported_modes, mode), "GridSample does not support mode ",
Expand All @@ -58,6 +69,13 @@ Status SimpleOpBuilder::ExplictOpCheck(const QnnModelWrapper& qnn_model_wrapper,
padding_mode.c_str());
}

// ONNX's Min and Max operators accept a variable number of inputs (i.e., variadic).
// However, QNN's Min and Max operators must take in exactly two inputs.
if (op_type == "Min" || op_type == "Max") {
ORT_RETURN_IF_NOT(node_unit.Inputs().size() == 2,
"QNN EP only supports Min and Max operators with exactly 2 inputs.");
}

return Status::OK();
}

Expand Down Expand Up @@ -207,7 +225,7 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
const std::string& op_type = node_unit.OpType();

if (do_op_validation) {
ORT_RETURN_IF_ERROR(ExplictOpCheck(qnn_model_wrapper, node_unit));
ORT_RETURN_IF_ERROR(ExplicitOpCheck(qnn_model_wrapper, node_unit));
// Skip the op validation for DepthToSpace & SpaceToDepth if it's not NHWC data layout
if (node_unit.Domain() != kMSInternalNHWCDomain && (op_type == "DepthToSpace" || op_type == "SpaceToDepth" || op_type == "GridSample")) {
return Status::OK();
Expand All @@ -217,7 +235,7 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
std::vector<std::string> param_tensor_names;
// Add attribute
if (op_type == "LogSoftmax" || op_type == "Softmax" || op_type == "Concat") {
int32_t default_axis = ("Softmax" == op_type) ? -1 : 0;
int32_t default_axis = GetDefaultAxisAttribute(op_type, node_unit.SinceVersion());
Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, default_axis));
QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/providers/qnn/argmaxmin_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static GetTestQDQModelFn<QType> BuildQDQArgMxxTestCase(const std::string& op_typ
return [op_type, input_def, attrs](ModelTestBuilder& builder,
std::vector<QuantParams<QType>>& output_qparams) {
ORT_UNUSED_PARAMETER(output_qparams);
QuantParams<QType> input_qparams = GetTestInputQuantParams(input_def);
QuantParams<QType> input_qparams = GetTestInputQuantParams<QType>(input_def);

auto* input = MakeTestInput(builder, input_def);

Expand Down Expand Up @@ -205,7 +205,7 @@ TEST_F(QnnHTPBackendTests, ArgMaxMin_AsGraphOutputUnsupported) {
auto model_builder_func = [](const std::string& op_type, const TestInputDef<float>& input_def,
const std::vector<ONNX_NAMESPACE::AttributeProto>& attrs) -> GetTestModelFn {
return [op_type, input_def, attrs](ModelTestBuilder& builder) {
QuantParams<uint8_t> input_qparams = GetTestInputQuantParams(input_def);
QuantParams<uint8_t> input_qparams = GetTestInputQuantParams<uint8_t>(input_def);

auto* input = MakeTestInput(builder, input_def);
auto* output = builder.MakeOutput();
Expand Down
Loading

0 comments on commit 1e4bfa1

Please sign in to comment.