Add RotaryEmbedding fusion for Qwen3 on-the-fly RoPE patterns#27590
Conversation
|
Below is AI analysis (some might not be correct): Modeling Code Alignment Check (vs
|
| Category | Issue | Severity | Action |
|---|---|---|---|
| max_seq_len = 2048 | Causes OOB memory access for sequences > 2048 | Critical | Increase to 131072 or make configurable |
| inv_freq tracing fallback | Uses wrong name when Cast/Expand nodes present | Medium (latent) | Track leaf input during traversal |
| Test coverage gap | Test model skips Expand/Cast in inv_freq path | Low | Add variant test |
| Redundant Concat + truncation | Dead computation in cache generation | Low | Simplify for clarity |
| Missing return (pre-existing) | Fallthrough after mismatched cache paths | Low | Not introduced in this PR |
The fusion logic is structurally sound and the modeling code alignment is verified correct. The critical blocker is the max_seq_len = 2048 limit which will cause runtime crashes for real Qwen3 inference beyond 2048 tokens.
|
Thanks for the thorough analysis! Addressed all items in 3a2cd4b: Critical (#1) — max_seq_len = 2048 → 131072: Increased to 131072 to cover most LLM contexts (Qwen3 default is 32768, many models go up to 128k). Memory cost for head_dim=128 is ~64 MB — very modest. Medium (#2) — inv_freq tracing bug: Fixed by tracking the leaf input name during traversal through Cast/Expand/Where/Unsqueeze nodes, so the correct initializer name is used even when Low (#3) — Test coverage gap: Added Low (#4) — Redundant Concat + truncation: Simplified to use Code scanning (#5): Removed the unnecessary premature assignment of All 16 tests pass (14 existing + 2 new Qwen3 RoPE tests). Lintrunner clean. |
There was a problem hiding this comment.
Pull request overview
This PR extends the ONNX Runtime transformer optimizer (FusionRotaryEmbeddings) to handle Qwen3's on-the-fly rotary position embedding (RoPE) computation pattern, where cos/sin are computed from inv_freq at runtime via MatMul rather than being looked up from a pre-computed cache.
Changes:
- Add new path patterns (
sin_path_5/cos_path_5,rotate_half_x2_path_2_3/_2_4,rotate_half_x1_path_2_3/_2_4) for Qwen3's Cast-tolerant RoPE pattern - Add
create_cos_sin_cache_from_on_the_fly_rope()helper that extractsinv_freq, computes cos/sin caches at optimization time, and adds them as model initializers - Add
create_qwen3_decoder_layer()test generator and new tests for both on-the-fly RoPE and the Cast+Expand+Where variant
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
fusion_rotary_attention.py |
Core changes: new path patterns for Qwen3 RoPE, on-the-fly cache generation helper, updated fuse() to handle on-the-fly vs. cache-based paths separately |
optimizer.py |
Registers "qwen3" model type with Gpt2OnnxModel and opt_level=0 |
fusion_options.py |
Sets Qwen3-specific defaults: disables EmbedLayerNorm, uses NoMask attention |
fusion_skiplayernorm.py |
Removes early return when symbolic shape inference fails, allowing Skip LN fusion with skip_index=1 default |
qwen3_model_generator.py |
New test graph generator for Qwen3 decoder layer with optional on-the-fly RoPE and Expand/Where inv_freq path |
test_attention_fusion.py |
Adds 2 new tests for on-the-fly RoPE fusion and its Cast+Expand+Where variant |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
onnxruntime/python/tools/transformers/fusion_rotary_attention.py
Outdated
Show resolved
Hide resolved
onnxruntime/python/tools/transformers/fusion_rotary_attention.py
Outdated
Show resolved
Hide resolved
onnxruntime/python/tools/transformers/fusion_rotary_attention.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
onnxruntime/python/tools/transformers/fusion_rotary_attention.py
Outdated
Show resolved
Hide resolved
onnxruntime/python/tools/transformers/fusion_rotary_attention.py
Outdated
Show resolved
Hide resolved
tianleiwu
left a comment
There was a problem hiding this comment.
The core logic is correct and well-tested. The fusion_skiplayernorm.py change is safe given the pre-existing checks. The on-the-fly RoPE fusion is a valuable addition for Qwen3 model optimization.
Minor Suggestions
fusion_skiplayernorm.pybehavioral change (Whenshape_infer_helperisNone(shape inference failed), previously the fusion was skipped entirely with an early return. Now it proceeds withskip_index=1as a safe default) affects all model types. Consider adding more specific comments about the safety guarantee.- Hardcoded
max_seq_len=131072— acceptable for now but may need to be configurable in the future. - No negative test cases — consider adding tests for malformed graphs or unsupported variants to ensure graceful fallback.
- No numerical validation in tests — the tests only count fused nodes but don't verify that the generated cos/sin caches match expected values.
|
@Rishi-Dave, there are some CI pipelines failed, please take a look. For example, there is test error in transformers/test_attention_fusion.py::TestFusion::test_gpt2_attention_no_past_fusion You can reproduce this like (need build and reinstall wheel first): |
09507b2 to
190c66e
Compare
|
Thanks for flagging! The CI failure in
The fallback was incorrectly fusing nodes in the GPT-2 no-past graph where shape inference fails. Fix: Rebased onto latest |
0ed8785 to
bfc0ffc
Compare
|
Thanks for the suggestions! Addressed in 78dc48f: Negative test (#3): Added Numerical validation (#4): Added All 19 tests pass (17 existing + 2 new). Lintrunner clean. |
Extend FusionRotaryEmbeddings to handle Qwen3's on-the-fly rotary position embedding computation, where cos/sin values are computed from inv_freq at runtime instead of being looked up from a pre-computed cache. Changes: - Add Cast-tolerant rotate_half path patterns for TorchScript exports that insert Cast nodes between Unsqueeze and Div - Add sin_path_5/cos_path_5 patterns matching the on-the-fly computation: MatMul → Transpose → Concat → Cos/Sin → Mul(scaling) → Unsqueeze → Mul, with optional Cast variant - Add create_cos_sin_cache_from_on_the_fly_rope() helper that extracts inv_freq weights, computes cos/sin caches as initializers, and traces position_ids from the graph - Handle per-layer vs shared node removal correctly (only remove per-layer Unsqueeze/outer Mul; shared MatMul/Cos/Sin nodes are pruned automatically) - Update qwen3_model_generator.py with full RoPE computation graph - Add test_qwen3_rotary_embedding_fusion verifying 2 RotaryEmbedding nodes are fused Verified on real Qwen3-Embedding-0.6B: 56 RotaryEmbedding fused (28 layers × 2), reducing 7416 → 4661 nodes (37% reduction).
…cache - Increase max_seq_len from 2048 to 131072 to prevent OOB memory access for sequences beyond 2048 tokens (Qwen3 default is 32768) - Fix inv_freq tracing to track leaf input name through Cast/Expand/Where /Unsqueeze nodes, preventing wrong fallback name when intermediate nodes are present - Simplify cache computation: use freqs directly instead of redundant Concat(freqs, freqs) followed by truncation to first half - Remove unnecessary premature variable assignments flagged by code scanning (position_ids_from_sin/cos_path) - Add test_qwen3_rotary_embedding_fusion_with_expand covering the Cast → Expand → Where traversal path in inv_freq tracing
- Fix inv_freq tracing through Where nodes: follow input[1] (true branch / data path) instead of input[0] (condition). Where has 3 inputs [condition, x, y] and inv_freq flows through x. - Use numpy_helper.from_array for cache tensor serialization instead of flatten().tolist(), avoiding intermediate Python list for ~8M float values - Remove unused sin_path parameter from create_cos_sin_cache_from_on_the_fly_rope (only cos_path is used) - Remove unused max_seq_len parameter from _on_the_fly_rope_nodes
- Early-return from create_cos_sin_cache_from_on_the_fly_rope when cos/sin cache initializers already exist, avoiding redundant 131072 × head_dim/2 cos/sin computation on every layer's fusion - Guard per-layer node removal with single-consumer check to prevent removing shared nodes that still have other consumers
Remove the shape-inference-failed fallback in FusionSkipLayerNormalization that was causing test_gpt2_attention_no_past_fusion to fail — the fallback allowed SkipLayerNorm fusion without shape validation, which incorrectly fused nodes in the GPT-2 no-past graph. Update Qwen3 test assertions to expect 4 SimplifiedLayerNormalization (no SkipSLN fusion when shape inference is unavailable for the synthetic test graph).
Address reviewer feedback: add test for graceful fallback when inv_freq is a dynamic graph input (not an extractable initializer), and add numerical validation that verifies cos/sin cache values match the expected mathematical computation at multiple positions.
78dc48f to
41f3a17
Compare
|
@tianleiwu — Rebased onto latest upstream/main and force-pushed. All feedback has been addressed: Your suggestions:
CI fix:
Test results (post-rebase):
Ready for re-review when you have a chance. |
|
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
Description
Extend
FusionRotaryEmbeddingsto handle Qwen3's on-the-fly rotary position embedding computation, where cos/sin values are computed frominv_freqat runtime instead of being looked up from a pre-computed cache.This is a follow-up to #27556 (Qwen3 basic model type support). Depends on #27556.
Part of #25083.
Motivation and Context
Qwen3 models (ranked 4th on MTEB) compute RoPE differently from existing supported models (Phi, LLaMA, etc.). Instead of pre-computing cos/sin caches and looking them up via
Gather(cache, position_ids), Qwen3 computes them on-the-fly:Additionally, TorchScript exports of Qwen3 insert
Castnodes in therotate_halfpattern (fromtorch.floor_dividetracing), which the existing path patterns don't account for.Changes
fusion_rotary_attention.py:rotate_halfpath patterns (rotate_half_x2_path_2_3,_2_4,rotate_half_x1_path_2_3,_2_4) that allow 1-2 Cast nodes between Unsqueeze and Div in the dynamic Slice index computationsin_path_5/cos_path_5patterns matching the on-the-fly computation:MatMul → Transpose → Concat → Cos/Sin → Mul(scaling) → Unsqueeze → Mul, with optional Cast variant (the optimizer's earlier Cast fusion pass may remove the Cast)create_cos_sin_cache_from_on_the_fly_rope()helper that extractsinv_freqweights, computes cos/sin caches as model initializers, and tracesposition_idsfrom the graphqwen3_model_generator.py:include_rope=Trueparameter tocreate_qwen3_decoder_layer()inv_freqinitializer,position_idsinput, MatMul/Transpose/Concat/Cos/Sin/Mul nodes, androtate_halfpattern with dynamic Slice indices (including Cast nodes from floor division)test_attention_fusion.py:test_qwen3_rotary_embedding_fusionverifying 2 RotaryEmbedding nodes are fused along with 3 SimplifiedLayerNormalization and 1 SkipSimplifiedLayerNormalizationVerification
test_attention_fusion.pytests pass (14 existing + 1 new)lintrunner -aclean on all modified files