diff --git a/onnxruntime/core/optimizer/bias_skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/bias_skip_layer_norm_fusion.cc new file mode 100644 index 0000000000000..cdd8d512c7166 --- /dev/null +++ b/onnxruntime/core/optimizer/bias_skip_layer_norm_fusion.cc @@ -0,0 +1,296 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/bias_skip_layer_norm_fusion.h" + +#include "core/graph/contrib_ops/contrib_defs.h" +#include "core/graph/graph_utils.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { + +/** +Skip Layer Normalization with bias will fuse Add(MatMul, bias) + SkipLayerNormalization into one node. + +Before fusion: + MatMul [skip] + | | + Add(bias) | + \ | + SkipLayerNormalization (4 inputs: input, skip, gamma, beta) + +After fusion: + MatMul [skip] + \ / + SkipLayerNormalization (5 inputs: input, skip, gamma, beta, bias) + +Note: Also handles a Cast between MatMul and Add (for fp16 models): + MatMul → Cast → Add(bias) → SkipLayerNormalization +*/ + +Status BiasSkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + auto get_bias_info = [&](Graph& g, NodeArg& bias_arg, bool& is_1d_bias, int64_t& bias_hidden_size) { + is_1d_bias = false; + bias_hidden_size = -1; + + const TensorShapeProto* bias_shape = bias_arg.Shape(); + if (bias_shape != nullptr) { + is_1d_bias = (bias_shape->dim_size() == 1); + if (is_1d_bias) { + const auto& dim0 = bias_shape->dim(0); + if (dim0.has_dim_value()) { + bias_hidden_size = dim0.dim_value(); + } + } + } else { + // For constant initializers from an outer scope, NodeArg::Shape() may be null. + // Fall back to checking the TensorProto dims to confirm that the bias is 1D. + const TensorProto* bias_initializer = + graph_utils::GetConstantInitializer(g, bias_arg.Name(), true); + if (bias_initializer != nullptr) { + is_1d_bias = (bias_initializer->dims_size() == 1); + if (is_1d_bias) { + bias_hidden_size = bias_initializer->dims(0); + } + } + } + }; + + // Helper: derive the hidden size from a single SLN 1-D input (gamma or beta). + // Returns -1 when the size cannot be determined. + auto get_sln_hidden_size_from_input = [&](const Node& sln, size_t input_index) -> int64_t { + if (sln.InputDefs().size() <= input_index) { + return -1; + } + const NodeArg* arg = sln.InputDefs()[input_index]; + if (arg == nullptr) { + return -1; + } + + const TensorShapeProto* shape = arg->Shape(); + if (shape != nullptr && shape->dim_size() == 1) { + const auto& dim0 = shape->dim(0); + if (dim0.has_dim_value()) { + return dim0.dim_value(); + } + } + + const TensorProto* initializer = + graph_utils::GetConstantInitializer(graph, arg->Name(), true); + if (initializer != nullptr && initializer->dims_size() == 1) { + return initializer->dims(0); + } + + return -1; + }; + + // Helper: derive the SLN hidden size by trying gamma (input 2) then beta (input 3). + auto get_sln_hidden_size = [&](const Node& sln) -> int64_t { + int64_t size = get_sln_hidden_size_from_input(sln, 2); + if (size == -1) { + size = get_sln_hidden_size_from_input(sln, 3); + } + return size; + }; + + for (auto node_index : node_topology_list) { + Node* p_sln = graph.GetNode(node_index); + if (p_sln == nullptr) continue; // node was removed in an earlier fusion + + Node& sln_node = *p_sln; + ORT_RETURN_IF_ERROR(Recurse(sln_node, modified, graph_level, logger)); + + // Must be a SkipLayerNormalization node in the Microsoft custom domain. + if (!graph_utils::IsSupportedOptypeVersionAndDomain(sln_node, "SkipLayerNormalization", {1}, kMSDomain) || + !graph_utils::IsSupportedProvider(sln_node, GetCompatibleExecutionProviders())) { + continue; + } + + // Must have exactly 4 inputs (input, skip, gamma, beta) – bias not yet absorbed. + auto& sln_inputs = sln_node.MutableInputDefs(); + if (sln_inputs.size() != 4) { + continue; + } + + // Try each of the first two SLN inputs (input[0] = "input", input[1] = "skip") to find an Add + // that adds a 1D constant bias to a MatMul result. Also consider a Cast between MatMul and Add + // (common in fp16 models). + Node* p_add = nullptr; + int sln_add_input_index = -1; // which SLN input (0 or 1) leads to the Add node + int add_bias_index = -1; // which Add input (0 or 1) is the 1D constant bias + + // Helper: validate a candidate Add node and, if it is a compatible bias-add, accept it + // by setting p_add/sln_add_input_index/add_bias_index. Returns true on acceptance. + // Both Path 1 and Path 2 call this so all acceptance criteria stay in one place. + auto try_accept_add = [&](Node* candidate_add, int add_matmul_input_idx, int sln_input_idx) -> bool { + if (candidate_add->GetExecutionProviderType() != sln_node.GetExecutionProviderType() || + candidate_add->GetOutputEdgesCount() != 1 || + graph.NodeProducesGraphOutput(*candidate_add)) { + return false; + } + int bias_idx = 1 - add_matmul_input_idx; + NodeArg* bias_arg = candidate_add->MutableInputDefs()[bias_idx]; + if (!graph_utils::NodeArgIsConstant(graph, *bias_arg)) { + return false; + } + bool is_1d_bias = false; + int64_t bias_hidden_size = -1; + get_bias_info(graph, *bias_arg, is_1d_bias, bias_hidden_size); + if (!is_1d_bias) { + return false; + } + // Verify the bias length matches the SLN hidden size. + // Try gamma/beta first; if not available, fall back to the last dimension of the Add's + // non-bias input (the MatMul/Cast output, whose last dim equals the hidden size). + int64_t sln_hidden_size = get_sln_hidden_size(sln_node); + if (sln_hidden_size == -1) { + const NodeArg* non_bias_arg = candidate_add->MutableInputDefs()[add_matmul_input_idx]; + const TensorShapeProto* non_bias_shape = non_bias_arg->Shape(); + if (non_bias_shape != nullptr && non_bias_shape->dim_size() > 0) { + const auto& last_dim = non_bias_shape->dim(non_bias_shape->dim_size() - 1); + if (last_dim.has_dim_value()) { + sln_hidden_size = last_dim.dim_value(); + } + } + } + // Require positive proof that bias length == hidden size; bail if either is still unknown. + if (sln_hidden_size == -1 || bias_hidden_size == -1 || sln_hidden_size != bias_hidden_size) { + return false; + } + p_add = candidate_add; + sln_add_input_index = sln_input_idx; + add_bias_index = bias_idx; + return true; + }; + + for (int sln_input_idx = 0; sln_input_idx <= 1 && p_add == nullptr; ++sln_input_idx) { + for (int add_matmul_input_idx = 0; add_matmul_input_idx <= 1 && p_add == nullptr; + ++add_matmul_input_idx) { + // --- Path 1: SLN.input[sln_input_idx] ← Add ← MatMul (direct) --- + std::vector path_matmul{ + {0, sln_input_idx, "Add", {7, 13, 14}, kOnnxDomain}, + {0, add_matmul_input_idx, "MatMul", {1, 9, 13}, kOnnxDomain}}; + + std::vector edges; + if (graph_utils::FindPath(sln_node, true, path_matmul, edges, logger)) { + try_accept_add(const_cast(&edges[0]->GetNode()), add_matmul_input_idx, sln_input_idx); + } + + if (p_add != nullptr) break; + + // --- Path 2: SLN.input[sln_input_idx] ← Add ← Cast ← MatMul (fp16 models) --- + std::vector path_cast_matmul{ + {0, sln_input_idx, "Add", {7, 13, 14}, kOnnxDomain}, + {0, add_matmul_input_idx, "Cast", {1, 6, 9, 13, 15}, kOnnxDomain}, + {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}}; + + if (graph_utils::FindPath(sln_node, true, path_cast_matmul, edges, logger)) { + try_accept_add(const_cast(&edges[0]->GetNode()), add_matmul_input_idx, sln_input_idx); + } + } + } + + if (p_add == nullptr) continue; + + // Determine the non-bias Add input (MatMul / Cast output). + int add_non_bias_input_index = 1 - add_bias_index; + + // Snapshot all information we need from the original nodes before modifying the graph. + // Build the new 5-input SkipLayerNormalization by replacing only the SLN input slot that + // was fed by the bias-Add with the Add's non-bias (MatMul/Cast) input. All other SLN inputs + // stay in their original positions. Preserving the input[0]/input[1] order is important + // because SkipLayerNormalization derives its output shape from input[0] while input[1] + // supports broadcasting; swapping them would silently change semantics. + InlinedVector new_sln_inputs{ + sln_inputs[0], // original input[0] (replaced below if needed) + sln_inputs[1], // original input[1] (replaced below if needed) + sln_inputs[2], // gamma – unchanged + sln_inputs[3], // beta – unchanged + p_add->MutableInputDefs()[add_bias_index] // bias (1D constant) – absorbed from Add + }; + // Replace only the SLN slot that was connected to the bias-Add. + new_sln_inputs[sln_add_input_index] = p_add->MutableInputDefs()[add_non_bias_input_index]; + + // Snapshot the outputs of the original SkipLayerNormalization so we can safely remove it + // before creating the replacement node while preserving the same graph outputs. + InlinedVector new_sln_outputs; + { + auto& sln_output_defs = sln_node.MutableOutputDefs(); + new_sln_outputs.assign(sln_output_defs.begin(), sln_output_defs.end()); + } + + // Snapshot attributes and execution provider type from the original SLN node. + const NodeAttributes sln_attrs = sln_node.GetAttributes(); + const std::string sln_ep = sln_node.GetExecutionProviderType(); + + // Capture outgoing edges from the original SLN node BEFORE removing any nodes. + // RemoveNodeOutputEdges clears the edge list, so this must precede removal to + // ensure downstream consumers are correctly rewired to the new fused node. + std::vector> sln_output_edges; + sln_output_edges.reserve(std::distance(sln_node.OutputEdgesBegin(), sln_node.OutputEdgesEnd())); + for (auto it = sln_node.OutputEdgesBegin(); it != sln_node.OutputEdgesEnd(); ++it) { + auto& edge = *it; + sln_output_edges.emplace_back(edge.GetNode().Index(), edge.GetSrcArgIndex(), edge.GetDstArgIndex()); + } + + // Remove the original Add and SkipLayerNormalization nodes (and their output edges) + // before adding the fused node to maintain the single-producer invariant for NodeArgs. + graph_utils::RemoveNodeOutputEdges(graph, *p_add); + graph.RemoveNode(p_add->Index()); + graph_utils::RemoveNodeOutputEdges(graph, sln_node); + graph.RemoveNode(sln_node.Index()); + + // The fused 5-input SkipLayerNormalization: + // input[0] = original SLN input[0] (unless the bias-Add was at SLN input[0]) + // input[1] = original SLN input[1] (unless the bias-Add was at SLN input[1]) + // input[2] = gamma – unchanged + // input[3] = beta – unchanged + // input[4] = bias – absorbed from the Add node + Node& new_sln_node = graph.AddNode( + graph.GenerateNodeName("SkipLayerNormalization"), + "SkipLayerNormalization", + "fused SkipLayerNormalization and bias Add", + new_sln_inputs, + new_sln_outputs, + {}, + kMSDomain); + + // Copy all attributes from the original SkipLayerNormalization node, ensuring epsilon is set. + + // First copy all non-epsilon attributes. + for (const auto& attr_pair : sln_attrs) { + if (attr_pair.first == "epsilon") { + continue; + } + new_sln_node.AddAttributeProto(attr_pair.second); + } + + // Then handle epsilon specifically so we can apply a default if it is missing. + auto epsilon_it = sln_attrs.find("epsilon"); + if (epsilon_it != sln_attrs.end()) { + new_sln_node.AddAttributeProto(epsilon_it->second); + } else { + new_sln_node.AddAttribute("epsilon", contrib::kDefaultSkipLayerNormEpsilon); + } + + new_sln_node.SetExecutionProviderType(sln_ep); + + // Rewire all downstream consumers from the original SLN node to the new fused node. + for (const auto& edge_info : sln_output_edges) { + graph.AddEdge(new_sln_node.Index(), std::get<0>(edge_info), std::get<1>(edge_info), + std::get<2>(edge_info)); + } + + modified = true; + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/bias_skip_layer_norm_fusion.h b/onnxruntime/core/optimizer/bias_skip_layer_norm_fusion.h new file mode 100644 index 0000000000000..0f97243018da8 --- /dev/null +++ b/onnxruntime/core/optimizer/bias_skip_layer_norm_fusion.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** + * \class BiasSkipLayerNormFusion + * \brief Rewrite graph fusing Add + SkipLayerNormalization subgraph to a single SkipLayerNormalization node, + * where the Add node adds a 1D constant bias to the output of a MatMul (or Cast after MatMul). + * + * Before fusion: + * MatMul + * | + * Add(bias) [skip] + * \ / + * SkipLayerNormalization (4 inputs: input, skip, gamma, beta) + * + * After fusion: + * MatMul [skip] + * \ / + * SkipLayerNormalization (5 inputs: input, skip, gamma, beta, bias) + */ +class BiasSkipLayerNormFusion : public GraphTransformer { + public: + explicit BiasSkipLayerNormFusion( + const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("BiasSkipLayerNormFusion", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 640848d47fe93..8945bee6f8cbb 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -79,6 +79,7 @@ #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" +#include "core/optimizer/bias_skip_layer_norm_fusion.h" #include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/transpose_optimizer.h" @@ -331,6 +332,12 @@ InlinedVector> GenerateTransformers( onnxruntime::kAclExecutionProvider, onnxruntime::kCudaExecutionProvider, onnxruntime::kDmlExecutionProvider}; + const InlinedHashSet cpu_acl_cuda_dml_js_webgpu_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kDmlExecutionProvider, + onnxruntime::kJsExecutionProvider, + onnxruntime::kWebGpuExecutionProvider}; const InlinedHashSet cpu_acl_js_webgpu_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kAclExecutionProvider, onnxruntime::kJsExecutionProvider, @@ -385,7 +392,8 @@ InlinedVector> GenerateTransformers( // Run MatMulAddFusion again after *AttentionFusion transforms with `preserve_attention_pattern = false`, // to cleanup the remaining MatMul-Add that were part of the attention pattern but not detected or fused. transformers.emplace_back(std::make_unique(no_limit_empty_ep_list, false)); - transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_js_webgpu_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_js_webgpu_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_eps)); transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_eps)); diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 4615b6a57b558..31827e76415ff 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -14,6 +14,7 @@ #include "core/graph/model.h" #include "core/optimizer/initializer.h" +#include "core/optimizer/bias_skip_layer_norm_fusion.h" #include "core/optimizer/embed_layer_norm_fusion.h" #include "core/optimizer/group_query_attention_fusion.h" #include "core/optimizer/layer_norm_fusion.h" @@ -756,6 +757,483 @@ TEST_F(GraphTransformationTests, SkipLayerNormFusion_NoBeta) { TestSkipLayerNormFusionNoBeta(MODEL_FOLDER "fusion/skip_layer_norm_no_beta_with_cast.onnx", true, logger_.get()); } +// ---- BiasSkipLayerNormFusion tests ---- +// +// All tests start with a pre-existing 4-input com.microsoft.SkipLayerNormalization node, +// mirroring the scenario where a model was already exported with SkipLayerNormalization (e.g., +// via the Python transformer optimizer), and a bias Add upstream still needs to be absorbed. + +// Verify that Add(MatMul_out, bias_1D) → SLN(4 inputs) is fused into SLN(5 inputs). +// Pattern: Add at SLN input[0], bias as Add input[1]. +TEST_F(GraphTransformationTests, BiasSkipLayerNormFusion_AddAtInput0) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* matmul_a = builder.MakeInput({2, 4, 8}, -1.0f, 1.0f); + auto* matmul_b = builder.MakeInitializer({8, 4}, -1.0f, 1.0f); + auto* skip = builder.MakeInput({2, 4, 4}, -1.0f, 1.0f); + auto* gamma = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); + auto* beta = builder.MakeInitializer({4}, {0.0f, 0.0f, 0.0f, 0.0f}); + auto* bias = builder.MakeInitializer({4}, {0.1f, 0.2f, 0.3f, 0.4f}); + + auto* matmul_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* sln_out = builder.MakeOutput(); + + builder.AddNode("MatMul", {matmul_a, matmul_b}, {matmul_out}); + builder.AddNode("Add", {matmul_out, bias}, {add_out}); + // 4-input SLN: add_out at input[0], skip at input[1] + builder.AddNode("SkipLayerNormalization", {add_out, skip, gamma, beta}, {sln_out}, kMSDomain); + }; + + auto post_graph_checker = [](Graph& graph) { + auto op_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count["Add"] == 0); + TEST_RETURN_IF_NOT(op_count["com.microsoft.SkipLayerNormalization"] == 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "SkipLayerNormalization") { + // Bias absorbed as 5th input + TEST_RETURN_IF_NOT(node.InputDefs().size() == 5u); + + // Verify wiring: input[0] is produced by MatMul, input[1] is the original skip input, + // and input[4] is an initializer (the fused bias). + const auto& input_defs = node.InputDefs(); + auto* input0 = input_defs[0]; + auto* input1 = input_defs[1]; + auto* input4 = input_defs[4]; + + // input[0] should come from MatMul + const Node* input0_producer = graph.GetProducerNode(input0->Name()); + TEST_RETURN_IF_NOT(input0_producer != nullptr); + TEST_RETURN_IF_NOT(input0_producer->OpType() == "MatMul"); + + // input[1] should be the skip connection: a graph input (no producer) + const Node* input1_producer = graph.GetProducerNode(input1->Name()); + TEST_RETURN_IF_NOT(input1_producer == nullptr); + bool is_graph_input1 = false; + for (const auto* gi : graph.GetInputs()) { + if (gi->Name() == input1->Name()) { + is_graph_input1 = true; + break; + } + } + TEST_RETURN_IF_NOT(is_graph_input1); + + // input[4] should be an initializer (the fused bias), identified by name + const Node* input4_producer = graph.GetProducerNode(input4->Name()); + TEST_RETURN_IF_NOT(input4_producer == nullptr); + const ONNX_NAMESPACE::TensorProto* bias_initializer = nullptr; + TEST_RETURN_IF_NOT(graph.GetInitializedTensor(input4->Name(), bias_initializer)); + TEST_RETURN_IF_NOT(bias_initializer != nullptr); + } + } + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 17, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + +// Same as above, but bias is Add input[0] (not input[1]). +TEST_F(GraphTransformationTests, BiasSkipLayerNormFusion_BiasAsFirstAddInput) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* matmul_a = builder.MakeInput({2, 4, 8}, -1.0f, 1.0f); + auto* matmul_b = builder.MakeInitializer({8, 4}, -1.0f, 1.0f); + auto* skip = builder.MakeInput({2, 4, 4}, -1.0f, 1.0f); + auto* gamma = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); + auto* beta = builder.MakeInitializer({4}, {0.0f, 0.0f, 0.0f, 0.0f}); + auto* bias = builder.MakeInitializer({4}, {0.1f, 0.2f, 0.3f, 0.4f}); + + auto* matmul_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* sln_out = builder.MakeOutput(); + + builder.AddNode("MatMul", {matmul_a, matmul_b}, {matmul_out}); + // bias is Add input[0], MatMul output is Add input[1] + builder.AddNode("Add", {bias, matmul_out}, {add_out}); + builder.AddNode("SkipLayerNormalization", {add_out, skip, gamma, beta}, {sln_out}, kMSDomain); + }; + + auto post_graph_checker = [](Graph& graph) { + auto op_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count["Add"] == 0); + TEST_RETURN_IF_NOT(op_count["com.microsoft.SkipLayerNormalization"] == 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "SkipLayerNormalization") { + TEST_RETURN_IF_NOT(node.InputDefs().size() == 5u); + + // Verify wiring for this scenario as well: input[0] from MatMul, input[1] is skip input, + // and input[4] is an initializer (the fused bias). + const auto& input_defs = node.InputDefs(); + auto* input0 = input_defs[0]; + auto* input1 = input_defs[1]; + auto* input4 = input_defs[4]; + + // input[0] should come from MatMul + const Node* input0_producer = graph.GetProducerNode(input0->Name()); + TEST_RETURN_IF_NOT(input0_producer != nullptr); + TEST_RETURN_IF_NOT(input0_producer->OpType() == "MatMul"); + + // input[1] should be the skip connection: a graph input (no producer) + const Node* input1_producer = graph.GetProducerNode(input1->Name()); + TEST_RETURN_IF_NOT(input1_producer == nullptr); + bool is_graph_input1 = false; + for (const auto* gi : graph.GetInputs()) { + if (gi->Name() == input1->Name()) { + is_graph_input1 = true; + break; + } + } + TEST_RETURN_IF_NOT(is_graph_input1); + + // input[4] should be an initializer (the fused bias), identified by name + const Node* input4_producer = graph.GetProducerNode(input4->Name()); + TEST_RETURN_IF_NOT(input4_producer == nullptr); + const ONNX_NAMESPACE::TensorProto* bias_initializer = nullptr; + TEST_RETURN_IF_NOT(graph.GetInitializedTensor(input4->Name(), bias_initializer)); + TEST_RETURN_IF_NOT(bias_initializer != nullptr); + } + } + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 17, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + +// Add(MatMul_out, bias_1D) is connected to SLN input[1] (the "skip" input). +TEST_F(GraphTransformationTests, BiasSkipLayerNormFusion_AddAtSkipInput) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({2, 4, 4}, -1.0f, 1.0f); + auto* matmul_a = builder.MakeInput({2, 4, 8}, -1.0f, 1.0f); + auto* matmul_b = builder.MakeInitializer({8, 4}, -1.0f, 1.0f); + auto* gamma = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); + auto* beta = builder.MakeInitializer({4}, {0.0f, 0.0f, 0.0f, 0.0f}); + auto* bias = builder.MakeInitializer({4}, {0.1f, 0.2f, 0.3f, 0.4f}); + + auto* matmul_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* sln_out = builder.MakeOutput(); + + builder.AddNode("MatMul", {matmul_a, matmul_b}, {matmul_out}); + builder.AddNode("Add", {matmul_out, bias}, {add_out}); + // add_out at SLN input[1] (skip), primary input at input[0] + builder.AddNode("SkipLayerNormalization", {input, add_out, gamma, beta}, {sln_out}, kMSDomain); + }; + + auto post_graph_checker = [](Graph& graph) { + auto op_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count["Add"] == 0); + TEST_RETURN_IF_NOT(op_count["com.microsoft.SkipLayerNormalization"] == 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "SkipLayerNormalization") { + TEST_RETURN_IF_NOT(node.InputDefs().size() == 5u); + + const auto& input_defs = node.InputDefs(); + + // input[0] should be the original graph input (unchanged – Add fed SLN.input[1], so + // only SLN.input[1] is replaced with the MatMul output; input[0] keeps its original value). + const Node* input0_producer = graph.GetProducerNode(input_defs[0]->Name()); + TEST_RETURN_IF_NOT(input0_producer == nullptr); + bool is_graph_input0 = false; + for (const auto* gi : graph.GetInputs()) { + if (gi->Name() == input_defs[0]->Name()) { + is_graph_input0 = true; + break; + } + } + TEST_RETURN_IF_NOT(is_graph_input0); + + // input[1] should come from MatMul (the bias-Add was at SLN.input[1]) + const Node* input1_producer = graph.GetProducerNode(input_defs[1]->Name()); + TEST_RETURN_IF_NOT(input1_producer != nullptr); + TEST_RETURN_IF_NOT(input1_producer->OpType() == "MatMul"); + + // input[4] should be an initializer (the fused bias) + const Node* input4_producer = graph.GetProducerNode(input_defs[4]->Name()); + TEST_RETURN_IF_NOT(input4_producer == nullptr); + const ONNX_NAMESPACE::TensorProto* bias_initializer = nullptr; + TEST_RETURN_IF_NOT(graph.GetInitializedTensor(input_defs[4]->Name(), bias_initializer)); + TEST_RETURN_IF_NOT(bias_initializer != nullptr); + } + } + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 17, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + +// Cast variant: MatMul → Cast → Add(bias_1D) → SLN(4 inputs). +// Models using fp16 precision commonly insert a Cast between MatMul and the bias Add. +TEST_F(GraphTransformationTests, BiasSkipLayerNormFusion_WithCast) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* matmul_a = builder.MakeInput({2, 4, 8}, MLFloat16(-1.0f), MLFloat16(1.0f)); + auto* matmul_b = builder.MakeInitializer({8, 4}, MLFloat16(-1.0f), MLFloat16(1.0f)); + auto* skip = builder.MakeInput({2, 4, 4}, -1.0f, 1.0f); + auto* gamma = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); + auto* beta = builder.MakeInitializer({4}, {0.0f, 0.0f, 0.0f, 0.0f}); + auto* bias = builder.MakeInitializer({4}, {0.1f, 0.2f, 0.3f, 0.4f}); + + auto* matmul_out = builder.MakeIntermediate(); + auto* cast_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* sln_out = builder.MakeOutput(); + + builder.AddNode("MatMul", {matmul_a, matmul_b}, {matmul_out}); + builder.AddNode("Cast", {matmul_out}, {cast_out}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + builder.AddNode("Add", {cast_out, bias}, {add_out}); + builder.AddNode("SkipLayerNormalization", {add_out, skip, gamma, beta}, {sln_out}, kMSDomain); + }; + + auto post_graph_checker = [](Graph& graph) { + auto op_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count["Add"] == 0); + TEST_RETURN_IF_NOT(op_count["com.microsoft.SkipLayerNormalization"] == 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "SkipLayerNormalization") { + TEST_RETURN_IF_NOT(node.InputDefs().size() == 5u); + + const auto& input_defs = node.InputDefs(); + + // input[0] should come from Cast (MatMul → Cast → fused SLN) + const Node* input0_producer = graph.GetProducerNode(input_defs[0]->Name()); + TEST_RETURN_IF_NOT(input0_producer != nullptr); + TEST_RETURN_IF_NOT(input0_producer->OpType() == "Cast"); + + // input[1] should be the skip connection: a graph input (no producer) + const Node* input1_producer = graph.GetProducerNode(input_defs[1]->Name()); + TEST_RETURN_IF_NOT(input1_producer == nullptr); + bool is_graph_input1 = false; + for (const auto* gi : graph.GetInputs()) { + if (gi->Name() == input_defs[1]->Name()) { + is_graph_input1 = true; + break; + } + } + TEST_RETURN_IF_NOT(is_graph_input1); + + // input[4] should be an initializer (the fused bias) + const Node* input4_producer = graph.GetProducerNode(input_defs[4]->Name()); + TEST_RETURN_IF_NOT(input4_producer == nullptr); + const ONNX_NAMESPACE::TensorProto* bias_initializer = nullptr; + TEST_RETURN_IF_NOT(graph.GetInitializedTensor(input_defs[4]->Name(), bias_initializer)); + TEST_RETURN_IF_NOT(bias_initializer != nullptr); + } + } + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 17, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + +// Cast variant negative test: bias is 1D but its length is incompatible with gamma/beta. +// This guards against fusing dimension-mismatched biases when hidden-size validation is applied +// on the Cast path. +TEST_F(GraphTransformationTests, BiasSkipLayerNormFusion_WithCast_BiasHiddenSizeMismatch) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* matmul_a = builder.MakeInput({2, 4, 8}, MLFloat16(-1.0f), MLFloat16(1.0f)); + auto* matmul_b = builder.MakeInitializer({8, 4}, MLFloat16(-1.0f), MLFloat16(1.0f)); + auto* skip = builder.MakeInput({2, 4, 4}, -1.0f, 1.0f); + auto* gamma = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); + auto* beta = builder.MakeInitializer({4}, {0.0f, 0.0f, 0.0f, 0.0f}); + // Intentionally use a 1D bias whose length does not match gamma/beta (size 4). + // bias{1} broadcasts validly with cast_out{2,4,4}, but bias_hidden_size(1) != sln_hidden_size(4) + // so the fusion is blocked. + auto* bias = builder.MakeInitializer({1}, {0.5f}); + + auto* matmul_out = builder.MakeIntermediate(); + auto* cast_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* sln_out = builder.MakeOutput(); + + builder.AddNode("MatMul", {matmul_a, matmul_b}, {matmul_out}); + builder.AddNode("Cast", {matmul_out}, {cast_out}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + builder.AddNode("Add", {cast_out, bias}, {add_out}); + builder.AddNode("SkipLayerNormalization", {add_out, skip, gamma, beta}, {sln_out}, kMSDomain); + }; + + auto post_graph_checker = [](Graph& graph) { + auto op_count = CountOpsInGraph(graph); + // Fusion should not occur: Add must remain, and SkipLayerNormalization must keep 4 inputs. + TEST_RETURN_IF_NOT(op_count["Add"] == 1); + TEST_RETURN_IF_NOT(op_count["com.microsoft.SkipLayerNormalization"] == 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "SkipLayerNormalization") { + TEST_RETURN_IF_NOT(node.InputDefs().size() == 4u); + } + } + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 17, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + +// Fusion must NOT occur when the bias is 2D (not 1D). +TEST_F(GraphTransformationTests, BiasSkipLayerNormFusion_NoFusion_2DBias) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* matmul_a = builder.MakeInput({2, 4, 8}, -1.0f, 1.0f); + auto* matmul_b = builder.MakeInitializer({8, 4}, -1.0f, 1.0f); + auto* skip = builder.MakeInput({2, 4, 4}, -1.0f, 1.0f); + auto* gamma = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); + auto* beta = builder.MakeInitializer({4}, {0.0f, 0.0f, 0.0f, 0.0f}); + // 2D bias – should prevent fusion + auto* bias_2d = builder.MakeInitializer({1, 4}, {0.1f, 0.2f, 0.3f, 0.4f}); + + auto* matmul_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* sln_out = builder.MakeOutput(); + + builder.AddNode("MatMul", {matmul_a, matmul_b}, {matmul_out}); + builder.AddNode("Add", {matmul_out, bias_2d}, {add_out}); + builder.AddNode("SkipLayerNormalization", {add_out, skip, gamma, beta}, {sln_out}, kMSDomain); + }; + + auto post_graph_checker = [](Graph& graph) { + auto op_count = CountOpsInGraph(graph); + // Graph should be unchanged: Add and 4-input SLN both remain. + TEST_RETURN_IF_NOT(op_count["Add"] == 1); + TEST_RETURN_IF_NOT(op_count["com.microsoft.SkipLayerNormalization"] == 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "SkipLayerNormalization") { + TEST_RETURN_IF_NOT(node.InputDefs().size() == 4u); + } + } + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 17, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + +// Fusion must NOT occur when the SLN node already has 5 inputs (bias already absorbed). +TEST_F(GraphTransformationTests, BiasSkipLayerNormFusion_NoFusion_SLNHas5Inputs) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({2, 4, 4}, -1.0f, 1.0f); + auto* skip = builder.MakeInput({2, 4, 4}, -1.0f, 1.0f); + auto* gamma = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); + auto* beta = builder.MakeInitializer({4}, {0.0f, 0.0f, 0.0f, 0.0f}); + auto* bias = builder.MakeInitializer({4}, {0.1f, 0.2f, 0.3f, 0.4f}); + auto* sln_out = builder.MakeOutput(); + + // SLN already has 5 inputs – no further fusion should happen. + builder.AddNode("SkipLayerNormalization", {input, skip, gamma, beta, bias}, {sln_out}, kMSDomain); + }; + + auto post_graph_checker = [](Graph& graph) { + auto op_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count["com.microsoft.SkipLayerNormalization"] == 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "SkipLayerNormalization") { + TEST_RETURN_IF_NOT(node.InputDefs().size() == 5u); + } + } + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 17, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + +// Fusion must NOT occur when the Add node feeds multiple consumers (the output is used both by +// SLN and by another node, so removing Add would drop the other consumer's input). +TEST_F(GraphTransformationTests, BiasSkipLayerNormFusion_NoFusion_AddHasMultipleConsumers) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* matmul_a = builder.MakeInput({2, 4, 8}, -1.0f, 1.0f); + auto* matmul_b = builder.MakeInitializer({8, 4}, -1.0f, 1.0f); + auto* skip = builder.MakeInput({2, 4, 4}, -1.0f, 1.0f); + auto* gamma = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); + auto* beta = builder.MakeInitializer({4}, {0.0f, 0.0f, 0.0f, 0.0f}); + auto* bias = builder.MakeInitializer({4}, {0.1f, 0.2f, 0.3f, 0.4f}); + + auto* matmul_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* sln_out = builder.MakeOutput(); + auto* identity_out = builder.MakeOutput(); + + builder.AddNode("MatMul", {matmul_a, matmul_b}, {matmul_out}); + builder.AddNode("Add", {matmul_out, bias}, {add_out}); + // add_out feeds both SLN and an Identity node – Add has 2 consumers. + builder.AddNode("SkipLayerNormalization", {add_out, skip, gamma, beta}, {sln_out}, kMSDomain); + builder.AddNode("Identity", {add_out}, {identity_out}); + }; + + auto post_graph_checker = [](Graph& graph) { + auto op_count = CountOpsInGraph(graph); + // Add must NOT be removed because it has multiple consumers. + TEST_RETURN_IF_NOT(op_count["Add"] == 1); + TEST_RETURN_IF_NOT(op_count["com.microsoft.SkipLayerNormalization"] == 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "SkipLayerNormalization") { + TEST_RETURN_IF_NOT(node.InputDefs().size() == 4u); + } + } + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 17, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + +// Verify that fusion preserves downstream edges when the SLN output feeds another node. +// This exercises the edge-rewiring code path: the fused SLN node must inherit all consumers +// of the original SLN node. +TEST_F(GraphTransformationTests, BiasSkipLayerNormFusion_DownstreamConsumerPreserved) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* matmul_a = builder.MakeInput({2, 4, 8}, -1.0f, 1.0f); + auto* matmul_b = builder.MakeInitializer({8, 4}, -1.0f, 1.0f); + auto* skip = builder.MakeInput({2, 4, 4}, -1.0f, 1.0f); + auto* gamma = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); + auto* beta = builder.MakeInitializer({4}, {0.0f, 0.0f, 0.0f, 0.0f}); + auto* bias = builder.MakeInitializer({4}, {0.1f, 0.2f, 0.3f, 0.4f}); + + auto* matmul_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* sln_out = builder.MakeIntermediate(); // intermediate: SLN output feeds Identity + auto* identity_out = builder.MakeOutput(); + + builder.AddNode("MatMul", {matmul_a, matmul_b}, {matmul_out}); + builder.AddNode("Add", {matmul_out, bias}, {add_out}); + builder.AddNode("SkipLayerNormalization", {add_out, skip, gamma, beta}, {sln_out}, kMSDomain); + // Downstream consumer of the SLN output: must be preserved after fusion. + builder.AddNode("Identity", {sln_out}, {identity_out}); + }; + + auto post_graph_checker = [](Graph& graph) { + auto op_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count["Add"] == 0); + TEST_RETURN_IF_NOT(op_count["com.microsoft.SkipLayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count["Identity"] == 1); + + // The Identity node must still be wired to the fused SLN output. + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "Identity") { + TEST_RETURN_IF_NOT(node.InputDefs().size() == 1u); + const Node* identity_input_producer = graph.GetProducerNode(node.InputDefs()[0]->Name()); + TEST_RETURN_IF_NOT(identity_input_producer != nullptr); + TEST_RETURN_IF_NOT(identity_input_producer->OpType() == "SkipLayerNormalization"); + // The fused SLN must have 5 inputs. + TEST_RETURN_IF_NOT(identity_input_producer->InputDefs().size() == 5u); + } + } + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 17, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat1) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format1.onnx"; std::shared_ptr p_model;