diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc new file mode 100644 index 0000000000000..619e3eaf5fad4 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc @@ -0,0 +1,480 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_node_group/gelu_fusion.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" + +namespace onnxruntime { +namespace qnn { + +// Helper function to extract value from raw data based on QNN data type +static Status GetValueOnQnnDataType(const Qnn_DataType_t qnn_data_type, + const uint8_t* raw_ptr, + double& value) { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: + case QNN_DATATYPE_SFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_INT_16: + case QNN_DATATYPE_SFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_INT_32: + case QNN_DATATYPE_SFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_UINT_8: + case QNN_DATATYPE_UFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_UINT_16: + case QNN_DATATYPE_UFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_UINT_32: + case QNN_DATATYPE_UFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_FLOAT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_FLOAT_16: { + value = static_cast(reinterpret_cast(raw_ptr)->ToFloat()); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Qnn Data Type: ", qnn_data_type, " not supported."); + } + return Status::OK(); +} + +// Helper function to extract a scalar float value from a constant initializer +// Handles both float and quantized (INT type) constant inputs +static std::optional GetConstantInitializerFloatScalar(QnnModelWrapper& qnn_model_wrapper, + const NodeUnitIODef& io_def) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const auto& name = io_def.node_arg.Name(); + + if (!graph_viewer.IsConstantInitializer(name, true)) { + return std::nullopt; + } + + // Get tensor info to check if it's quantized + TensorInfo tensor_info = {}; + if (!qnn_model_wrapper.GetTensorInfo(io_def, tensor_info).IsOK()) { + return std::nullopt; + } + + // Must be an initializer + if (!tensor_info.is_initializer || !tensor_info.initializer_tensor) { + return std::nullopt; + } + + // Unpack the initializer data + std::vector unpacked_tensor; + if (!qnn_model_wrapper.UnpackInitializerData(*tensor_info.initializer_tensor, unpacked_tensor).IsOK()) { + return std::nullopt; + } + + if (unpacked_tensor.empty()) { + return std::nullopt; + } + + // Extract the value using GetValueOnQnnDataType + double extracted_value = 0.0; + if (!GetValueOnQnnDataType(tensor_info.qnn_data_type, unpacked_tensor.data(), extracted_value).IsOK()) { + return std::nullopt; + } + + // Check if quantized and dequantize if needed + const bool is_quantized = tensor_info.quant_param.IsQuantized(); + if (is_quantized) { + // For quantized tensors, dequantize the value + if (!tensor_info.quant_param.IsPerTensor()) { + return std::nullopt; // Only support per-tensor quantization + } + + const Qnn_QuantizeParams_t& quant_param = tensor_info.quant_param.Get(); + double dequantized_value = utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, + extracted_value); + return static_cast(dequantized_value); + } + + // For non-quantized tensors, return the extracted value directly + return static_cast(extracted_value); +} + +// Helper function to check if a constant initializer has the expected float value +static bool IsInitializerWithExpectedValue(QnnModelWrapper& qnn_model_wrapper, + const NodeUnitIODef& io_def, + float expected_value, + float tolerance = 1e-5f) { + std::optional actual_value = GetConstantInitializerFloatScalar(qnn_model_wrapper, io_def); + if (!actual_value.has_value()) { + return false; + } + + // Compare with expected value within tolerance + return std::fabs(actual_value.value() - expected_value) <= tolerance; +} + +// Forward declaration. +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output, + bool validate); + +// Helper function to validate on QNN +static Status ValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output) { + return CreateOrValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output, true); +} + +// Helper function to create on QNN +static Status CreateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output) { + return CreateOrValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output, false); +} + +// Gets the parent and child of the Erf node. Can handle the following sequences +// - Parent -> Erf -> Child. +// - Parent -> DQ -> Erf -> Q -> Child. +// +// Also returns the outputs of the Erf. For the sequence `DQ -> Erf -> Q`, returns the outputs of the Q. +static bool GetErfParentAndChild(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& erf_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + /*out*/ const NodeUnit*& parent_node_unit, + /*out*/ const NodeUnit*& child_node_unit, + /*out*/ const NodeUnit*& dq_node_unit, + /*out*/ const NodeUnit*& q_node_unit, + /*out*/ gsl::span& erf_outputs) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + auto get_first_parent = [&](const NodeUnit& node_unit) -> const NodeUnit* { + const auto& inputs = node_unit.Inputs(); + if (inputs.empty()) { + return nullptr; + } + return GetParentOfInput(graph_viewer, node_unit, inputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + }; + + auto get_first_child = [&](const NodeUnit& node_unit) -> const NodeUnit* { + const auto& outputs = node_unit.Outputs(); + if (outputs.empty()) { + return nullptr; + } + + return GetOnlyChildOfOutput(graph_viewer, node_unit, outputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + }; + + const NodeUnit* erf_parent_node_unit = get_first_parent(erf_node_unit); + if (erf_parent_node_unit == nullptr) { + return false; + } + + const NodeUnit* erf_child_node_unit = get_first_child(erf_node_unit); + if (erf_child_node_unit == nullptr) { + return false; + } + + if (erf_node_unit.UnitType() == NodeUnit::Type::SingleNode && + erf_parent_node_unit->OpType() == "DequantizeLinear" && + erf_child_node_unit->OpType() == "QuantizeLinear") { + // This is the explicit sequence DQ -> Erf -> Q. + // Look past the DQ and Q nodes to get the actual parent and child. + // We do this because ORT utils do not automatically group DQ -> Erf -> Q into a NodeUnit. + dq_node_unit = erf_parent_node_unit; + q_node_unit = erf_child_node_unit; + erf_parent_node_unit = get_first_parent(*erf_parent_node_unit); + erf_child_node_unit = get_first_child(*erf_child_node_unit); + + erf_outputs = q_node_unit->Outputs(); + } else { + erf_outputs = erf_node_unit.Outputs(); + } + + parent_node_unit = erf_parent_node_unit; + child_node_unit = erf_child_node_unit; + return parent_node_unit != nullptr && child_node_unit != nullptr; +} + +std::unique_ptr GeluFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& erf_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& /*logger*/) { + if (erf_node_unit.OpType() != "Erf") { + return nullptr; + } + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const NodeUnit* div_node_unit = nullptr; + const NodeUnit* add_node_unit = nullptr; + const NodeUnit* dq_node_unit = nullptr; + const NodeUnit* q_node_unit = nullptr; + gsl::span erf_outputs; + + if (!GetErfParentAndChild(qnn_model_wrapper, erf_node_unit, node_to_node_unit, node_unit_to_qnn_node_group, + div_node_unit, add_node_unit, dq_node_unit, q_node_unit, erf_outputs)) { + return nullptr; + } + + // Erf must have a Div parent. + if (div_node_unit == nullptr || div_node_unit->OpType() != "Div") { + return nullptr; + } + + // Div must have 2 inputs + const auto& div_inputs = div_node_unit->Inputs(); + if (div_inputs.size() < 2) { + return nullptr; + } + + // Check second input of Div is sqrt(2) ≈ 1.4142 + if (!IsInitializerWithExpectedValue(qnn_model_wrapper, div_inputs[1], static_cast(M_SQRT2))) { + return nullptr; + } + + // Erf must have an Add child consuming its output + if (add_node_unit == nullptr || add_node_unit->OpType() != "Add") { + return nullptr; + } + + // Add must have 2 inputs + const auto& add_inputs = add_node_unit->Inputs(); + if (add_inputs.size() < 2) { + return nullptr; + } + + // Check the other input node (e.g. not the Erf) is 1.0f + bool is_erf_first_input = (add_inputs[0].node_arg.Name() == erf_outputs[0].node_arg.Name()); + const auto& add_const_input = add_inputs[is_erf_first_input ? 1 : 0]; + if (!IsInitializerWithExpectedValue(qnn_model_wrapper, add_const_input, 1.0f)) { + return nullptr; + } + + // Add must have a Mul child consuming its output + const auto& add_outputs = add_node_unit->Outputs(); + if (add_outputs.empty()) { + return nullptr; + } + + const NodeUnit* mul_node_unit = GetOnlyChildOfOutput(graph_viewer, *add_node_unit, add_outputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + if (mul_node_unit == nullptr || mul_node_unit->OpType() != "Mul") { + return nullptr; + } + + // Now check which pattern we have + const auto& root_input_name = div_inputs[0].node_arg.Name(); + const auto& mul_inputs = mul_node_unit->Inputs(); + + if (mul_inputs.size() < 2) { + return nullptr; + } + + // Try to match Pattern 1: root -> Mul(0.5) -> ... -> Mul + // In this case, one input to the final Mul should be from a Mul node + const NodeUnit* mul2_node_unit = nullptr; + + // Check if either input to mul_node_unit comes from a Mul node + for (size_t i = 0; i < 2; ++i) { + const auto& mul_input = mul_inputs[i]; + + const NodeUnit* producer_unit = GetParentOfInput(graph_viewer, *mul_node_unit, mul_input, + node_to_node_unit, node_unit_to_qnn_node_group); + if (producer_unit && producer_unit->OpType() == "Mul") { + const auto& mul2_inputs = producer_unit->Inputs(); + if (mul2_inputs.size() >= 2) { + bool has_root_input = (mul2_inputs[0].node_arg.Name() == root_input_name || + mul2_inputs[1].node_arg.Name() == root_input_name); + if (has_root_input) { + int root_index = (mul2_inputs[0].node_arg.Name() == root_input_name) ? 0 : 1; + const auto& mul_const_input = mul2_inputs[1 - root_index]; + + if (IsInitializerWithExpectedValue(qnn_model_wrapper, mul_const_input, 0.5f)) { + mul2_node_unit = producer_unit; + break; + } + } + } + } + if (mul2_node_unit != nullptr) break; + } + + std::vector node_units; + const NodeUnit* final_mul_node_unit = nullptr; + + if (mul2_node_unit != nullptr) { + // Pattern 1: root -> Mul(0.5) -> ... -> Mul + if (dq_node_unit != nullptr) { + assert(q_node_unit != nullptr); + node_units = {div_node_unit, dq_node_unit, &erf_node_unit, q_node_unit, add_node_unit, mul2_node_unit, + mul_node_unit}; + } else { + node_units = {div_node_unit, &erf_node_unit, add_node_unit, mul2_node_unit, mul_node_unit}; + } + final_mul_node_unit = mul_node_unit; + } else { + // Try Pattern 2: root -> ... -> Mul -> Mul(0.5) + // Check if one input to mul_node_unit is root + bool has_root_input = (mul_inputs[0].node_arg.Name() == root_input_name || + mul_inputs[1].node_arg.Name() == root_input_name); + + if (!has_root_input) { + return nullptr; + } + + // mul_node_unit must have a Mul child consuming its output + const auto& mul_outputs = mul_node_unit->Outputs(); + if (mul_outputs.empty()) { + return nullptr; + } + + const NodeUnit* mul2_node_unit_pattern2 = GetOnlyChildOfOutput(graph_viewer, *mul_node_unit, mul_outputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + if (mul2_node_unit_pattern2 == nullptr || mul2_node_unit_pattern2->OpType() != "Mul") { + return nullptr; + } + + // Verify this final Mul has 2 inputs + const auto& mul2_inputs = mul2_node_unit_pattern2->Inputs(); + if (mul2_inputs.size() < 2) { + return nullptr; + } + + // Check the constant input is 0.5f + int mul_const_input_index = 0; + if (mul2_inputs[0].node_arg.Name() == mul_outputs[0].node_arg.Name()) { + mul_const_input_index = 1; + } + const auto& mul_const_input = mul2_inputs[mul_const_input_index]; + if (!IsInitializerWithExpectedValue(qnn_model_wrapper, mul_const_input, 0.5f)) { + return nullptr; + } + + // Pattern 2 + if (dq_node_unit != nullptr) { + assert(q_node_unit != nullptr); + node_units = {div_node_unit, dq_node_unit, &erf_node_unit, q_node_unit, add_node_unit, + mul_node_unit, mul2_node_unit_pattern2}; + } else { + node_units = {div_node_unit, &erf_node_unit, add_node_unit, mul_node_unit, mul2_node_unit_pattern2}; + } + + final_mul_node_unit = mul2_node_unit_pattern2; + } + + // Validate on QNN + const NodeUnitIODef& root_input = div_inputs[0]; + const NodeUnitIODef& final_output = final_mul_node_unit->Outputs()[0]; + + if (Status status = ValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output); + !status.IsOK()) { + return nullptr; + } + + return std::make_unique(std::move(node_units), &erf_node_unit); +} + +GeluFusion::GeluFusion(std::vector&& node_units, const NodeUnit* target_node_unit) + : node_units_(std::move(node_units)), target_node_unit_(target_node_unit) { +} + +Status GeluFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const { + ORT_RETURN_IF_NOT(!node_units_.empty(), "GeluFusion node_units_ is empty"); + const NodeUnitIODef& root_input = node_units_[0]->Inputs()[0]; + const NodeUnitIODef& final_output = node_units_.back()->Outputs()[0]; + return ValidateOnQnn(qmw, node_units_, root_input, final_output); +} + +Status GeluFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& /*logger*/) const { + ORT_RETURN_IF_NOT(!node_units_.empty(), "GeluFusion node_units_ is empty"); + const NodeUnitIODef& root_input = node_units_[0]->Inputs()[0]; + const NodeUnitIODef& final_output = node_units_.back()->Outputs()[0]; + return CreateOnQnn(qmw, node_units_, root_input, final_output); +} + +gsl::span GeluFusion::GetNodeUnits() const { + return gsl::span(node_units_.data(), node_units_.size()); +} + +const NodeUnit* GeluFusion::GetTargetNodeUnit() const { + return target_node_unit_; +} + +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output, + bool validate) { + assert(node_units.size() >= 4); + const auto& node_name = utils::GetUniqueName(*node_units[0]); + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(root_input, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(final_output, output_tensor)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_GELU, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {})); + } else { + // Only add tensor wrappers if they don't already exist + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(root_input.node_arg.Name())) { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + } + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(final_output.node_arg.Name())) { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + } + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_GELU, + {root_input.node_arg.Name()}, + {final_output.node_arg.Name()}, + {}, + validate), + "Failed to add fused Gelu node."); + } + + return Status::OK(); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h new file mode 100644 index 0000000000000..508b1fca48a67 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" +#include "core/providers/qnn/ort_api.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of the Gelu pattern expanded into ONNX operators. +/// This fusion handles two patterns: +/// Pattern 1: +/// +-------Mul(0.5)---------------------+ +/// | | +/// | v +/// [root] --> Div -----> Erf --> Add --> Mul ==> +/// (B=1.4142...) (1) +/// +/// Pattern 2: +/// +------------------------------------+ +/// | | +/// | v +/// [root] --> Div -----> Erf --> Add --> Mul -->Mul ==> +/// (B=1.4142...) (1) (0.5) +/// +/// Both patterns are translated into a QNN Gelu operator. +/// The contained NodeUnits can be of type SingleNode or QDQGroup (with Q-DQ nodes). +/// The second inputs to Div, Add, and Mul Node Units should be constant. +/// +class GeluFusion : public IQnnNodeGroup { + public: + GeluFusion(std::vector&& node_units, const NodeUnit* target_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(GeluFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "GeluFusion"; } + + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid Gelu pattern. + /// If so, returns a IQnnNodeGroup that contains all the NodeUnits in the pattern. + /// + /// Used for validation and traverse/query the graph + /// Erf node unit that could be part of the sequence + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& erf_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::vector node_units_; + const NodeUnit* target_node_unit_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 368caa518b7ba..4297801ce4cdc 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -22,6 +22,7 @@ #include "core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/reshape_transpose_rank5.h" +#include "core/providers/qnn/builder/qnn_node_group/gelu_fusion.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/ort_api.h" @@ -83,6 +84,7 @@ static std::unordered_map> fusions = { {"Gemm", {LowPowerBlockQuantizedGemmFusion::TryFusion, ReshapeGemmFusion::TryFusion}}, {"Mul", {ScaleSoftmaxFusion::TryFusion}}, {"Cast", {CastLoneQFusion::TryFusion}}, + {"Erf", {GeluFusion::TryFusion}}, {"Reshape", {Rank6ToRank5Fusion::TryFusion}}, {"Transpose", {ChannelShuffleFusion::TryFusion}}}; @@ -119,9 +121,11 @@ static std::unique_ptr TryQnnFusions( const std::unordered_map& node_to_node_unit, const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger) { - // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except MatMul w/ LPBQ encodings and Reshape + // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except + // MatMul w/ LPBQ encodings, Erf and Reshape if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode && starting_node_unit.OpType() != "MatMul" && + starting_node_unit.OpType() != "Erf" && starting_node_unit.OpType() != "Reshape") { return nullptr; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc index 10e1633e4b57d..7b77164a38545 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc @@ -226,14 +226,92 @@ const NodeUnit* GetParentOfInput(const GraphViewer& graph_viewer, return nullptr; } - // parent must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (p_parent_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return p_parent_node_unit; + } + return nullptr; +} + +const NodeUnit* GetOnlyChildOfOutput(const GraphViewer& graph_viewer, + const NodeUnit& node_unit, + const NodeUnitIODef& output, + const std::unordered_map& node_unit_map, + const std::unordered_map& qnn_node_group_map) { + const Node* p_parent_node = nullptr; + + for (auto node : node_unit.GetAllNodesInGroup()) { + for (auto node_output : node->OutputDefs()) { + if (node_output->Name() == output.node_arg.Name()) { + p_parent_node = node; + break; + } + } + // break the loop if producer node of output is found + if (p_parent_node != nullptr) { + break; + } + } + + // return if the given output tensor is not produced by any node in the given node_unit + if (p_parent_node == nullptr) { + return nullptr; + } + + const Node& parent_node = *p_parent_node; + + if (graph_viewer.NodeProducesGraphOutput(parent_node)) { + // Node is producing a graph output + return nullptr; + } + + // First pass: count how many children consume this specific output + int child_count = 0; + const NodeUnit* p_child_node_unit = nullptr; + + for (auto edge = parent_node.OutputEdgesBegin(); edge != parent_node.OutputEdgesEnd(); ++edge) { + const Node& child_node = edge->GetNode(); + + // Check if this edge corresponds to the output we're looking for + bool is_matching_output = false; + for (auto child_input : child_node.InputDefs()) { + if (child_input->Name() == output.node_arg.Name()) { + is_matching_output = true; + break; + } + } + + if (!is_matching_output) { + continue; + } + + if (graph_viewer.GetNode(child_node.Index()) == nullptr) { + // Node is not in this GraphViewer return nullptr; } - return p_parent_node_unit; + const auto child_node_unit_it = node_unit_map.find(&child_node); + if (child_node_unit_it == node_unit_map.end()) { + return nullptr; + } + const NodeUnit* current_child_node_unit = child_node_unit_it->second; + + // Check if child node has already been handled. Should not be the case if the calling + // fusion function has been called in topological order, but check to be safe. + if (qnn_node_group_map.count(current_child_node_unit) != 0) { + return nullptr; + } + + // Store the child node unit and increment count + p_child_node_unit = current_child_node_unit; + child_count++; + + // If we found more than one child, return nullptr immediately + if (child_count > 1) { + return nullptr; + } } - return nullptr; + + // Return the child only if there's exactly one child + return (child_count == 1) ? p_child_node_unit : nullptr; } } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h index 14e2a3f25e7db..b52cdd5fa3ec6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h @@ -51,5 +51,11 @@ const NodeUnit* GetParentOfInput(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, const std::unordered_map& qnn_node_group_map); +const NodeUnit* GetOnlyChildOfOutput(const GraphViewer& graph_viewer, + const NodeUnit& node_unit, + const NodeUnitIODef& output, + const std::unordered_map& node_unit_map, + const std::unordered_map& qnn_node_group_map); + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/gelu_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/gelu_fusion_test.cc new file mode 100644 index 0000000000000..e28cf00aa070b --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_node_group/gelu_fusion_test.cc @@ -0,0 +1,407 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +namespace { + +// Helper function to build GELU Pattern 1: root -> Mul -> Div -> Erf -> Add -> Mul +// Pattern 1: +// +-------Mul(0.5)---------------------+ +// | | +// | v +// [root] --> Div -----> Erf --> Add --> Mul ==> +// (B=1.4142...) (1) +GetTestModelFn BuildGeluPattern1TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input + NodeArg* input = MakeTestInput(builder, input_def); + + // Create Mul(0.5) branch: input * 0.5 + NodeArg* half_initializer = builder.MakeScalarInitializer(half); + NodeArg* mul_half_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input, half_initializer}, {mul_half_output}); + + // Create main branch: input / sqrt(2) + NodeArg* sqrt2_initializer = builder.MakeScalarInitializer(sqrt_2); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input, sqrt2_initializer}, {div_output}); + + // Erf + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_output}, {erf_output}); + + // Add 1.0 + NodeArg* one_initializer = builder.MakeScalarInitializer(one); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_output, one_initializer}, {add_output}); + + // Final Mul: (add_output) * (mul_half_output) + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Mul", {add_output, mul_half_output}, {output}); + }; +} + +// Helper function to build GELU Pattern 2: Mul(0.5) after the main sequence +// Pattern 2: +// +------------------------------------+ +// | | +// | v +// [root] --> Div -----> Erf --> Add --> Mul -->Mul ==> +// (B=1.4142...) (1) (0.5) +GetTestModelFn BuildGeluPattern2TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input + NodeArg* input = MakeTestInput(builder, input_def); + + // Main branch: input / sqrt(2) + NodeArg* sqrt2_initializer = builder.MakeScalarInitializer(sqrt_2); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input, sqrt2_initializer}, {div_output}); + + // Erf + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_output}, {erf_output}); + + // Add 1.0 + NodeArg* one_initializer = builder.MakeScalarInitializer(one); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_output, one_initializer}, {add_output}); + + // Mul with input: input * add_output + NodeArg* mul_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input, add_output}, {mul_output}); + + // Final Mul with 0.5: mul_output * 0.5 + NodeArg* half_initializer = builder.MakeScalarInitializer(half); + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Mul", {mul_output, half_initializer}, {output}); + }; +} + +// Helper function to build QDQ GELU Pattern 1 +template +GetTestQDQModelFn BuildQDQGeluPattern1TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder, std::vector>& output_qparams) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + + // Quantize input once + NodeArg* input_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input, input_qparams.scale, input_qparams.zero_point, input_q); + + // Create quantized constants with individual quantization parameters + // For scalar constants, use range [0, value] to ensure proper quantization + QuantParams sqrt2_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, sqrt_2)); + NodeArg* sqrt2_initializer = builder.MakeScalarInitializer(sqrt_2); + NodeArg* sqrt2_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(sqrt2_initializer, sqrt2_qparams.scale, sqrt2_qparams.zero_point, sqrt2_q); + + QuantParams one_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, one)); + NodeArg* one_initializer = builder.MakeScalarInitializer(one); + NodeArg* one_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(one_initializer, one_qparams.scale, one_qparams.zero_point, one_q); + + QuantParams half_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, half)); + NodeArg* half_initializer = builder.MakeScalarInitializer(half); + NodeArg* half_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(half_initializer, half_qparams.scale, half_qparams.zero_point, half_q); + + NodeArg* input_dq_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_q, input_qparams.scale, input_qparams.zero_point, input_dq_1); + NodeArg* sqrt2_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(sqrt2_q, sqrt2_qparams.scale, sqrt2_qparams.zero_point, sqrt2_dq); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input_dq_1, sqrt2_dq}, {div_output}); + NodeArg* div_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(div_output, input_qparams.scale, input_qparams.zero_point, div_q); + + // DQ -> Erf -> Q + NodeArg* div_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(div_q, input_qparams.scale, input_qparams.zero_point, div_dq); + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_dq}, {erf_output}); + NodeArg* erf_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(erf_output, input_qparams.scale, input_qparams.zero_point, erf_q); + + // DQ -> Add -> Q + NodeArg* erf_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(erf_q, input_qparams.scale, input_qparams.zero_point, erf_dq); + NodeArg* one_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(one_q, one_qparams.scale, one_qparams.zero_point, one_dq); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_dq, one_dq}, {add_output}); + NodeArg* add_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(add_output, input_qparams.scale, input_qparams.zero_point, add_q); + + // DQ -> Mul (with input) -> Q + NodeArg* input_dq_2 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_q, input_qparams.scale, input_qparams.zero_point, input_dq_2); + NodeArg* add_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(add_q, input_qparams.scale, input_qparams.zero_point, add_dq); + NodeArg* mul_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input_dq_2, add_dq}, {mul_output}); + NodeArg* mul_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(mul_output, input_qparams.scale, input_qparams.zero_point, mul_q); + + // Final DQ -> Mul (with 0.5) -> Q + NodeArg* mul_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(mul_q, input_qparams.scale, input_qparams.zero_point, mul_dq); + NodeArg* half_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(half_q, half_qparams.scale, half_qparams.zero_point, half_dq); + NodeArg* mul_final_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {mul_dq, half_dq}, {mul_final_output}); + + // Add output QDQ + AddQDQNodePairWithOutputAsGraphOutput(builder, mul_final_output, output_qparams[0].scale, + output_qparams[0].zero_point); + }; +} + +// Helper function to build QDQ GELU Pattern 2 +template +GetTestQDQModelFn BuildQDQGeluPattern2TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder, std::vector>& output_qparams) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + + // Quantize input once + NodeArg* input_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input, input_qparams.scale, input_qparams.zero_point, input_q); + + // Create quantized constants with individual quantization parameters + // For scalar constants, use range [0, value] to ensure proper quantization + QuantParams sqrt2_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, sqrt_2)); + NodeArg* sqrt2_initializer = builder.MakeScalarInitializer(sqrt_2); + NodeArg* sqrt2_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(sqrt2_initializer, sqrt2_qparams.scale, sqrt2_qparams.zero_point, sqrt2_q); + + QuantParams one_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, one)); + NodeArg* one_initializer = builder.MakeScalarInitializer(one); + NodeArg* one_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(one_initializer, one_qparams.scale, one_qparams.zero_point, one_q); + + QuantParams half_qparams = GetTestInputQuantParams(TestInputDef({}, true, 0.0f, half)); + NodeArg* half_initializer = builder.MakeScalarInitializer(half); + NodeArg* half_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(half_initializer, half_qparams.scale, half_qparams.zero_point, half_q); + + // Main branch: DQ -> Div -> Q + NodeArg* input_dq_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_q, input_qparams.scale, input_qparams.zero_point, input_dq_1); + NodeArg* sqrt2_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(sqrt2_q, sqrt2_qparams.scale, sqrt2_qparams.zero_point, sqrt2_dq); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input_dq_1, sqrt2_dq}, {div_output}); + NodeArg* div_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(div_output, input_qparams.scale, input_qparams.zero_point, div_q); + + // DQ -> Erf -> Q + NodeArg* div_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(div_q, input_qparams.scale, input_qparams.zero_point, div_dq); + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_dq}, {erf_output}); + NodeArg* erf_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(erf_output, input_qparams.scale, input_qparams.zero_point, erf_q); + + // DQ -> Add -> Q + NodeArg* erf_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(erf_q, input_qparams.scale, input_qparams.zero_point, erf_dq); + NodeArg* one_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(one_q, one_qparams.scale, one_qparams.zero_point, one_dq); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_dq, one_dq}, {add_output}); + NodeArg* add_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(add_output, input_qparams.scale, input_qparams.zero_point, add_q); + + // DQ -> Mul (with input) -> Q + NodeArg* input_dq_2 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_q, input_qparams.scale, input_qparams.zero_point, input_dq_2); + NodeArg* add_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(add_q, input_qparams.scale, input_qparams.zero_point, add_dq); + NodeArg* mul_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input_dq_2, add_dq}, {mul_output}); + NodeArg* mul_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(mul_output, input_qparams.scale, input_qparams.zero_point, mul_q); + + // Final DQ -> Mul (with 0.5) + NodeArg* mul_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(mul_q, input_qparams.scale, input_qparams.zero_point, mul_dq); + NodeArg* half_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(half_q, half_qparams.scale, half_qparams.zero_point, half_dq); + NodeArg* mul_final_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {mul_dq, half_dq}, {mul_final_output}); + + // Add output QDQ + AddQDQNodePairWithOutputAsGraphOutput(builder, mul_final_output, output_qparams[0].scale, + output_qparams[0].zero_point); + }; +} + +ProviderOptions GetProviderOptions() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + return provider_options; +} + +} // namespace + +// Test GELU Pattern 1 with float32 model (for baseline comparison) +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_Float32) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-3f); +} + +// Test GELU Pattern 2 with float32 model (for baseline comparison) +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_Float32) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-3f); +} + +// Test GELU Pattern 1 with larger input shape +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_LargeInput) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 128, 768}, false, -1.5f, 1.5f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/2e-3f); +} + +// Test GELU Pattern 2 with larger input shape +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_LargeInput) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 128, 768}, false, -1.5f, 1.5f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/2e-3f); +} + +// Test GELU Pattern 1 with 3D input +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_3D) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 16, 32}, false, -1.0f, 1.0f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-3f); +} + +// Test GELU Pattern 2 with 3D input +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_3D) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 16, 32}, false, -1.0f, 1.0f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-3f); +} + +// Test GELU Pattern 1 with 2D input (typical for linear layers) +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_2D) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({32, 512}, false, -1.5f, 1.5f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/2e-3f); +} + +// Test GELU Pattern 2 with 2D input (typical for linear layers) +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_2D) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({32, 512}, false, -1.5f, 1.5f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/2e-3f); +} + +// Test GELU Pattern 1 with QDQ +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_QDQ_U8) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + TestQDQModelAccuracy(BuildGeluPattern1TestCase(input_def), + BuildQDQGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); +} + +// Test GELU Pattern 2 with QDQ +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_QDQ_U8) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + TestQDQModelAccuracy(BuildGeluPattern2TestCase(input_def), + BuildQDQGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD)