diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index 655364357999a..3727ac0918115 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -254,6 +254,21 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le continue; } + // SkipLayerNormalization kernel requires gamma and beta to be 1D. + // Skip fusion if gamma or beta have more than 1 dimension. + const NodeArg* gamma_arg = ln_node.MutableInputDefs()[1]; + const TensorShapeProto* gamma_shape = gamma_arg->Shape(); + if (gamma_shape != nullptr && gamma_shape->dim_size() != 1) { + continue; + } + if (ln_node.MutableInputDefs().size() > 2) { + const NodeArg* beta_arg = ln_node.MutableInputDefs()[2]; + const TensorShapeProto* beta_shape = beta_arg->Shape(); + if (beta_shape != nullptr && beta_shape->dim_size() != 1) { + continue; + } + } + NodeArg beta_place_holder("", nullptr); // Get the inputs for the new SkipLayerNormalization node. diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 0afb836192b0a..4615b6a57b558 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -638,6 +638,37 @@ TEST_F(GraphTransformationTests, SkipLayerNormFusionTest) { TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output.onnx", 1, 1, 0, 0, logger_.get()); } +// SkipLayerNorm fusion should not be applied when gamma/beta have more than 1 dimension, +// because the SkipLayerNormalization kernel requires 1D gamma/beta. +TEST_F(GraphTransformationTests, SkipLayerNormFusion_3DGamma_NoFusion) { + auto build_test_case = [](ModelTestBuilder& builder) { + // Inputs: A and B are 3D [16, 32, 4] + auto* input_a = builder.MakeInput({16, 32, 4}, -1.0f, 1.0f); + auto* input_b = builder.MakeInput({16, 32, 4}, -1.0f, 1.0f); + // gamma and beta have 3D shape [1, 1, 4] (not 1D) + auto* gamma = builder.MakeInitializer({1, 1, 4}, {1.0f, 2.0f, 3.0f, 4.0f}); + auto* beta = builder.MakeInitializer({1, 1, 4}, {0.1f, 0.2f, 0.3f, 0.4f}); + auto* add_out = builder.MakeIntermediate(); + auto* ln_out = builder.MakeOutput(); + + builder.AddNode("Add", {input_a, input_b}, {add_out}); + builder.AddNode("LayerNormalization", {add_out, gamma, beta}, {ln_out}) + .AddAttribute("axis", static_cast(-1)); + }; + + auto post_graph_checker = [](Graph& graph) { + // SkipLayerNormalization should NOT have been created because gamma/beta are 3D. + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["com.microsoft.SkipLayerNormalization"] == 0); + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 17, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + TEST_F(GraphTransformationTests, GroupQueryAttentionFusionTest) { TestGQAFusion(MODEL_FOLDER "fusion/gqa_fusion_quantized_simple.onnx", 1, 0, logger_.get()); TestGQAFusion(MODEL_FOLDER "fusion/gqa_fusion_different_head_sizes.onnx", 0, 1, logger_.get());