Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
CreateGatherNDOpBuilder("GatherND", *this);
}

{
CreateQuickGeluOpBuilder("QuickGelu", *this);
}

{
CreateModOpBuilder("Mod", *this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ void CreateThresholdedReluOpBuilder(const std::string& op_type, OpBuilderRegistr
void CreateSTFTOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateInverseOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateQuickGeluOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

} // namespace qnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/op_builder_factory.h"
#include "core/providers/qnn/builder/qnn_utils.h"

namespace onnxruntime {
namespace qnn {

class QuickGeluOpBuilder : public BaseOpBuilder {
public:
QuickGeluOpBuilder() : BaseOpBuilder("QuickGeluOpBuilder") {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QuickGeluOpBuilder);

protected:
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const override ORT_MUST_USE_RESULT;
};

Status QuickGeluOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const {
LOGS(logger, VERBOSE) << "Processing QuickGelu operator: " << node_unit.Name();

const std::string& input_name = input_names[0];
const auto& outputs = node_unit.Outputs();
const std::string& output_name = outputs[0].node_arg.Name();

NodeAttrHelper node_helper(node_unit);
float alpha = node_helper.Get("alpha", 1.702f);

TensorInfo input_info = {};
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_info));

// Skip alpha multiplication when alpha is 1.0 to reduce accumulated error
constexpr float alpha_epsilon = 1e-6f;
const bool skip_alpha_mul = std::abs(alpha - 1.0f) < alpha_epsilon;

std::string sigmoid_input_name;
std::string sigmoid_output_name = utils::GetUniqueName(node_unit.Name() + "_sigmoid");

if (skip_alpha_mul) {
sigmoid_input_name = input_name;
} else {
const std::string alpha_mul_output_name = utils::GetUniqueName(node_unit.Name() + "_alpha_mul");
sigmoid_input_name = alpha_mul_output_name;

// The alpha tensor data type should match the input data type for element-wise multiply
std::string alpha_tensor_name = utils::GetUniqueName(node_unit.Name() + "_alpha");
std::vector<uint32_t> alpha_shape{1};
Qnn_DataType_t alpha_qnn_data_type = input_info.qnn_data_type;
std::vector<uint8_t> alpha_data;

if (alpha_qnn_data_type == QNN_DATATYPE_FLOAT_16) {
alpha_data.resize(sizeof(MLFloat16));
MLFloat16 alpha_fp16(alpha);
memcpy(alpha_data.data(), &alpha_fp16.val, sizeof(MLFloat16));
} else {
alpha_data.resize(sizeof(float));
memcpy(alpha_data.data(), &alpha, sizeof(float));
}

QnnTensorWrapper alpha_tensor_wrapper(alpha_tensor_name,
QNN_TENSOR_TYPE_STATIC,
alpha_qnn_data_type,
QnnQuantParamsWrapper(),
std::move(alpha_shape),
std::move(alpha_data));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(alpha_tensor_wrapper)), "Failed to add alpha tensor.");

QnnTensorWrapper alpha_mul_output_tensor_wrapper(alpha_mul_output_name,
QNN_TENSOR_TYPE_NATIVE,
input_info.qnn_data_type,
QnnQuantParamsWrapper(),
std::vector<uint32_t>(input_info.shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(alpha_mul_output_tensor_wrapper)),
"Failed to add alpha_mul_output tensor.");

// Step 1: Create Mul node for alpha * x
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit.Name() + "_alpha_mul"),
QNN_OP_PACKAGE_NAME_QTI_AISW,
QNN_OP_ELEMENT_WISE_MULTIPLY,
{alpha_tensor_name, input_name},
{alpha_mul_output_name},
{},
do_op_validation),
"Failed to create alpha_mul node.");
}

QnnTensorWrapper sigmoid_output_tensor_wrapper(sigmoid_output_name,
QNN_TENSOR_TYPE_NATIVE,
input_info.qnn_data_type,
QnnQuantParamsWrapper(),
std::vector<uint32_t>(input_info.shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(sigmoid_output_tensor_wrapper)),
"Failed to add sigmoid_output tensor.");

Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
QnnTensorWrapper output_tensor_wrapper(output_name,
tensor_type,
input_info.qnn_data_type,
input_info.quant_param.Copy(),
std::vector<uint32_t>(input_info.shape));

Check warning on line 110 in onnxruntime/core/providers/qnn/builder/opbuilder/quick_gelu_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/opbuilder/quick_gelu_op_builder.cc:110: Add #include <vector> for vector<> [build/include_what_you_use] [4]
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor_wrapper)),

Check warning on line 111 in onnxruntime/core/providers/qnn/builder/opbuilder/quick_gelu_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/opbuilder/quick_gelu_op_builder.cc:111: Add #include <utility> for move [build/include_what_you_use] [4]
"Failed to add output tensor.");

// Step 2: Create Sigmoid node for sigmoid(alpha * x) or sigmoid(x)
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit.Name() + "_sigmoid"),
QNN_OP_PACKAGE_NAME_QTI_AISW,
QNN_OP_SIGMOID,
{sigmoid_input_name},
{sigmoid_output_name},
{},
do_op_validation),
"Failed to create sigmoid node.");

// Step 3: Create Mul node for x * sigmoid(alpha * x) or x * sigmoid(x)
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit.Name() + "_final_mul"),
QNN_OP_PACKAGE_NAME_QTI_AISW,
QNN_OP_ELEMENT_WISE_MULTIPLY,
{input_name, sigmoid_output_name},
{output_name},
{},
do_op_validation),
"Failed to create final_mul node.");

return Status::OK();
}

void CreateQuickGeluOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {

Check warning on line 137 in onnxruntime/core/providers/qnn/builder/opbuilder/quick_gelu_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/opbuilder/quick_gelu_op_builder.cc:137: Add #include <string> for string [build/include_what_you_use] [4]
op_registrations.AddOpBuilder(op_type, std::make_unique<QuickGeluOpBuilder>());

Check warning on line 138 in onnxruntime/core/providers/qnn/builder/opbuilder/quick_gelu_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_unique<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/opbuilder/quick_gelu_op_builder.cc:138: Add #include <memory> for make_unique<> [build/include_what_you_use] [4]
}

} // namespace qnn
} // namespace onnxruntime
202 changes: 202 additions & 0 deletions onnxruntime/test/providers/qnn/quick_gelu_op_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#if !defined(ORT_MINIMAL_BUILD)

#include <string>
#include "core/graph/constants.h"
#include "test/providers/qnn/qnn_test_utils.h"

#include "gtest/gtest.h"

namespace onnxruntime {
namespace test {

// Runs a model with a QuickGelu operator on the QNN CPU backend. Checks the graph node assignment
// and that inference outputs for QNN EP and CPU EP match.
template <typename DataType>
static void RunQuickGeluTest(const TestInputDef<DataType>& input_def,
float alpha,
ExpectedEPNodeAssignment expected_ep_assignment,
const std::string& backend_name = "cpu",
float fp32_abs_err = 5e-3f) {
ProviderOptions provider_options;
provider_options["backend_type"] = backend_name;

if (backend_name == "htp") {
provider_options["enable_htp_fp16_precision"] = "1";
}

auto model_builder = [input_def, alpha](ModelTestBuilder& builder) {
NodeArg* input = MakeTestInput<DataType>(builder, input_def);
auto* output = builder.MakeOutput();

Node& node = builder.AddNode("QuickGelu", {input}, {output}, kMSDomain);
node.AddAttribute("alpha", alpha);
};

RunQnnModelTest(model_builder,
provider_options,
13, // opset version for contrib ops
expected_ep_assignment,
fp32_abs_err);
}

// Tests the accuracy of a QDQ QuickGelu model on QNN EP by comparing to CPU EP.
template <typename QType>
static void RunQDQQuickGeluTest(const TestInputDef<float>& input_def,
float alpha,
ExpectedEPNodeAssignment expected_ep_assignment,
const std::string& backend_name = "htp",
bool use_contrib_qdq = false) {
ProviderOptions provider_options;
provider_options["backend_type"] = backend_name;
provider_options["offload_graph_io_quantization"] = "0";

GetTestModelFn model_builder_fn = [input_def, alpha](ModelTestBuilder& builder) {
NodeArg* input = MakeTestInput<float>(builder, input_def);
auto* output = builder.MakeOutput();

Node& node = builder.AddNode("QuickGelu", {input}, {output}, kMSDomain);
node.AddAttribute("alpha", alpha);
};

GetTestQDQModelFn<QType> qdq_model_builder_fn = [input_def, alpha, use_contrib_qdq](ModelTestBuilder& builder, std::vector<QuantParams<QType>>& output_qparams) {
NodeArg* input = MakeTestInput<float>(builder, input_def);
QuantParams<QType> input_qparams = GetTestInputQuantParams<QType>(input_def);
NodeArg* input_after_qdq = AddQDQNodePair<QType>(builder, input, input_qparams.scale,
input_qparams.zero_point, use_contrib_qdq);

// QuickGelu -> op_output
auto* op_output = builder.MakeIntermediate();
Node& node = builder.AddNode("QuickGelu", {input_after_qdq}, {op_output}, kMSDomain);
node.AddAttribute("alpha", alpha);

// op_output -> Q -> DQ -> output
AddQDQNodePairWithOutputAsGraphOutput<QType>(builder, op_output, output_qparams[0].scale,
output_qparams[0].zero_point, use_contrib_qdq);
};

TestQDQModelAccuracy(model_builder_fn,
qdq_model_builder_fn,
provider_options,
13, // opset version for contrib ops
expected_ep_assignment,
QDQTolerance(5e-3f));
}

//
// CPU tests:
//

// Test QuickGelu with default alpha value (1.0)
TEST_F(QnnCPUBackendTests, QuickGelu_Default_Alpha) {
RunQuickGeluTest<float>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
1.0f, // alpha
ExpectedEPNodeAssignment::All);
}

// Test QuickGelu with custom alpha value
TEST_F(QnnCPUBackendTests, QuickGelu_Custom_Alpha) {
RunQuickGeluTest<float>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
1.702f, // alpha
ExpectedEPNodeAssignment::All);
}

// Test QuickGelu with negative alpha value
TEST_F(QnnCPUBackendTests, QuickGelu_Negative_Alpha) {
RunQuickGeluTest<float>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
-1.702f, // alpha
ExpectedEPNodeAssignment::All);
}

#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
//
// HTP tests:
//

TEST_F(QnnHTPBackendTests, QuickGelu_Default_Alpha) {
RunQuickGeluTest<float>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
1.0f,
ExpectedEPNodeAssignment::All,
"htp",
0.01f);
}

// Test QuickGelu with custom alpha value on HTP
TEST_F(QnnHTPBackendTests, QuickGelu_Custom_Alpha) {
RunQuickGeluTest<float>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
1.702f, // alpha
ExpectedEPNodeAssignment::All,
"htp");
}

// Test QuickGelu with negative alpha value on HTP
TEST_F(QnnHTPBackendTests, QuickGelu_Negative_Alpha) {
RunQuickGeluTest<float>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
-1.702f, // alpha
ExpectedEPNodeAssignment::All,
"htp");
}

TEST_F(QnnHTPBackendTests, QuickGelu_Float16_Default_Alpha) {
RunQuickGeluTest<MLFloat16>(ConvertToFP16InputDef(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))),
1.0f,
ExpectedEPNodeAssignment::All,
"htp",
0.01f);
}

// Test QuickGelu with float16 inputs and custom alpha on HTP
TEST_F(QnnHTPBackendTests, QuickGelu_Float16_Custom_Alpha) {
RunQuickGeluTest<MLFloat16>(ConvertToFP16InputDef(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))),
1.702f, // alpha
ExpectedEPNodeAssignment::All,
"htp");
}

// Test QuickGelu with float16 inputs and negative alpha on HTP
TEST_F(QnnHTPBackendTests, QuickGelu_Float16_Negative_Alpha) {
RunQuickGeluTest<MLFloat16>(ConvertToFP16InputDef(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))),
-1.702f, // alpha
ExpectedEPNodeAssignment::All,
"htp");
}

// Test 8-bit QDQ QuickGelu with default alpha value on HTP
TEST_F(QnnHTPBackendTests, QuickGelu_QDQ_U8_Default_Alpha) {
RunQDQQuickGeluTest<uint8_t>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
1.0f, // alpha
ExpectedEPNodeAssignment::All);
}

// Test 8-bit QDQ QuickGelu with custom alpha value on HTP
TEST_F(QnnHTPBackendTests, QuickGelu_QDQ_U8_Custom_Alpha) {
RunQDQQuickGeluTest<uint8_t>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
1.702f, // alpha
ExpectedEPNodeAssignment::All);
}

// Test 16-bit QDQ QuickGelu with default alpha value on HTP
TEST_F(QnnHTPBackendTests, QuickGelu_QDQ_U16_Default_Alpha) {
RunQDQQuickGeluTest<uint16_t>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
1.0f, // alpha
ExpectedEPNodeAssignment::All,
"htp",
true); // Use com.microsoft Q/DQ ops
}

// Test 16-bit QDQ QuickGelu with custom alpha value on HTP
TEST_F(QnnHTPBackendTests, QuickGelu_QDQ_U16_Custom_Alpha) {
RunQDQQuickGeluTest<uint16_t>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
1.702f, // alpha
ExpectedEPNodeAssignment::All,
"htp",
true); // Use com.microsoft Q/DQ ops
}

#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)

} // namespace test
} // namespace onnxruntime
#endif // !defined(ORT_MINIMAL_BUILD)
Loading