Skip to content

feat(qwen): fused RMSNorm+RoPE for Qwen3/3.X family + Liger m-rope default#3680

Merged
winglian merged 7 commits into
axolotl-ai-cloud:mainfrom
thad0ctor:qwen-fused-kernels
May 29, 2026
Merged

feat(qwen): fused RMSNorm+RoPE for Qwen3/3.X family + Liger m-rope default#3680
winglian merged 7 commits into
axolotl-ai-cloud:mainfrom
thad0ctor:qwen-fused-kernels

Conversation

@thad0ctor
Copy link
Copy Markdown
Contributor

@thad0ctor thad0ctor commented May 25, 2026

Description

Generalizes axolotl's existing Gemma 4 fused RMSNorm + RoPE Triton kernel to Qwen3, Qwen3-MoE, Qwen3.5, Qwen3.5-MoE, Qwen3.6 dense, and Qwen3.6-MoE full-attention layers behind a new opt-in cfg.fused_attn_kernel. Qwen3.6 checkpoints are loaded by transformers under the qwen3_5 / qwen3_5_moe model_types, so the same dispatch covers both generations.

This update also folds in the Qwen3-VL fused-attention work from thad0ctor#29. It adds support for model_config_type: qwen3_vl / qwen3_vl_text by fusing the Qwen3-VL text attention q/k path:

q_proj/k_proj -> per-head RMSNorm -> mRoPE

with the existing compile-safe Triton fused_rms_norm_rope kernel. The patched forward preserves the stock Qwen3-VL attention flow after q/k preparation, including cache update, ALL_ATTENTION_FUNCTIONS dispatch, dropout/scaling kwargs, output projection, and Axolotl LoRA apply_qkv / apply_o hooks.

The PR adds a UNIT_OFFSET kernel flag for Qwen3.5's Gemma-style (1.0 + weight) RMSNorm; default off keeps every existing Gemma 4 / Qwen3 caller bit-identical.

It also auto-enables Liger's fused (m-)rope kernel for qwen2_vl / qwen2_5_vl / qwen3_vl / qwen3_vl_moe and _text variants when liger_rope is unset. Previously, axolotl's plugin overrode upstream's rope=True default with None, silently skipping the fused m-rope kernel.

Finally, this fixes a pre-existing latent patch-ordering bug in apply_pre_model_load_patches: patch_self_attn_lora re-parses Attention.forward source via inspect.getsource, so it must run before any patch that replaces forward. The reorder also fixes Gemma 4 + lora_qkv_kernel composition.

Motivation and Context

The goal is to extend the existing Gemma 4 fused RMSNorm+RoPE Triton kernel to the Qwen3 family for added LoRA throughput. The Qwen3 attention forwards are structurally compatible with the existing kernel:

  • Qwen3 / Qwen3-MoE share the kernel's shape contract verbatim and wire up unchanged.
  • Qwen3.5 / Qwen3.5-MoE need one new UNIT_OFFSET: tl.constexpr for their Gemma-style (1.0 + weight) RMSNorm and a chunk-2 split on the gated q_proj.
  • Qwen3-VL text attention has the same q/k RMSNorm + RoPE fusion opportunity, but was previously outside the supported dispatch set and would no-op with fused_attn_kernel: true.

Performance

Qwen3 Dense Benchmark Runs

All runs below use real Qwen/Qwen3-0.6B (28 layers, hidden=1024, head_dim=128) under LoRA r=16 alpha=32 on [q,k,v,o]_proj, bf16, attn_implementation=sdpa. Was = stock eager. Is = fused_attn_kernel: true, optionally combined with torch_compile: true.

Model Hardware Bench S B Was (eager) Is (variant) Delta vs eager
Qwen3-0.6B RTX 3090 (sm_86) tail100, 300 steps x 2 orders 2048 1 330.77 ms / 9304 MB fused: 296.25 ms / 8968 MB +11.7% time / -3.6% mem
Qwen3-0.6B RTX 3090 (sm_86) tail100, 300 steps x 2 orders 2048 1 330.77 ms / 9304 MB fused + compile (default): 403.73 ms / 6270 MB -18.1% time / -32.6% mem
Qwen3-0.6B RTX 3090 (sm_86) tail100, 300 steps x 2 orders 2048 1 330.77 ms / 9304 MB fused + compile (reduce-overhead): 385.04 ms / 7068 MB -14.1% time / -24.0% mem
Qwen3-0.6B RTX 3090 (sm_86) tail100, 200 steps x 2 orders (scaling check) 4096 1 619.35 ms / 17403 MB fused: 558.45 ms / 16731 MB +11.1% time / -3.9% mem
Qwen3-0.6B RTX 3090 (sm_86) e2e median, 50 steps (full fwd + bwd + AdamW) 1024 1 200.75 ms fused: 184.71 ms +8.68% time (peak mem not captured)
Qwen3-0.6B RTX 5090 (sm_120) tail100, 300 steps x 2 orders 2048 1 183.03 ms / 9304 MB fused: 171.04 ms / 8968 MB +7.0% time / -3.6% mem
Qwen3-0.6B RTX 5090 (sm_120) tail100, 300 steps x 2 orders 2048 1 183.03 ms / 9304 MB fused + compile (default): 167.58 ms / 6270 MB +9.2% time / -32.6% mem

Notes:

  • All runs: 0 spikes >2x median, 1 Dynamo frame, 0 graph breaks under compile.
  • 5090 fused + compile also delivers -33% peak memory (9304 -> 6270 MB) due to Inductor's surrounding-op fusion.
  • The 3090 compile regression is a known Inductor autotune-DB gap on sm_86; plain fused remains faster on 3090.

Qwen3-VL Benchmark Runs

This is the benchmark run from thad0ctor#29, using separate fresh processes and real Qwen3-VL-8B text-attention dimensions.

Environment:

  • GPU: NVIDIA GeForce RTX 3090
  • Pinning: CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=2
  • Config source: local Qwen3-VL-8B-Instruct config
  • Attention impl: flash_attention_2
  • dtype: bf16
  • batch size: 1
  • sequence length: 2048
  • hidden size: 4096
  • attention heads: 32
  • KV heads: 8
  • head dim: 128
  • measured op: one Qwen3VLTextAttention forward + backward
  • warmup: 20 steps
  • timed: 500 steps per seed
  • seeds: 1234, 5678
Condition Mean ms/step Median ms/step Tail100 mean Peak allocated Delta vs stock
stock Qwen3-VL attention 16.118 16.081 16.158 0.426 GB baseline
fused RMSNorm+mRoPE 13.535 13.496 13.563 0.384 GB -16.0% time / -9.9% mem

Every fused benchmark run reported forward_module == 'axolotl.monkeypatch.models.qwen3_vl.fused_attn', so the fused_attn_kernel: true path was active and not silently deactivated.

How has this been tested?

Unit / integration: tests/monkeypatch/, tests/kernels/, tests/integrations/test_liger*.

Test surfaces:

  • Per-model parity + backward grad flow (Qwen3 / Qwen3-MoE / Qwen3.5 / Qwen3.5-MoE)
  • Qwen3-VL single-layer fused-vs-stock parity on a tiny Qwen3VLTextModel
  • Qwen3-VL end-to-end text-model forward parity
  • Qwen3-VL backward grad flow through q/k norm weights
  • Qwen3-VL PatchManager dispatch for model_config_type='qwen3_vl' and qwen3_vl_text
  • Kernel UNIT_OFFSET=True parity vs from-scratch reference
  • torch.compile(fullgraph=True) parity for offset and no-offset paths
  • Liger Qwen-VL auto-default across model_config_types including Qwen3-VL variants; explicit liger_rope: false remains respected
  • Patch idempotency (double-apply is a no-op)
  • Transformers signature contract pins (catches future minor-version drift loudly)
  • Gradient-checkpointing composition
  • Flash-Attention 2 composition (skip-if-unavailable)
  • Patch-ordering source-line invariant + LoRA-fused composition with fused-first reverse-order trip-wires
  • Liger RMSNorm composition (fused_attn_kernel: true + liger_rms_norm: true)
  • Cross-device norm weight handling for accelerate-sharded or otherwise misplaced RMSNorm weights
  • PEFT ModulesToSaveWrapper composition for trainable q/k RMSNorm adapter weights
  • Kernel coverage at production head_dim=256
  • attention_mask pass-through and sliding_window kwarg preservation
  • get_text_config-derived dispatch from multimodal Qwen3 checkpoints

Local validation:

CUDA targeted fused-attention suite:
60 passed, 1 xfailed

Qwen3.5 fused-attention sanity check:
12 passed

Qwen3-VL-specific patch/dispatch check:
18 passed

Exact CI-failing Gemma4/Qwen3 fused-attn tests after log-capture fix:
2 passed

Same exact tests after importing axolotl.cli first to reproduce configured Axolotl logging:
2 passed

pre-commit on touched tests:
passed

The maintainer-reported Gemma4/Qwen3 fused-attn CI failures were test harness failures, not implementation failures. In the Modal/docker e2e logs, the expected log messages were present in captured stderr, but caplog.records was empty because Axolotl logging had already been configured and the axolotl logger can run with propagate=False. The tests now attach caplog.handler directly to the module loggers they assert against:

  • axolotl.monkeypatch.lora_kernels for the Gemma4 fused-then-LoRA skip test
  • axolotl.loaders.patch_manager for the unsupported fused_attn_kernel warning test

End-to-end on real checkpoints:

  • Qwen/Qwen3-0.6B save / reload parity (3090): save adapter -> reload into fresh model -> forward parity vs in-memory, max_abs = 0.0
  • Qwen/Qwen3-0.6B 50-step LoRA convergence (3090): eager 13.189 -> 11.996, fused 13.186 -> 11.997, final-loss diff = 0.0005 (bf16 noise); both 33/49 steps monotone-decreasing
  • Qwen/Qwen3-0.6B PatchManager pipeline (3090): patches when cfg.fused_attn_kernel: true, skips when false
  • Qwen3.5-2B checkpoint LoRA parity (3090, 5 steps, full_attention layers exercised): max_abs=0.0020, both finite and decreasing
  • Qwen3-VL-8B-Instruct real config path check: normalized config has model_config_type='qwen3_vl', model_config_type_text='qwen3_vl_text', and fused_attn_kernel=True; PatchManager._apply_model_specific_patches() installs axolotl.monkeypatch.models.qwen3_vl.fused_attn; Liger plugin does not overwrite the fused forward
  • Qwen3-VL-8B-Instruct Liger m-rope auto-default end-to-end (3090): real Qwen3VLForConditionalGeneration loads and forward-passes finite logits with the swapped kernel in place
  • FSDP multi-GPU composition (3x 3090, torchrun --nproc_per_node=3, full_shard + bf16 mixed precision + LoRA + gradient checkpointing): 3-step training completes with finite loss; max peak memory 2.58 GB/rank
  • Qwen3.6-27B dense multimodal checkpoint, FSDP multi-GPU (4x 24 GB GPU): patch installs on full-attention layers only; 3-step training completes with finite, monotonically decreasing loss; peak memory 13.4 GB/rank
  • Qwen3.6-35B-A3B MoE multimodal checkpoint, FSDP multi-GPU (4x 24 GB GPU): patch installs on full-attention layers only; 3-step training completes with finite loss; peak memory 13.5 GB/rank
  • A Qwen3-VL-8B LoRA CPT checkpoint created before the Qwen3-VL fused patch was also resumed successfully as a compatibility smoke test; this is not used as benchmark evidence above.

AI Usage Disclaimer

Yes. OpenAI ChatGPT/Codex assisted with code changes, test scaffolding, local benchmarking, validation, and drafting this PR summary. The implementation was validated locally with the tests and benchmark runs described above.

Screenshots (if appropriate)

N/A

Types of changes

  • New feature (non-breaking) - fused_attn_kernel for Qwen3 / Qwen3-MoE / Qwen3.5 / Qwen3.5-MoE / Qwen3-VL text attention
  • Performance improvement - fused q/k RMSNorm + RoPE/mRoPE path
  • Bug fix - patch-ordering latent bug in apply_pre_model_load_patches; Liger Qwen-VL m-rope auto-default; configured-logging-safe caplog tests
  • Defensive - device-coerce of q_norm/k_norm weights; PEFT ModulesToSaveWrapper unwrap; PatchManager warning when set on unsupported model_config_type
  • Tests
  • Breaking change
  • Documentation update

Social Handles (Optional)

N/A

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 25, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 98b26315-031e-4924-8283-421ab13c05cc

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR implements fused RMSNorm+RoPE Triton kernels for Qwen3/3.5 models and their MoE variants. It refactors the kernel infrastructure to use torch.library, adds model-specific monkeypatches, wires dispatch logic into PatchManager, and provides comprehensive test coverage including compilation safety and composition scenarios.

Changes

Qwen Fused Attention Implementation

Layer / File(s) Summary
Kernel refactoring to torch.library and unit_offset support
src/axolotl/kernels/gemma4_fused_rope.py
Replaces torch.autograd.Function-based autograd with @triton_op and setup_context/register_autograd wiring; extends Triton kernels with UNIT_OFFSET constexpr parameter for conditional weight scaling in both forward and backward passes; updates public API to include unit_offset: bool = False.
Kernel correctness and torch.compile tests
tests/kernels/test_gemma4_fused_rope_compile.py, tests/kernels/test_gemma4_fused_rope_unit_offset.py
Validates forward/backward correctness under torch.compile(fullgraph=True), reference implementations for unit_offset=True path, gradient flow, and numerical closeness with cosine similarity and tolerance-based assertions.
Configuration schema and PatchManager dispatch wiring
src/axolotl/utils/schemas/config.py, src/axolotl/loaders/patch_manager.py
Adds fused_attn_kernel: bool | None config field; introduces _FUSED_ATTN_KERNEL_SUPPORTED model set and _warn_if_fused_attn_unsupported validation in PatchManager; reorders self-attention LoRA patching earlier; wires conditional fused-attention patch dispatch for Qwen3, Qwen3-MoE, Qwen3.5, and Qwen3.5-MoE.
Liger integration for Qwen-VL RoPE handling
src/axolotl/integrations/liger/args.py, src/axolotl/integrations/liger/plugin.py, tests/integrations/test_liger_qwen_vl_rope_default.py
Enhances LigerArgs.liger_rope field with JSON schema documentation; conditionally enables rope=True for specific Qwen-VL model types when unset; validates behavior via parametrized integration tests for both None (auto-enable) and explicit False cases.
Qwen3 dense fused attention monkeypatch
src/axolotl/monkeypatch/models/qwen3/fused_attn.py
Implements patch_qwen3_fused_attn() that replaces Qwen3Attention.forward with fused RMSNorm+RoPE kernel (via fused_rms_norm_rope), supports LoRA QKV/O via apply_qkv, handles device placement for norm weights, and routes attention through HuggingFace's interface.
Qwen3-MoE fused attention monkeypatch
src/axolotl/monkeypatch/models/qwen3_moe/fused_attn.py
Implements patch_qwen3_moe_fused_attn() with _resolve_norm_module for PEFT adapter unwrapping, fused Q/K computation, and MoE-specific gating via sigmoid applied post-attention.
Qwen3.5 and Qwen3.5-MoE fused attention monkeypatches
src/axolotl/monkeypatch/models/qwen3_5/fused_attn.py, src/axolotl/monkeypatch/models/qwen3_5_moe/fused_attn.py
Implement fused attention for Qwen3.5 variants using unit_offset=True in the kernel call; both include PEFT module unwrapping and active-adapter resolution; MoE variant includes gating logic and expert selection.
Comprehensive monkeypatch tests for Qwen3 variants
tests/monkeypatch/test_qwen3_fused_attn.py, tests/monkeypatch/test_qwen3_fused_attn_defensive.py, tests/monkeypatch/test_qwen3_fused_attn_robustness.py
Validates parity between fused and stock attention via cosine similarity and assert_close, backward gradient flow through fused kernels, LoRA composition ordering (LoRA-then-fused succeeds; reverse order fails as expected), PatchManager dispatch, PEFT wrapper handling, padding mask pass-through, gradient checkpointing, CPU-offloaded weights, and FlashAttention2 composition.
Qwen3.5 variant monkeypatch tests
tests/monkeypatch/test_qwen3_5_fused_attn.py
Adds parity and gradient tests for Qwen3.5 dense and MoE; validates LoRA composition after applying self-attention patch first; tests Liger RMSNorm stub compatibility; confirms PatchManager dispatch for both _text suffixed model types.
Documentation and example configurations
docs/optimizations.qmd, examples/qwen3/8b-lora-fused-attn.yaml
Adds optimization guide section describing the fused RMSNorm+RoPE feature, fused_attn_kernel: true configuration option, torch.compile safety notes, and hardware-specific guidance; provides LoRA fine-tuning example with fused attention enabled.
Gemma4 fused attention test enhancements
tests/monkeypatch/test_gemma4_fused_attn.py
Adds TestGemma4FusedAttnLoRACompose suite validating LoRA-to-fused and fused-to-LoRA composition with xfail for transformers-version-specific kernel mapping; switches per-layer correctness to torch.testing.assert_close; refines docstrings and cleans up redundant gradient test comments.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • axolotl-ai-cloud/axolotl#3598: Both PRs directly refactor the same src/axolotl/kernels/gemma4_fused_rope.py kernel, updating autograd dispatch and extending behavior with unit_offset and weight handling.
  • axolotl-ai-cloud/axolotl#3442: Both PRs modify PatchManager._apply_model_specific_patches to add Qwen3.5/Qwen3.5-MoE model-specific dispatch logic (this PR: fused-attn-kernel gating; referenced PR: sample-packing/FLA patches), related at patch orchestration level.

Suggested labels

under review

Suggested reviewers

  • winglian
  • NanoCode012
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 18.70% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and specifically summarizes the main change: adding fused RMSNorm+RoPE kernel support for Qwen3/3.X family and enabling Liger m-rope default, which aligns with the primary objective of generalizing fused kernels across these models.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@thad0ctor thad0ctor force-pushed the qwen-fused-kernels branch from 37c3186 to f201233 Compare May 25, 2026 01:37
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (5)
tests/kernels/test_gemma4_fused_rope_unit_offset.py (1)

29-35: ⚡ Quick win

Add a unit_offset test where n_rot < D to cover the partial-rotary branch.

Current unit_offset=True cases only exercise n_rot == D, so the pass-through path (col >= n_rot) is untested. Please add at least one case (for example D=128, n_rot=64) with forward parity and gradient checks.

Also applies to: 50-55, 66-70, 100-105

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/kernels/test_gemma4_fused_rope_unit_offset.py` around lines 29 - 35,
Add a test case inside test_gemma4_fused_rope_unit_offset that sets
unit_offset=True but uses n_rot < D (e.g., D=128, n_rot=64) so the
partial-rotary branch (columns >= n_rot) is exercised; construct x, weight, cos,
sin with matching shapes (x: B,S,H,D; cos/sin: B,S,D), run the same forward
parity check and backward gradient checks used in the existing cases, and assert
equivalence with the non-fused/reference implementation; duplicate this new
scenario for the other unit_offset=True blocks so all forward parity and
gradient checks cover n_rot < D as well.
src/axolotl/monkeypatch/models/qwen3_moe/fused_attn.py (1)

1-15: ⚡ Quick win

Please normalize comments to one short WHY-only line in this file.

The current docstrings/comment blocks are more verbose than the repository rule for src/axolotl files.

As per coding guidelines: “Only add comments when explaining the WHY behind non-obvious logic... Comments should be a maximum of one short line”.

Also applies to: 65-65

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/monkeypatch/models/qwen3_moe/fused_attn.py` around lines 1 - 15,
The file-level docstring and any multi-line comments (including the docstring
for _resolve_norm_module and the comment at line ~65) are too verbose; replace
each with a single short WHY-only comment explaining the non-obvious intent
(e.g., why we unwrap PEFT ModulesToSaveWrapper) and remove extra explanatory
sentences so comments are one short line; update the top module docstring to a
single-line WHY summary and convert multi-line inline comments to one-line WHY
comments while keeping function names like _resolve_norm_module unchanged.
src/axolotl/monkeypatch/models/qwen3/fused_attn.py (1)

1-15: ⚡ Quick win

Align comments/docstrings with the src/axolotl one-line WHY-only rule.

Please shorten/trim these comments to single-line WHY-focused notes; the current module/function docstrings and long inline explanation are more descriptive than the repo rule allows.

As per coding guidelines: “Only add comments when explaining the WHY behind non-obvious logic... Comments should be a maximum of one short line”.

Also applies to: 63-63

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/monkeypatch/models/qwen3/fused_attn.py` around lines 1 - 15, The
module-level docstring and the `_resolve_norm_module` docstring/explanatory
comment are too verbose and violate the one-line WHY-only rule; replace them
with concise single-line comments that explain why the fused kernel exists and
why `_resolve_norm_module` unwraps `ModulesToSaveWrapper` to access the
trainable adapter (e.g., "Why: fuse q_norm/k_norm + RoPE into one Triton kernel
for performance" and "Why: unwrap ModulesToSaveWrapper to read active adapter
weights"), and remove the longer descriptive text elsewhere (including the
inline comment at line ~63).
src/axolotl/monkeypatch/models/qwen3_5/fused_attn.py (1)

1-15: ⚡ Quick win

Refactor comment/docstring verbosity to match src/axolotl policy.

Please keep comments as single-line WHY notes only; current multi-line descriptive text is outside the repo guideline.

As per coding guidelines: “Only add comments when explaining the WHY behind non-obvious logic... Comments should be a maximum of one short line”.

Also applies to: 50-50, 77-77

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/monkeypatch/models/qwen3_5/fused_attn.py` around lines 1 - 15,
The module-level multi-line docstring and other verbose comments should be
replaced with a single short WHY comment per repo policy: remove the
triple-quoted description at the top and replace it with a one-line comment
explaining the reason for the special fused behavior (e.g., why we need fused
q_norm/k_norm + RoPE for Qwen3.5), and likewise collapse any other multi-line
comments at the indicated spots into single-line WHY notes; update the header
around the _resolve_norm_module function to have only a concise one-line comment
explaining why ModulesToSaveWrapper is unwrapped and what active_adapter
significance is.
src/axolotl/monkeypatch/models/qwen3_5_moe/fused_attn.py (1)

1-15: ⚡ Quick win

Apply the one-line WHY-only comment rule here as well.

This file has the same verbose docstring/comment pattern; please shorten to concise WHY-only comments.

As per coding guidelines: “Only add comments when explaining the WHY behind non-obvious logic... Comments should be a maximum of one short line”.

Also applies to: 50-50, 77-77

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/monkeypatch/models/qwen3_5_moe/fused_attn.py` around lines 1 -
15, The top-file docstring and verbose comments in this module should be
replaced with one-line WHY-only comments: remove the multi-line module docstring
and instead add a single short line explaining why this monkeypatch exists
(e.g., "Workaround to make fused-attention use adapter weights for
Qwen3.5-MoE."); likewise shorten the comment above _resolve_norm_module and the
comments at the locations referenced (around lines 50 and 77) to single-line WHY
explanations that describe the non-obvious reason (not the how) for unwrapping
ModulesToSaveWrapper and any adapter-specific behavior, leaving
implementation/details to the code itself. Ensure references to
_resolve_norm_module, ModulesToSaveWrapper/modules_to_save, and active_adapter
remain clear in the single-line comments.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@src/axolotl/monkeypatch/models/qwen3_5_moe/fused_attn.py`:
- Around line 1-15: The top-file docstring and verbose comments in this module
should be replaced with one-line WHY-only comments: remove the multi-line module
docstring and instead add a single short line explaining why this monkeypatch
exists (e.g., "Workaround to make fused-attention use adapter weights for
Qwen3.5-MoE."); likewise shorten the comment above _resolve_norm_module and the
comments at the locations referenced (around lines 50 and 77) to single-line WHY
explanations that describe the non-obvious reason (not the how) for unwrapping
ModulesToSaveWrapper and any adapter-specific behavior, leaving
implementation/details to the code itself. Ensure references to
_resolve_norm_module, ModulesToSaveWrapper/modules_to_save, and active_adapter
remain clear in the single-line comments.

In `@src/axolotl/monkeypatch/models/qwen3_5/fused_attn.py`:
- Around line 1-15: The module-level multi-line docstring and other verbose
comments should be replaced with a single short WHY comment per repo policy:
remove the triple-quoted description at the top and replace it with a one-line
comment explaining the reason for the special fused behavior (e.g., why we need
fused q_norm/k_norm + RoPE for Qwen3.5), and likewise collapse any other
multi-line comments at the indicated spots into single-line WHY notes; update
the header around the _resolve_norm_module function to have only a concise
one-line comment explaining why ModulesToSaveWrapper is unwrapped and what
active_adapter significance is.

In `@src/axolotl/monkeypatch/models/qwen3_moe/fused_attn.py`:
- Around line 1-15: The file-level docstring and any multi-line comments
(including the docstring for _resolve_norm_module and the comment at line ~65)
are too verbose; replace each with a single short WHY-only comment explaining
the non-obvious intent (e.g., why we unwrap PEFT ModulesToSaveWrapper) and
remove extra explanatory sentences so comments are one short line; update the
top module docstring to a single-line WHY summary and convert multi-line inline
comments to one-line WHY comments while keeping function names like
_resolve_norm_module unchanged.

In `@src/axolotl/monkeypatch/models/qwen3/fused_attn.py`:
- Around line 1-15: The module-level docstring and the `_resolve_norm_module`
docstring/explanatory comment are too verbose and violate the one-line WHY-only
rule; replace them with concise single-line comments that explain why the fused
kernel exists and why `_resolve_norm_module` unwraps `ModulesToSaveWrapper` to
access the trainable adapter (e.g., "Why: fuse q_norm/k_norm + RoPE into one
Triton kernel for performance" and "Why: unwrap ModulesToSaveWrapper to read
active adapter weights"), and remove the longer descriptive text elsewhere
(including the inline comment at line ~63).

In `@tests/kernels/test_gemma4_fused_rope_unit_offset.py`:
- Around line 29-35: Add a test case inside test_gemma4_fused_rope_unit_offset
that sets unit_offset=True but uses n_rot < D (e.g., D=128, n_rot=64) so the
partial-rotary branch (columns >= n_rot) is exercised; construct x, weight, cos,
sin with matching shapes (x: B,S,H,D; cos/sin: B,S,D), run the same forward
parity check and backward gradient checks used in the existing cases, and assert
equivalence with the non-fused/reference implementation; duplicate this new
scenario for the other unit_offset=True blocks so all forward parity and
gradient checks cover n_rot < D as well.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a4617ab6-0d3c-4139-8659-04acecb39b72

📥 Commits

Reviewing files that changed from the base of the PR and between 1d68aca and 37c3186.

📒 Files selected for processing (22)
  • docs/optimizations.qmd
  • examples/qwen3/8b-lora-fused-attn.yaml
  • src/axolotl/integrations/liger/args.py
  • src/axolotl/integrations/liger/plugin.py
  • src/axolotl/kernels/gemma4_fused_rope.py
  • src/axolotl/loaders/patch_manager.py
  • src/axolotl/monkeypatch/models/qwen3/__init__.py
  • src/axolotl/monkeypatch/models/qwen3/fused_attn.py
  • src/axolotl/monkeypatch/models/qwen3_5/fused_attn.py
  • src/axolotl/monkeypatch/models/qwen3_5_moe/__init__.py
  • src/axolotl/monkeypatch/models/qwen3_5_moe/fused_attn.py
  • src/axolotl/monkeypatch/models/qwen3_moe/__init__.py
  • src/axolotl/monkeypatch/models/qwen3_moe/fused_attn.py
  • src/axolotl/utils/schemas/config.py
  • tests/integrations/test_liger_qwen_vl_rope_default.py
  • tests/kernels/test_gemma4_fused_rope_compile.py
  • tests/kernels/test_gemma4_fused_rope_unit_offset.py
  • tests/monkeypatch/test_gemma4_fused_attn.py
  • tests/monkeypatch/test_qwen3_5_fused_attn.py
  • tests/monkeypatch/test_qwen3_fused_attn.py
  • tests/monkeypatch/test_qwen3_fused_attn_defensive.py
  • tests/monkeypatch/test_qwen3_fused_attn_robustness.py

@winglian
Copy link
Copy Markdown
Collaborator

Benchmark: fused kernel vs the realistic baseline (liger_rms_norm)

Setup: isolated q_norm/k_norm + RoPE path (exactly what the kernel replaces), Qwen3-14B shapes (40 q-heads / 8 kv-heads, head_dim 128, B=1), bf16, RTX PRO 6000 Blackwell (sm_120, 188 SM). Median of CUDA-event-timed batches after warmup. Baselines: stock eager
(unfused) and liger_rms_norm: true (fused RMSNorm, eager RoPE — the next-best option a Qwen3 user already has, since Liger has no RoPE kernel for the dense Qwen3 family).

Forward + backward (training-relevant), ms/iter

  ┌────────┬───────┬────────────────┬───────┬────────────────┬────────────────┐
  │ seqlen │ eager │ liger_rms_norm │ fused │ fused vs liger │ fused vs eager │
  ├────────┼───────┼────────────────┼───────┼────────────────┼────────────────┤
  │   1024 │  0.83 │           0.49 │  0.48 │         ~even¹ │           +74% │
  ├────────┼───────┼────────────────┼───────┼────────────────┼────────────────┤
  │   2048 │  1.21 │           0.84 │  0.55 │           +53% │          +120% │
  ├────────┼───────┼────────────────┼───────┼────────────────┼────────────────┤
  │   4096 │  2.84 │           1.36 │  0.86 │           +59% │          +231% │
  ├────────┼───────┼────────────────┼───────┼────────────────┼────────────────┤
  │   8192 │  7.05 │           3.52 │  1.98 │           +78% │          +256% │
  ├────────┼───────┼────────────────┼───────┼────────────────┼────────────────┤
  │  16384 │ 14.96 │           7.54 │  4.12 │           +83% │          +263% │
  ├────────┼───────┼────────────────┼───────┼────────────────┼────────────────┤
  │  32768 │ 31.39 │          15.88 │  8.72 │           +82% │          +260% │
  └────────┴───────┴────────────────┴───────┴────────────────┴────────────────┘

¹ 1k fwd+bwd is sub-ms and clock/noise-dominated (the eager/liger rows bounce run-to-run); treat as ≈ break-even-to-positive.

Forward only (inference / generation), ms/iter

  ┌────────┬───────┬────────────────┬───────┬────────────────┬────────────────┐
  │ seqlen │ eager │ liger_rms_norm │ fused │ fused vs liger │ fused vs eager │
  ├────────┼───────┼────────────────┼───────┼────────────────┼────────────────┤
  │   1024 │ 0.141 │          0.108 │ 0.065 │           +66% │          +117% │
  ├────────┼───────┼────────────────┼───────┼────────────────┼────────────────┤
  │   4096 │ 0.688 │          0.381 │ 0.133 │          +188% │          +418% │
  ├────────┼───────┼────────────────┼───────┼────────────────┼────────────────┤
  │  16384 │ 3.779 │          1.860 │ 0.657 │          +183% │          +475% │
  ├────────┼───────┼────────────────┼───────┼────────────────┼────────────────┤
  │  32768 │ 7.934 │          3.787 │ 1.317 │          +188% │          +503% │
  └────────┴───────┴────────────────┴───────┴────────────────┴────────────────┘

Peak + resident activation VRAM

┌────────┬──────────┬───────┬───────┬───────┬────────────────┬────────────────┐
│ seqlen │  metric  │ eager │ liger │ fused │ fused vs liger │ fused vs eager │
├────────┼──────────┼───────┼───────┼───────┼────────────────┼────────────────┤
│   4096 │ peak     │  488M │  329M │  329M │            +0% │           −33% │
├────────┼──────────┼───────┼───────┼───────┼────────────────┼────────────────┤
│   4096 │ resident │  145M │  0.8M │  0.8M │            +0% │           −99% │
├────────┼──────────┼───────┼───────┼───────┼────────────────┼────────────────┤
│  16384 │ peak     │ 1952M │ 1315M │ 1315M │            +0% │           −33% │
├────────┼──────────┼───────┼───────┼───────┼────────────────┼────────────────┤
│  16384 │ resident │  579M │  3.0M │  3.0M │            +0% │           −99% │
├────────┼──────────┼───────┼───────┼───────┼────────────────┼────────────────┤
│  32768 │ peak     │ 3904M │ 2630M │ 2630M │            +0% │           −33% │
├────────┼──────────┼───────┼───────┼───────┼────────────────┼────────────────┤
│  32768 │ resident │ 1158M │  6.0M │  6.0M │            +0% │           −99% │
└────────┴──────────┴───────┴───────┴───────┴────────────────┴────────────────┘

@thad0ctor
Copy link
Copy Markdown
Contributor Author

@winglian great inputs on the bwd pass and thanks for the affirming bench results!

FYI, I have a follow-on fork PR stacked on this branch that revisits the sm_86> compile : where the earlier numbers showed torch_compile was not helpful on Ampere, the new torch 2.11 runs show vanilla compile on top of this fused path is now ~16.5-17% faster than fused-only, with the new Inductor option plumbing adding a smaller additional tuning gain on top (~0.5% gain).

Please let me know if I should wait for this to merge or PR this now, the PR against my fork if you are interested in looking: thad0ctor#28

@winglian winglian force-pushed the qwen-fused-kernels branch from af7d7fa to 0ee177d Compare May 26, 2026 16:38
thad0ctor and others added 3 commits May 26, 2026 15:05
…3.5-MoE

Generalizes the existing Gemma 4 fused RMSNorm+RoPE Triton kernel to four
new Qwen attention variants, and auto-enables Liger's fused (m-)rope
kernel for the Qwen-VL family. Eager-mode behavior is bit-identical when
the new cfg.fused_attn_kernel flag is unset.

Changes
-------
* New ``cfg.fused_attn_kernel: bool | None`` (default None / off). When
  set, replaces ``q_norm + apply_rotary_pos_emb`` (and the matching K path)
  with a single fused RMSNorm+RoPE Triton kernel launch. Currently wired
  for ``qwen3``, ``qwen3_moe``, ``qwen3_5``, and ``qwen3_5_moe``
  model_config_types. Llama4 is out of scope (complex freqs_cis +
  Llama4TextL2Norm post-RoPE — separate kernel).
* Kernel ``UNIT_OFFSET: tl.constexpr`` flag added to the forward + backward
  Triton kernels for Qwen3.5's Gemma-style ``(1.0 + weight)`` RMSNorm.
  Default ``False`` keeps Gemma 4 / Qwen3 / Qwen3-MoE bit-identical to
  before. Threaded through the triton_op + register_autograd plumbing.
* Refactors ``fused_rms_norm_rope`` / ``fused_rms_norm_noscale`` to
  ``torch.library.triton_op`` + ``register_autograd`` so they trace under
  ``torch.compile(fullgraph=True)``. Validated: 1 Dynamo frame, 0 graph
  breaks. On sm_120 the compile path composes to +9.2% combined,
  −33% peak memory. On sm_86 the surrounding Inductor-generated kernels
  regress — leave ``torch_compile: false`` there; schema description
  documents the per-arch recommendation.
* Liger Qwen-VL auto-default: when ``cfg.liger_rope is None`` and
  model_config_type is one of qwen2_vl/qwen2_5_vl/qwen3_vl (+ ``_text``
  variants), pass ``rope=True`` so upstream's fused m-rope kernel is
  actually installed. Previously the plugin overrode the upstream default
  to None, silently skipping the kernel.
* Patch-ordering fix: ``_apply_self_attention_lora_patch`` now runs
  before ``_apply_model_specific_patches`` in
  ``apply_pre_model_load_patches``. ``patch_self_attn_lora`` reads
  ``inspect.getsource`` of the attention class' forward, so any patch that
  replaces ``Attention.forward`` must run *after* the source-rewrite step.
  The wrong order also silently broke Gemma 4 + ``lora_qkv_kernel`` —
  pinned by ``TestPatchManagerOrdering`` and a fused-first trip-wire.

Tests
-----
* Per-model parity + backward grad flow for Qwen3, Qwen3-MoE, Qwen3.5,
  Qwen3.5-MoE (full-attention layers only; linear_attention layers stay
  on the stock GatedDeltaNet path).
* Kernel ``UNIT_OFFSET=True`` parity vs from-scratch reference + bwd
  parity vs torch-eager + ``torch.compile(fullgraph=True)`` parity.
* ``torch.compile(fullgraph=True)`` parity for the no-offset path.
* Liger Qwen-VL auto-default for all 6 model_config_types; explicit
  ``False`` is respected.
* Patch idempotency (double-apply is a no-op).
* Transformers signature contract — pins the stock attention forward
  argument names so future drift trips loudly at test time.
* Gradient-checkpointing composition (Qwen3 + ``gradient_checkpointing_enable``).
* Flash-Attention 2 composition (skip-if-unavailable).
* LoRA + fused composition on Qwen3 / Qwen3.5 / Qwen3.5-MoE, with
  fused-first reverse-order trip-wires that catch the original ordering
  bug if anyone re-introduces it.

A pre-existing upstream-drift xfail in ``test_gemma4_fused_attn.py``
documents Gemma 4 + ``lora_qkv_kernel`` being broken in transformers
5.8.1 (new ``shared_kv_states: dict[str, ...]`` signature drift in
QKV_PATCHES). Out of scope for this PR; flips to XPASS when patched.

Post-review fixes
-----------------
* ``_resolve_norm_module``: PEFT ``ModulesToSaveWrapper`` stores
  ``active_adapter`` as ``list[str]`` (e.g. ``["default"]``), not a string.
  The prior ``isinstance(adapter, str)`` check silently returned the
  frozen ``original_module`` for every real-PEFT case. Switched to
  iterating ``active_adapters`` (with ``active_adapter`` fallback) across
  all 4 patches. Added a direct unit-test plus an end-to-end test that
  drives real ``peft.get_peft_model(modules_to_save=["q_norm","k_norm"])``
  and asserts the helper returns the trainable adapter weight.
* ``cfg.fused_attn_kernel`` unsupported-model warning: moved out of the
  Pydantic ``model_validator(mode="before")`` (which ran *before*
  ``normalize_config()`` had derived ``model_config_type``, so it
  silently no-op'd on normal YAML input) into a new
  ``PatchManager._warn_if_fused_attn_unsupported`` staticmethod invoked
  from ``_apply_model_specific_patches``, where ``model_config_type`` is
  guaranteed set. Added a source-line guard that the helper stays wired.
@thad0ctor
Copy link
Copy Markdown
Contributor Author

@winglian I realized I forgot qwen3_vl in the PR, would it be best to add that now or a follow-on PR?

@winglian
Copy link
Copy Markdown
Collaborator

There are failing tests in the qwen3 and gemma4 fused attn tests so may as well include that model too while you fix those tests.

@thad0ctor thad0ctor force-pushed the qwen-fused-kernels branch from ddcc2c6 to 7bdce52 Compare May 27, 2026 13:36
@thad0ctor
Copy link
Copy Markdown
Contributor Author

There are failing tests in the qwen3 and gemma4 fused attn tests so may as well include that model too while you fix those tests.

@winglian

I just pushed the path for qwen3_vl and addressed the tests, updating the PR summary for completeness.

The sole remaining xfail appears to be unrelated to this PR but I can address if you would like.

Testing and benchmarks:

  • Targeted fused-attention suite on the local 3090: 60 passed, 1 xfailed (the xfail appears to be unrelated to this PR)
  • Exact failed CI tests after the logging-capture fix: 2 passed
  • Updated and ran and pre-commit tests/lint/ruff
  • Qwen3-VL attention microbenchmark on 3090: stock mean 16.118 ms, fused mean 13.535 ms, about 16.0% faster
  • Peak memory in the same microbenchmark: 0.426 GB stock to 0.384 GB fused
  • Resume smoke test from a checkpoint created before the patch:previous run 22.86 s/step over steps 18801-18836, resumed patched run 21.57 s/step; tail comparison improved from 24.76 s/step to 21.80 s/step

The reported failing Qwen3/Gemma4 fused-attn tests were not implementation failures. The expected log messages were present in captured stderr, but caplog.records was empty in the Modal/docker e2e environment because Axolotl logging had already been configured and the axolotl logger can run with propagate=False. The tests now attach caplog.handler directly to the module loggers they assert against:

  • axolotl.monkeypatch.lora_kernels for the Gemma4 fused-then-LoRA skip test
  • axolotl.loaders.patch_manager for the unsupported fused_attn_kernel warning test

The Qwen3-VL patch wires fused_attn_kernel: true for qwen3_vl and qwen3_vl_text, dispatches the new monkeypatch through PatchManager, and uses the existing fused RMSNorm+RoPE Triton path for the text attention q_norm / k_norm + mRoPE sequence. The patch keeps the stock attention flow and preserves the LoRA apply_qkv / apply_o hooks.

# Conflicts:
#	src/axolotl/loaders/patch_manager.py
@winglian winglian merged commit 280506e into axolotl-ai-cloud:main May 29, 2026
15 of 16 checks passed
@thad0ctor thad0ctor deleted the qwen-fused-kernels branch May 30, 2026 21:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants