Add pre-layer normalization support to attention fusion#27418
Add pre-layer normalization support to attention fusion#27418tianleiwu merged 2 commits intomicrosoft:mainfrom
Conversation
|
@microsoft-github-policy-service agree |
|
Please merge main to fix conflict. Issues and ConcernsMedium
Minor
|
The attention fusion in FusionAttention.fuse() assumed a post-LN architecture where every LayerNormalization has an Add parent (the residual connection from the previous block). In pre-LN models the first block's LayerNormalization is fed directly by a graph input, causing the fusion to bail out. This commit makes four changes to fusion_attention.py: - Allow LayerNormalization with a graph-input parent instead of returning early - Include graph inputs when collecting the residual "other_inputs" - Extend the child-LN resolution to SkipLayerNormalization anchors so root_input is redirected to the first LN's output - Guard the output_name_to_node lookup for graph inputs that are not node outputs A new test graph generator (create_bert_attention_pre_ln) and test (test_attention_fusion_pre_ln) verify the fusion fires on a minimal pre-LN first-block model. Fixes microsoft#11684
…erse-add tests, strengthen assertions, use tempfile - Clarify Change 1 comment to explain that QKV matching still fails from the first LN anchor; real fusion happens from the second anchor. - Add test_attention_fusion_pre_ln_with_skiplayernorm exercising Change 3 (SkipLayerNormalization anchor path when fuse_skip_layer_norm runs first). - Add test_attention_fusion_pre_ln_reverse_add_order exercising switch_add_inputs=True for both Add input orderings. - Strengthen all pre-LN test assertions: verify exactly 1 Attention node and num_heads=2 attribute (matching BART SDPA test pattern). - Replace dir="." with tempfile.mkdtemp() to avoid CI path issues.
9c0b270 to
9547f14
Compare
|
Thanks for the thorough review! I've rebased onto main and addressed all five items: Medium #1 — Change 1 comment: Updated the comment to clearly explain that QKV matching will still fail from this (first) LN anchor because its inputs are weights, not the QKV projection path. The real fusion happens when Medium #2 — SkipLN anchor test: Added Minor #3 — Stronger assertions: All three pre-LN tests now verify Minor #4 — Minor #5 — Temp file path: Replaced All 13 tests in test_attention_fusion.py pass, ruff clean. |
There was a problem hiding this comment.
Test temp-dir cleanup (Nice to have):
- The new tests switched from
dir = "."totempfile.mkdtemp(), which is a good fix for CI path stability. - To avoid leaving empty temp directories behind, consider using
tempfile.TemporaryDirectory()context managers instead ofmkdtemp()+ manual file delete.
|
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
Description
Add support for pre-layer normalization (pre-LN) transformer architectures in the Python attention fusion optimizer.
Motivation and Context
Fixes #11684
Pre-LN models (used in GPT-3, ViT variants, and many modern architectures) apply LayerNormalization before attention rather than after. The first block of a pre-LN model has no
Addnode before its firstLayerNormalization— its input comes directly from a graph input. This causedFusionAttention.fuse()to bail out early because it assumed everyLayerNormalizationanchor has anAddparent (the residual connection from the previous block).This PR makes four surgical changes to
fusion_attention.pyso that pre-LN first-block models fuse correctly, while preserving all existing post-LN behavior:Addparent is found, check whether the input is a graph input and continueother_inputsloop previously skipped anything not inoutput_name_to_node; graph inputs are now recognizedfuse_skip_layer_norm()runs, the anchor becomesSkipLayerNormalization; the redirect fromroot_input(graph input) to the first LN's output now fires for SkipLN anchors toooutput_name_to_nodelookup — graph inputs are not node outputs, so the dictionary access is now guardedChanges
onnxruntime/python/tools/transformers/fusion_attention.py— 4 targeted edits toFusionAttention.fuse()onnxruntime/test/python/transformers/bert_model_generator.py— newcreate_bert_attention_pre_ln()test graph generatoronnxruntime/test/python/transformers/test_attention_fusion.py— newtest_attention_fusion_pre_ln()testTest Plan
test_attention_fusion_pre_lnpasses — verifiesAttentionfused op appears in the optimized graph