Support broadcasting skip shapes in SkipLayerNorm fusion#27489
Merged
tianleiwu merged 4 commits intomicrosoft:mainfrom Mar 3, 2026
Merged
Conversation
Allow SkipLayerNormalization fusion when the Add inputs have broadcast-compatible shapes. The C++ kernel already supports skip tensors of shape (1, S, H) or (S, H) when input is (B, S, H), but the Python fusion was rejecting these due to an exact shape equality check. Add get_skip_index() to identify which Add input is the skip tensor and ensure correct input ordering in the fused node. Add tests for 2D and 3D broadcast skip shapes in both input positions.
tianleiwu
reviewed
Feb 27, 2026
onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py
Outdated
Show resolved
Hide resolved
tianleiwu
reviewed
Feb 27, 2026
Contributor
tianleiwu
left a comment
There was a problem hiding this comment.
The logic is correct and conservative. The fusion only activates for the well-defined broadcast patterns that the C++ kernel explicitly supports.
Suggestions
| # | Severity | Description |
|---|---|---|
| 1 | Low | TODO partial removal: The TODO about subgraph shape inference is still valid. Consider updating it to just # TODO(tianleiwu): support subgraph in shape inference. instead of removing entirely. |
| 2 | Low | Redundant length check: len(shape_a) == 3 in get_skip_index() is redundant with the check inside _is_broadcast_skip(). Could simplify to just call _is_broadcast_skip(shape_a, shape_b) directly. |
| 3 | Low | Typo in test: epsion → epsilon in create_broadcast_test_model(). (Pre-existing issue in create_test_model() too.) |
| 4 | Informational | The C++ graph optimizer (skip_layer_norm_fusion.cc CheckFirstAdd()) still requires both Add inputs to be 3D with identical shapes. This PR only updates the Python optimizer. If the C++ optimizer also encounters broadcast models, it would still reject them. This is out of scope for this PR but worth noting for follow-up. |
| 5 | Informational | The _broadcast return value from get_skip_index() is unused. If there's no planned use for it, it could be simplified to just return skip_index. However, maintaining parity with FusionSkipGroupNorm.get_skip_index() (which also returns broadcast) is a reasonable design choice. |
Missing Test Coverage (minor)
- No test for the negative case: a model where shapes are incompatible (e.g., different hidden sizes) to verify fusion is correctly rejected.
- No test for
add_graph_output=Truewith broadcast shapes (existing tests cover this for non-broadcast cases). - No test for
SimplifiedLayerNormalization(RMS LayerNorm) with broadcast shapes.
- Keep subgraph shape inference TODO (separate concern from broadcasting) - Remove redundant len==3 checks in get_skip_index (already in _is_broadcast_skip) - Fix epsion typo in test model generators (epsilon) - Add test: broadcast with add_graph_output=True - Add test: incompatible shapes correctly rejected - Add test: SimplifiedLayerNormalization (RMS LayerNorm) with broadcast skip
Contributor
|
@Rishi-Dave, There is test error in CI pipeline: transformers/test_optimizer_huggingface_bert.py::TestHuggingfaceBertModelOptimization::test_distillbert You can reproduce it by: |
The broadcast-compatible shape check was too eager — it fused the embedding Add+LayerNorm (word_embedding (B,S,H) + position_embedding (1,S,H)) as SkipLayerNormalization, consuming nodes that EmbedLayerNormalization needs later in the pipeline. Changes: - Add Gather guard: when broadcast is detected, skip fusion if either Add input comes from a Gather node (embedding lookup pattern) - Remove incorrect else:return when shape_infer_helper is None — restore old fallthrough behavior (fuse with default skip_index=1) - Fix incompatible shapes test to use (1,1,4) instead of (3,5) so shape inference succeeds and the shape rejection logic is actually exercised Verified: all 4 tests in test_optimizer_huggingface_bert.py pass (BERT, DistilBERT, RoBERTa, XLM-RoBERTa) plus all 13 unit tests.
Contributor
|
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
tianleiwu
reviewed
Mar 2, 2026
Restore the conservative guard that skips SkipLayerNormalization fusion when shape_infer_helper is None (shape inference failed). This was present in the original code and prevents fusing unsupported patterns in models with subgraphs or complex ops where shape inference cannot resolve all shapes.
Contributor
|
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
tianleiwu
approved these changes
Mar 2, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
SkipLayerNormalizationfusion when Add inputs have broadcast-compatible shapesget_skip_index()to identify skip tensor and ensure correct input ordering(S, H)and 3D(1, S, H)broadcast skip shapesMotivation
Fixes #27488
The Python optimizer's
FusionSkipLayerNormalizationrejects theAdd → LayerNormalizationfusion when the two Add inputs have different but broadcast-compatible shapes. The C++SkipLayerNormalizationkernel already supports broadcasting (skip can be 2D or 3D with last two dims matching input), but the Python fusion used an exact shape equality check that blocked these cases. This resolves the existing TODO comments from @tianleiwu.Changes
fusion_skiplayernorm.py: Added_is_broadcast_skip()helper andget_skip_index()method (following the pattern fromFusionSkipGroupNorm). Replaced strictcompare_shape()equality check with broadcast-aware logic. Usedskip_indexto ensure the full-sized input goes to position 0 and the skip to position 1 in the fused node.test_skip_layer_norm_fusion.py: Addedcreate_broadcast_test_model()and 4 new test cases covering 2D/3D broadcast skip shapes in both Add input positions.Test Plan
test_skip_layer_norm_fusiontests pass (no regressions)ruff checkpasses on changed files