Skip to content

Add pre-layer normalization support to attention fusion#27418

Merged
tianleiwu merged 2 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/feat/pre-layer-norm-fusion
Mar 2, 2026
Merged

Add pre-layer normalization support to attention fusion#27418
tianleiwu merged 2 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/feat/pre-layer-norm-fusion

Conversation

@Rishi-Dave
Copy link
Copy Markdown
Contributor

Description

Add support for pre-layer normalization (pre-LN) transformer architectures in the Python attention fusion optimizer.

Motivation and Context

Fixes #11684

Pre-LN models (used in GPT-3, ViT variants, and many modern architectures) apply LayerNormalization before attention rather than after. The first block of a pre-LN model has no Add node before its first LayerNormalization — its input comes directly from a graph input. This caused FusionAttention.fuse() to bail out early because it assumed every LayerNormalization anchor has an Add parent (the residual connection from the previous block).

This PR makes four surgical changes to fusion_attention.py so that pre-LN first-block models fuse correctly, while preserving all existing post-LN behavior:

  1. Allow LN with graph-input parent — instead of returning early when no Add parent is found, check whether the input is a graph input and continue
  2. Include graph inputs in residual collection — the other_inputs loop previously skipped anything not in output_name_to_node; graph inputs are now recognized
  3. Extend child-LN resolution to SkipLN anchors — after fuse_skip_layer_norm() runs, the anchor becomes SkipLayerNormalization; the redirect from root_input (graph input) to the first LN's output now fires for SkipLN anchors too
  4. Guard output_name_to_node lookup — graph inputs are not node outputs, so the dictionary access is now guarded

Changes

  • onnxruntime/python/tools/transformers/fusion_attention.py — 4 targeted edits to FusionAttention.fuse()
  • onnxruntime/test/python/transformers/bert_model_generator.py — new create_bert_attention_pre_ln() test graph generator
  • onnxruntime/test/python/transformers/test_attention_fusion.py — new test_attention_fusion_pre_ln() test

Test Plan

  • New unit test test_attention_fusion_pre_ln passes — verifies Attention fused op appears in the optimized graph
  • Lintrunner passes on all changed files (no lint issues)
  • Changes are minimal and scoped to the pre-LN first-block gap

@Rishi-Dave
Copy link
Copy Markdown
Contributor Author

@microsoft-github-policy-service agree

@tianleiwu
Copy link
Copy Markdown
Contributor

Please merge main to fix conflict.

Issues and Concerns

Medium

  1. Change 1 is misleading / potentially dead code
    Change 1 - Allow LN with graph-input parent (Line ~896)
    When fuse() is called for the first LN anchor and start_node = normalize_node, the match_parent_path for qkv_nodes will immediately fail because the LN's inputs are [input_1, LN_weight, LN_bias], not the QKV path. The function returns via the qkv_nodes is None path. The new elif branch changes where the function returns, not whether it returns. This is functionally harmless, but will confuse future maintainers who expect this branch to enable fusion. Either:

    • Add a comment like # Note: qkv_nodes matching will still fail from the first LN anchor; the real fix is Changes #2-#4 acting on the second LN/SkipLN anchor
    • Or remove the branch entirely (the else: return it replaces was already correct for the first LN).
  2. Test coverage gap — SkipLayerNorm anchor path (Change 3) is untested
    Chang 3 - Extend child-LN resolution to SkipLN anchors (Line ~952)
    The test model goes through LayerNormalization as the second anchor. Change 3 targets the SkipLayerNormalization case (when fuse_skip_layer_norm runs first). There is no test that exercises this path. Consider adding a second test variant that enables SkipLayerNorm fusion before attention fusion, or document why it is deferred.

Minor

  1. Test assertions are weak
    All peer tests in the file validate both the presence of Attention and its attributes (e.g. num_heads). test_attention_fusion_pre_ln only checks assertIn("Attention", op_types). Recommend adding:

    attention_nodes = [n for n in optimized_model.model.graph.node if n.op_type == "Attention"]
    self.assertEqual(len(attention_nodes), 1)
    num_heads_attr = next(a for a in attention_nodes[0].attribute if a.name == "num_heads")
    self.assertEqual(num_heads_attr.i, 2)  # matches create_bert_attention_pre_ln default
  2. switch_add_inputs is accepted but never passed from the test
    The generator supports switch_add_inputs=True/False for both QKV matmul Add inputs. The test only calls create_bert_attention_pre_ln() with defaults. Both orderings should be tested (as is done for create_bert_attention). This is particularly important because match_parent_path uses [None, None, ...] wildcards for Add inputs.

  3. Temporary file path uses "." (current directory)

    dir = "."
    model_path = os.path.join(dir, "pre_ln_attention.onnx")

    Other tests in the same file use patterns like os.path.dirname(__file__) or tempfile.mkdtemp(). Using "." means the file lands wherever the test runner's working directory is, which can cause permission errors or conflicts in CI environments.

The attention fusion in FusionAttention.fuse() assumed a post-LN
architecture where every LayerNormalization has an Add parent (the
residual connection from the previous block).  In pre-LN models the
first block's LayerNormalization is fed directly by a graph input,
causing the fusion to bail out.

This commit makes four changes to fusion_attention.py:
- Allow LayerNormalization with a graph-input parent instead of
  returning early
- Include graph inputs when collecting the residual "other_inputs"
- Extend the child-LN resolution to SkipLayerNormalization anchors
  so root_input is redirected to the first LN's output
- Guard the output_name_to_node lookup for graph inputs that are
  not node outputs

A new test graph generator (create_bert_attention_pre_ln) and test
(test_attention_fusion_pre_ln) verify the fusion fires on a minimal
pre-LN first-block model.

Fixes microsoft#11684
…erse-add tests, strengthen assertions, use tempfile

- Clarify Change 1 comment to explain that QKV matching still fails from
  the first LN anchor; real fusion happens from the second anchor.
- Add test_attention_fusion_pre_ln_with_skiplayernorm exercising Change 3
  (SkipLayerNormalization anchor path when fuse_skip_layer_norm runs first).
- Add test_attention_fusion_pre_ln_reverse_add_order exercising
  switch_add_inputs=True for both Add input orderings.
- Strengthen all pre-LN test assertions: verify exactly 1 Attention node
  and num_heads=2 attribute (matching BART SDPA test pattern).
- Replace dir="." with tempfile.mkdtemp() to avoid CI path issues.
@Rishi-Dave Rishi-Dave force-pushed the rishidave/feat/pre-layer-norm-fusion branch from 9c0b270 to 9547f14 Compare February 28, 2026 05:23
@Rishi-Dave
Copy link
Copy Markdown
Contributor Author

Thanks for the thorough review! I've rebased onto main and addressed all five items:

Medium #1 — Change 1 comment: Updated the comment to clearly explain that QKV matching will still fail from this (first) LN anchor because its inputs are weights, not the QKV projection path. The real fusion happens when fuse() is called again from the second LN/SkipLN anchor.

Medium #2 — SkipLN anchor test: Added test_attention_fusion_pre_ln_with_skiplayernorm which enables enable_skip_layer_norm = True so the optimizer fuses Add + LayerNorm into SkipLayerNormalization before attention fusion runs. This directly exercises the Change 3 code path. The test passes — attention fusion correctly anchors on the SkipLayerNormalization node.

Minor #3 — Stronger assertions: All three pre-LN tests now verify len(attention_nodes) == 1 and num_heads == 2, following the BART SDPA test pattern.

Minor #4switch_add_inputs: Added test_attention_fusion_pre_ln_reverse_add_order which calls create_bert_attention_pre_ln(switch_add_inputs=True) to test both Add input orderings.

Minor #5 — Temp file path: Replaced dir = "." with tempfile.mkdtemp() in all three pre-LN tests.

All 13 tests in test_attention_fusion.py pass, ruff clean.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test temp-dir cleanup (Nice to have):

  • The new tests switched from dir = "." to tempfile.mkdtemp(), which is a good fix for CI path stability.
  • To avoid leaving empty temp directories behind, consider using tempfile.TemporaryDirectory() context managers instead of mkdtemp() + manual file delete.

@tianleiwu
Copy link
Copy Markdown
Contributor

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

@tianleiwu tianleiwu enabled auto-merge (squash) March 2, 2026 07:37
@tianleiwu tianleiwu merged commit 4612613 into microsoft:main Mar 2, 2026
89 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add transformer optimization for pre layer normalization?

2 participants