-
Notifications
You must be signed in to change notification settings - Fork 4k
Add SkipLayerNorm fusion with bias Add #27765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 14 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
f5d7a1f
Add BiasSkipLayerNormFusion C++ graph optimizer
Copilot f970eea
Add unit tests for BiasSkipLayerNormFusion; fix EdgeEndToMatch arg or…
Copilot 6639fa3
Add JS EP and WebGPU EP to BiasSkipLayerNormFusion compatible EPs
Copilot 936c766
Also update SkipLayerNormFusion to include JS EP and WebGPU EP
Copilot c3445f7
Potential fix for pull request finding
kunal-vaishnavi 7f9f1c2
Potential fix for pull request finding
kunal-vaishnavi 7e2a1f2
Potential fix for pull request finding
kunal-vaishnavi 10c4964
Potential fix for pull request finding
kunal-vaishnavi 44cd4ab
Apply suggestions from code review
kunal-vaishnavi c941f40
Apply suggestions from code review
kunal-vaishnavi 344b8ef
Apply suggestions from code review
kunal-vaishnavi 436d629
Apply suggestions from code review
kunal-vaishnavi 774de7f
Refactor: extract get_sln_hidden_size helper to eliminate duplication…
Copilot 67dd5b8
Fix remaining inconsistency: unify hidden-size comparison condition b…
Copilot 78448b6
Fix edge rewiring bug, extract try_accept_add helper, strengthen test…
Copilot b33e48f
Fix CI failures: sign-compare warning and MLFloat16 template arg dedu…
Copilot 6c04cde
Fix Thread 26: preserve input[0]/input[1] ordering in fusion; Thread …
Copilot feb4518
Fix CI build error: Graph::GetInitializedTensor returns bool, not Status
Copilot 8a35f34
Fix BiasSkipLayerNormFusion_WithCast_BiasHiddenSizeMismatch test: use…
Copilot 0a84632
fix: require positive proof of bias-hidden-size match in BiasSkipLaye…
Copilot 6e84003
Update onnxruntime/core/optimizer/bias_skip_layer_norm_fusion.cc
kunal-vaishnavi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
327 changes: 327 additions & 0 deletions
327
onnxruntime/core/optimizer/bias_skip_layer_norm_fusion.cc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,327 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/optimizer/bias_skip_layer_norm_fusion.h" | ||
|
|
||
| #include "core/graph/contrib_ops/contrib_defs.h" | ||
| #include "core/graph/graph_utils.h" | ||
|
|
||
| using namespace ONNX_NAMESPACE; | ||
| using namespace onnxruntime::common; | ||
|
|
||
| namespace onnxruntime { | ||
|
|
||
| /** | ||
| Skip Layer Normalization with bias will fuse Add(MatMul, bias) + SkipLayerNormalization into one node. | ||
|
|
||
| Before fusion: | ||
| MatMul [skip] | ||
| | | | ||
| Add(bias) | | ||
| \ | | ||
| SkipLayerNormalization (4 inputs: input, skip, gamma, beta) | ||
|
|
||
| After fusion: | ||
| MatMul [skip] | ||
| \ / | ||
| SkipLayerNormalization (5 inputs: input, skip, gamma, beta, bias) | ||
|
|
||
| Note: Also handles a Cast between MatMul and Add (for fp16 models): | ||
| MatMul → Cast → Add(bias) → SkipLayerNormalization | ||
| */ | ||
|
|
||
| Status BiasSkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, | ||
| const logging::Logger& logger) const { | ||
| GraphViewer graph_viewer(graph); | ||
| const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); | ||
| const size_t original_num_nodes = graph.NumberOfNodes(); | ||
|
|
||
| auto get_bias_info = [&](Graph& g, NodeArg& bias_arg, bool& is_1d_bias, int64_t& bias_hidden_size) { | ||
| is_1d_bias = false; | ||
| bias_hidden_size = -1; | ||
|
|
||
| const TensorShapeProto* bias_shape = bias_arg.Shape(); | ||
| if (bias_shape != nullptr) { | ||
| is_1d_bias = (bias_shape->dim_size() == 1); | ||
| if (is_1d_bias) { | ||
| const auto& dim0 = bias_shape->dim(0); | ||
| if (dim0.has_dim_value()) { | ||
| bias_hidden_size = dim0.dim_value(); | ||
| } | ||
| } | ||
| } else { | ||
| // For constant initializers from an outer scope, NodeArg::Shape() may be null. | ||
| // Fall back to checking the TensorProto dims to confirm that the bias is 1D. | ||
|
kunal-vaishnavi marked this conversation as resolved.
|
||
| const TensorProto* bias_initializer = | ||
| graph_utils::GetConstantInitializer(g, bias_arg.Name(), true); | ||
| if (bias_initializer != nullptr) { | ||
| is_1d_bias = (bias_initializer->dims_size() == 1); | ||
| if (is_1d_bias) { | ||
| bias_hidden_size = bias_initializer->dims(0); | ||
| } | ||
| } | ||
| } | ||
| }; | ||
|
kunal-vaishnavi marked this conversation as resolved.
kunal-vaishnavi marked this conversation as resolved.
|
||
|
|
||
| // Helper: derive the hidden size from a single SLN 1-D input (gamma or beta). | ||
| // Returns -1 when the size cannot be determined. | ||
| auto get_sln_hidden_size_from_input = [&](const Node& sln, size_t input_index) -> int64_t { | ||
| if (sln.InputDefs().size() <= input_index) { | ||
| return -1; | ||
| } | ||
| const NodeArg* arg = sln.InputDefs()[input_index]; | ||
| if (arg == nullptr) { | ||
| return -1; | ||
| } | ||
|
|
||
| const TensorShapeProto* shape = arg->Shape(); | ||
| if (shape != nullptr && shape->dim_size() == 1) { | ||
| const auto& dim0 = shape->dim(0); | ||
| if (dim0.has_dim_value()) { | ||
| return dim0.dim_value(); | ||
| } | ||
| } | ||
|
|
||
| const TensorProto* initializer = | ||
| graph_utils::GetConstantInitializer(graph, arg->Name(), true); | ||
| if (initializer != nullptr && initializer->dims_size() == 1) { | ||
| return initializer->dims(0); | ||
| } | ||
|
|
||
| return -1; | ||
| }; | ||
|
|
||
| // Helper: derive the SLN hidden size by trying gamma (input 2) then beta (input 3). | ||
| auto get_sln_hidden_size = [&](const Node& sln) -> int64_t { | ||
| int64_t size = get_sln_hidden_size_from_input(sln, 2); | ||
| if (size == -1) { | ||
| size = get_sln_hidden_size_from_input(sln, 3); | ||
| } | ||
| return size; | ||
| }; | ||
|
|
||
| for (auto node_index : node_topology_list) { | ||
| Node* p_sln = graph.GetNode(node_index); | ||
| if (p_sln == nullptr) continue; // node was removed in an earlier fusion | ||
|
|
||
| Node& sln_node = *p_sln; | ||
| ORT_RETURN_IF_ERROR(Recurse(sln_node, modified, graph_level, logger)); | ||
|
|
||
| // Must be a SkipLayerNormalization node in the Microsoft custom domain. | ||
| if (!graph_utils::IsSupportedOptypeVersionAndDomain(sln_node, "SkipLayerNormalization", {1}, kMSDomain) || | ||
| !graph_utils::IsSupportedProvider(sln_node, GetCompatibleExecutionProviders())) { | ||
| continue; | ||
| } | ||
|
|
||
| // Must have exactly 4 inputs (input, skip, gamma, beta) – bias not yet absorbed. | ||
| auto& sln_inputs = sln_node.MutableInputDefs(); | ||
| if (sln_inputs.size() != 4) { | ||
| continue; | ||
| } | ||
|
|
||
| // Try each of the first two SLN inputs (input[0] = "input", input[1] = "skip") to find an Add | ||
| // that adds a 1D constant bias to a MatMul result. Also consider a Cast between MatMul and Add | ||
| // (common in fp16 models). | ||
| Node* p_add = nullptr; | ||
| int sln_add_input_index = -1; // which SLN input (0 or 1) leads to the Add node | ||
| int add_bias_index = -1; // which Add input (0 or 1) is the 1D constant bias | ||
|
|
||
| for (int sln_input_idx = 0; sln_input_idx <= 1 && p_add == nullptr; ++sln_input_idx) { | ||
| for (int add_matmul_input_idx = 0; add_matmul_input_idx <= 1 && p_add == nullptr; | ||
| ++add_matmul_input_idx) { | ||
| // --- Path 1: SLN.input[sln_input_idx] ← Add ← MatMul (direct) --- | ||
| std::vector<graph_utils::EdgeEndToMatch> path_matmul{ | ||
| {0, sln_input_idx, "Add", {7, 13, 14}, kOnnxDomain}, | ||
| {0, add_matmul_input_idx, "MatMul", {1, 9, 13}, kOnnxDomain}}; | ||
|
kunal-vaishnavi marked this conversation as resolved.
|
||
|
|
||
| std::vector<const Node::EdgeEnd*> edges; | ||
| if (graph_utils::FindPath(sln_node, true, path_matmul, edges, logger)) { | ||
| Node* candidate_add = const_cast<Node*>(&edges[0]->GetNode()); | ||
|
|
||
| if (candidate_add->GetExecutionProviderType() == sln_node.GetExecutionProviderType() && | ||
| candidate_add->GetOutputEdgesCount() == 1 && | ||
| !graph.NodeProducesGraphOutput(*candidate_add)) { | ||
| int bias_idx = 1 - add_matmul_input_idx; | ||
| NodeArg* bias_arg = candidate_add->MutableInputDefs()[bias_idx]; | ||
|
|
||
| if (graph_utils::NodeArgIsConstant(graph, *bias_arg)) { | ||
| bool is_1d_bias = false; | ||
| int64_t bias_hidden_size = -1; | ||
| get_bias_info(graph, *bias_arg, is_1d_bias, bias_hidden_size); | ||
|
|
||
| // If we know the bias is 1D, additionally check that its length matches the | ||
| // hidden size expected by SkipLayerNormalization (gamma/beta) when that size | ||
| // can be determined. If the sizes are known and incompatible, skip fusion. | ||
| bool bias_matches_hidden = true; | ||
| if (is_1d_bias) { | ||
| int64_t sln_hidden_size = get_sln_hidden_size(sln_node); | ||
|
|
||
| if (sln_hidden_size != -1 && bias_hidden_size != -1 && sln_hidden_size != bias_hidden_size) { | ||
| bias_matches_hidden = false; | ||
| } | ||
| } | ||
|
|
||
| if (is_1d_bias && bias_matches_hidden) { | ||
| p_add = candidate_add; | ||
| sln_add_input_index = sln_input_idx; | ||
| add_bias_index = bias_idx; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (p_add != nullptr) break; | ||
|
|
||
| // --- Path 2: SLN.input[sln_input_idx] ← Add ← Cast ← MatMul (fp16 models) --- | ||
| std::vector<graph_utils::EdgeEndToMatch> path_cast_matmul{ | ||
| {0, sln_input_idx, "Add", {7, 13, 14}, kOnnxDomain}, | ||
| {0, add_matmul_input_idx, "Cast", {1, 6, 9, 13, 15}, kOnnxDomain}, | ||
| {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}}; | ||
|
kunal-vaishnavi marked this conversation as resolved.
|
||
|
|
||
| if (graph_utils::FindPath(sln_node, true, path_cast_matmul, edges, logger)) { | ||
| Node* candidate_add = const_cast<Node*>(&edges[0]->GetNode()); | ||
|
|
||
| if (candidate_add->GetExecutionProviderType() == sln_node.GetExecutionProviderType() && | ||
| candidate_add->GetOutputEdgesCount() == 1 && | ||
| !graph.NodeProducesGraphOutput(*candidate_add)) { | ||
| int bias_idx = 1 - add_matmul_input_idx; | ||
| NodeArg* bias_arg = candidate_add->MutableInputDefs()[bias_idx]; | ||
|
|
||
| if (graph_utils::NodeArgIsConstant(graph, *bias_arg)) { | ||
| bool is_1d_bias = false; | ||
| int64_t bias_hidden_size = -1; | ||
|
|
||
| // Reuse common bias shape extraction logic to ensure consistent behavior. | ||
| get_bias_info(graph, *bias_arg, is_1d_bias, bias_hidden_size); | ||
|
|
||
| bool bias_matches_hidden = true; | ||
| if (is_1d_bias) { | ||
| // Derive the hidden size from SkipLayerNormalization's gamma/beta inputs, | ||
| // using the same logic as Path 1. | ||
| int64_t sln_hidden_size = get_sln_hidden_size(sln_node); | ||
|
|
||
| if (sln_hidden_size != -1 && bias_hidden_size != -1 && sln_hidden_size != bias_hidden_size) { | ||
| bias_matches_hidden = false; | ||
| } | ||
| } | ||
|
|
||
| if (is_1d_bias && bias_matches_hidden) { | ||
| p_add = candidate_add; | ||
| sln_add_input_index = sln_input_idx; | ||
| add_bias_index = bias_idx; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
kunal-vaishnavi marked this conversation as resolved.
|
||
| } | ||
| } | ||
|
|
||
| if (p_add == nullptr) continue; | ||
|
|
||
| // Determine the non-bias Add input (MatMul / Cast output) and the SLN skip input. | ||
| int add_non_bias_input_index = 1 - add_bias_index; | ||
| int sln_skip_input_index = 1 - sln_add_input_index; | ||
|
|
||
| // Snapshot all information we need from the original nodes before modifying the graph. | ||
| // Inputs from the Add and SkipLayerNormalization nodes. | ||
| InlinedVector<NodeArg*> new_sln_inputs{ | ||
| p_add->MutableInputDefs()[add_non_bias_input_index], // input (MatMul / Cast output) | ||
| sln_inputs[sln_skip_input_index], // skip | ||
| sln_inputs[2], // gamma | ||
| sln_inputs[3], // beta | ||
| p_add->MutableInputDefs()[add_bias_index] // bias (1D constant) | ||
| }; | ||
|
kunal-vaishnavi marked this conversation as resolved.
Outdated
|
||
|
|
||
| // Snapshot the outputs of the original SkipLayerNormalization so we can safely remove it | ||
| // before creating the replacement node while preserving the same graph outputs. | ||
| InlinedVector<NodeArg*> new_sln_outputs; | ||
| { | ||
| auto& sln_output_defs = sln_node.MutableOutputDefs(); | ||
| new_sln_outputs.assign(sln_output_defs.begin(), sln_output_defs.end()); | ||
| } | ||
|
|
||
| // Snapshot attributes and execution provider type from the original SLN node. | ||
| const NodeAttributes sln_attrs = sln_node.GetAttributes(); | ||
| const std::string sln_ep = sln_node.GetExecutionProviderType(); | ||
|
|
||
| // Remove the original Add and SkipLayerNormalization nodes (and their output edges) | ||
| // before adding the fused node to maintain the single-producer invariant for NodeArgs. | ||
| graph_utils::RemoveNodeOutputEdges(graph, *p_add); | ||
| graph.RemoveNode(p_add->Index()); | ||
| graph_utils::RemoveNodeOutputEdges(graph, sln_node); | ||
| graph.RemoveNode(sln_node.Index()); | ||
|
|
||
| // Build the new 5-input SkipLayerNormalization: | ||
| // input[0] = MatMul (or Cast) output – the "input" tensor | ||
| // input[1] = skip – the "skip" tensor | ||
| // input[2] = gamma – unchanged | ||
| // input[3] = beta – unchanged | ||
| // input[4] = bias – absorbed from the Add node | ||
| Node& new_sln_node = graph.AddNode( | ||
| graph.GenerateNodeName("SkipLayerNormalization"), | ||
| "SkipLayerNormalization", | ||
| "fused SkipLayerNormalization and bias Add", | ||
| new_sln_inputs, | ||
| new_sln_outputs, | ||
| {}, | ||
| kMSDomain); | ||
|
kunal-vaishnavi marked this conversation as resolved.
|
||
|
|
||
| // Copy all attributes from the original SkipLayerNormalization node, ensuring epsilon is set. | ||
|
|
||
| // First copy all non-epsilon attributes. | ||
| for (const auto& attr_pair : sln_attrs) { | ||
| if (attr_pair.first == "epsilon") { | ||
| continue; | ||
| } | ||
| new_sln_node.AddAttributeProto(attr_pair.second); | ||
| } | ||
|
|
||
| // Then handle epsilon specifically so we can apply a default if it is missing. | ||
| auto epsilon_it = sln_attrs.find("epsilon"); | ||
| if (epsilon_it != sln_attrs.end()) { | ||
| new_sln_node.AddAttributeProto(epsilon_it->second); | ||
| } else { | ||
| new_sln_node.AddAttribute("epsilon", contrib::kDefaultSkipLayerNormEpsilon); | ||
| } | ||
|
|
||
| new_sln_node.SetExecutionProviderType(sln_ep); | ||
|
|
||
| // Rewire all downstream consumers from the original SLN node to the new fused node. | ||
| // Collect the outgoing edges first to avoid iterator invalidation or use-after-free | ||
| // when modifying edges in the graph. | ||
| std::vector<std::tuple<NodeIndex, int, int>> sln_output_edges; | ||
| sln_output_edges.reserve(std::distance(sln_node.OutputEdgesBegin(), sln_node.OutputEdgesEnd())); | ||
| for (auto it = sln_node.OutputEdgesBegin(); it != sln_node.OutputEdgesEnd(); ++it) { | ||
| auto& edge = *it; | ||
| Node& downstream_node = edge.GetNode(); | ||
| int src_arg_index = edge.GetSrcArgIndex(); | ||
| int dst_arg_index = edge.GetDstArgIndex(); | ||
| sln_output_edges.emplace_back(downstream_node.Index(), src_arg_index, dst_arg_index); | ||
| } | ||
|
|
||
| for (const auto& edge_info : sln_output_edges) { | ||
| NodeIndex downstream_node_index = std::get<0>(edge_info); | ||
| int src_arg_index = std::get<1>(edge_info); | ||
| int dst_arg_index = std::get<2>(edge_info); | ||
|
|
||
| // Add an equivalent edge from the new fused SLN node to the same downstream node. | ||
| // The original edges from the old SLN node have already been removed when that node was removed. | ||
| graph.AddEdge(new_sln_node.Index(), downstream_node_index, src_arg_index, dst_arg_index); | ||
| } | ||
|
kunal-vaishnavi marked this conversation as resolved.
kunal-vaishnavi marked this conversation as resolved.
|
||
|
|
||
| // The original Add and SkipLayerNormalization nodes have already been removed above, | ||
| // so we do not add them to nodes_to_remove here. | ||
| // Note: nodes in other fusion paths may still be collected in nodes_to_remove and | ||
| // removed after the full iteration (see below). | ||
| modified = true; | ||
| } | ||
|
|
||
| // Set 'modified' based on whether the number of nodes in the graph changed. | ||
| if (graph.NumberOfNodes() != original_num_nodes) { | ||
| modified = true; | ||
| } | ||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
||
| } // namespace onnxruntime | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "core/optimizer/graph_transformer.h" | ||
|
|
||
| namespace onnxruntime { | ||
|
|
||
| /** | ||
| * \class BiasSkipLayerNormFusion | ||
| * \brief Rewrite graph fusing Add + SkipLayerNormalization subgraph to a single SkipLayerNormalization node, | ||
| * where the Add node adds a 1D constant bias to the output of a MatMul (or Cast after MatMul). | ||
| * | ||
| * Before fusion: | ||
| * MatMul | ||
| * | | ||
| * Add(bias) [skip] | ||
| * \ / | ||
| * SkipLayerNormalization (4 inputs: input, skip, gamma, beta) | ||
| * | ||
| * After fusion: | ||
| * MatMul [skip] | ||
| * \ / | ||
| * SkipLayerNormalization (5 inputs: input, skip, gamma, beta, bias) | ||
| */ | ||
| class BiasSkipLayerNormFusion : public GraphTransformer { | ||
| public: | ||
| explicit BiasSkipLayerNormFusion( | ||
| const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept | ||
| : GraphTransformer("BiasSkipLayerNormFusion", compatible_execution_providers) {} | ||
|
|
||
| Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; | ||
| }; | ||
|
|
||
| } // namespace onnxruntime |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.