Skip to content

Support broadcasting skip shapes in SkipLayerNorm fusion#27489

Merged
tianleiwu merged 4 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/fix/skiplayernorm-broadcast-shapes
Mar 3, 2026
Merged

Support broadcasting skip shapes in SkipLayerNorm fusion#27489
tianleiwu merged 4 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/fix/skiplayernorm-broadcast-shapes

Conversation

@Rishi-Dave
Copy link
Copy Markdown
Contributor

Summary

  • Allow SkipLayerNormalization fusion when Add inputs have broadcast-compatible shapes
  • Add get_skip_index() to identify skip tensor and ensure correct input ordering
  • Add tests for 2D (S, H) and 3D (1, S, H) broadcast skip shapes

Motivation

Fixes #27488

The Python optimizer's FusionSkipLayerNormalization rejects the Add → LayerNormalization fusion when the two Add inputs have different but broadcast-compatible shapes. The C++ SkipLayerNormalization kernel already supports broadcasting (skip can be 2D or 3D with last two dims matching input), but the Python fusion used an exact shape equality check that blocked these cases. This resolves the existing TODO comments from @tianleiwu.

Changes

  • fusion_skiplayernorm.py: Added _is_broadcast_skip() helper and get_skip_index() method (following the pattern from FusionSkipGroupNorm). Replaced strict compare_shape() equality check with broadcast-aware logic. Used skip_index to ensure the full-sized input goes to position 0 and the skip to position 1 in the fused node.
  • test_skip_layer_norm_fusion.py: Added create_broadcast_test_model() and 4 new test cases covering 2D/3D broadcast skip shapes in both Add input positions.

Test Plan

  • All 6 existing test_skip_layer_norm_fusion tests pass (no regressions)
  • All 4 new broadcast tests pass
  • ruff check passes on changed files

Allow SkipLayerNormalization fusion when the Add inputs have
broadcast-compatible shapes. The C++ kernel already supports skip
tensors of shape (1, S, H) or (S, H) when input is (B, S, H), but
the Python fusion was rejecting these due to an exact shape equality
check.

Add get_skip_index() to identify which Add input is the skip tensor
and ensure correct input ordering in the fused node. Add tests for
2D and 3D broadcast skip shapes in both input positions.
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is correct and conservative. The fusion only activates for the well-defined broadcast patterns that the C++ kernel explicitly supports.

Suggestions

# Severity Description
1 Low TODO partial removal: The TODO about subgraph shape inference is still valid. Consider updating it to just # TODO(tianleiwu): support subgraph in shape inference. instead of removing entirely.
2 Low Redundant length check: len(shape_a) == 3 in get_skip_index() is redundant with the check inside _is_broadcast_skip(). Could simplify to just call _is_broadcast_skip(shape_a, shape_b) directly.
3 Low Typo in test: epsionepsilon in create_broadcast_test_model(). (Pre-existing issue in create_test_model() too.)
4 Informational The C++ graph optimizer (skip_layer_norm_fusion.cc CheckFirstAdd()) still requires both Add inputs to be 3D with identical shapes. This PR only updates the Python optimizer. If the C++ optimizer also encounters broadcast models, it would still reject them. This is out of scope for this PR but worth noting for follow-up.
5 Informational The _broadcast return value from get_skip_index() is unused. If there's no planned use for it, it could be simplified to just return skip_index. However, maintaining parity with FusionSkipGroupNorm.get_skip_index() (which also returns broadcast) is a reasonable design choice.

Missing Test Coverage (minor)

  • No test for the negative case: a model where shapes are incompatible (e.g., different hidden sizes) to verify fusion is correctly rejected.
  • No test for add_graph_output=True with broadcast shapes (existing tests cover this for non-broadcast cases).
  • No test for SimplifiedLayerNormalization (RMS LayerNorm) with broadcast shapes.

- Keep subgraph shape inference TODO (separate concern from broadcasting)
- Remove redundant len==3 checks in get_skip_index (already in _is_broadcast_skip)
- Fix epsion typo in test model generators (epsilon)
- Add test: broadcast with add_graph_output=True
- Add test: incompatible shapes correctly rejected
- Add test: SimplifiedLayerNormalization (RMS LayerNorm) with broadcast skip
@tianleiwu
Copy link
Copy Markdown
Contributor

@Rishi-Dave, There is test error in CI pipeline: transformers/test_optimizer_huggingface_bert.py::TestHuggingfaceBertModelOptimization::test_distillbert

You can reproduce it by:

pip install -r tools/ci_build/requirements/transformers-test/requirements.txt
cd test/python/transformers

# May need replace fusion_skiplayernorm.py under site packages / onnxruntime/transformers/ in installed wheel with yours before running test.
python transformers/test_optimizer_huggingface_bert.py

The broadcast-compatible shape check was too eager — it fused the
embedding Add+LayerNorm (word_embedding (B,S,H) + position_embedding
(1,S,H)) as SkipLayerNormalization, consuming nodes that
EmbedLayerNormalization needs later in the pipeline.

Changes:
- Add Gather guard: when broadcast is detected, skip fusion if either
  Add input comes from a Gather node (embedding lookup pattern)
- Remove incorrect else:return when shape_infer_helper is None — restore
  old fallthrough behavior (fuse with default skip_index=1)
- Fix incompatible shapes test to use (1,1,4) instead of (3,5) so shape
  inference succeeds and the shape rejection logic is actually exercised

Verified: all 4 tests in test_optimizer_huggingface_bert.py pass
(BERT, DistilBERT, RoBERTa, XLM-RoBERTa) plus all 13 unit tests.
@tianleiwu
Copy link
Copy Markdown
Contributor

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

Restore the conservative guard that skips SkipLayerNormalization fusion
when shape_infer_helper is None (shape inference failed). This was
present in the original code and prevents fusing unsupported patterns
in models with subgraphs or complex ops where shape inference cannot
resolve all shapes.
@tianleiwu
Copy link
Copy Markdown
Contributor

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

@tianleiwu tianleiwu enabled auto-merge (squash) March 2, 2026 07:40
@tianleiwu tianleiwu merged commit 12814f7 into microsoft:main Mar 3, 2026
93 of 101 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SkipLayerNorm fusion rejects broadcast-compatible skip shapes

2 participants