feat(qwen): fused RMSNorm+RoPE for Qwen3/3.X family + Liger m-rope default#3680
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR implements fused RMSNorm+RoPE Triton kernels for Qwen3/3.5 models and their MoE variants. It refactors the kernel infrastructure to use ChangesQwen Fused Attention Implementation
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
37c3186 to
f201233
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (5)
tests/kernels/test_gemma4_fused_rope_unit_offset.py (1)
29-35: ⚡ Quick winAdd a
unit_offsettest wheren_rot < Dto cover the partial-rotary branch.Current
unit_offset=Truecases only exercisen_rot == D, so the pass-through path (col >= n_rot) is untested. Please add at least one case (for exampleD=128, n_rot=64) with forward parity and gradient checks.Also applies to: 50-55, 66-70, 100-105
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/kernels/test_gemma4_fused_rope_unit_offset.py` around lines 29 - 35, Add a test case inside test_gemma4_fused_rope_unit_offset that sets unit_offset=True but uses n_rot < D (e.g., D=128, n_rot=64) so the partial-rotary branch (columns >= n_rot) is exercised; construct x, weight, cos, sin with matching shapes (x: B,S,H,D; cos/sin: B,S,D), run the same forward parity check and backward gradient checks used in the existing cases, and assert equivalence with the non-fused/reference implementation; duplicate this new scenario for the other unit_offset=True blocks so all forward parity and gradient checks cover n_rot < D as well.src/axolotl/monkeypatch/models/qwen3_moe/fused_attn.py (1)
1-15: ⚡ Quick winPlease normalize comments to one short WHY-only line in this file.
The current docstrings/comment blocks are more verbose than the repository rule for
src/axolotlfiles.As per coding guidelines: “Only add comments when explaining the WHY behind non-obvious logic... Comments should be a maximum of one short line”.
Also applies to: 65-65
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/monkeypatch/models/qwen3_moe/fused_attn.py` around lines 1 - 15, The file-level docstring and any multi-line comments (including the docstring for _resolve_norm_module and the comment at line ~65) are too verbose; replace each with a single short WHY-only comment explaining the non-obvious intent (e.g., why we unwrap PEFT ModulesToSaveWrapper) and remove extra explanatory sentences so comments are one short line; update the top module docstring to a single-line WHY summary and convert multi-line inline comments to one-line WHY comments while keeping function names like _resolve_norm_module unchanged.src/axolotl/monkeypatch/models/qwen3/fused_attn.py (1)
1-15: ⚡ Quick winAlign comments/docstrings with the
src/axolotlone-line WHY-only rule.Please shorten/trim these comments to single-line WHY-focused notes; the current module/function docstrings and long inline explanation are more descriptive than the repo rule allows.
As per coding guidelines: “Only add comments when explaining the WHY behind non-obvious logic... Comments should be a maximum of one short line”.
Also applies to: 63-63
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/monkeypatch/models/qwen3/fused_attn.py` around lines 1 - 15, The module-level docstring and the `_resolve_norm_module` docstring/explanatory comment are too verbose and violate the one-line WHY-only rule; replace them with concise single-line comments that explain why the fused kernel exists and why `_resolve_norm_module` unwraps `ModulesToSaveWrapper` to access the trainable adapter (e.g., "Why: fuse q_norm/k_norm + RoPE into one Triton kernel for performance" and "Why: unwrap ModulesToSaveWrapper to read active adapter weights"), and remove the longer descriptive text elsewhere (including the inline comment at line ~63).src/axolotl/monkeypatch/models/qwen3_5/fused_attn.py (1)
1-15: ⚡ Quick winRefactor comment/docstring verbosity to match
src/axolotlpolicy.Please keep comments as single-line WHY notes only; current multi-line descriptive text is outside the repo guideline.
As per coding guidelines: “Only add comments when explaining the WHY behind non-obvious logic... Comments should be a maximum of one short line”.
Also applies to: 50-50, 77-77
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/monkeypatch/models/qwen3_5/fused_attn.py` around lines 1 - 15, The module-level multi-line docstring and other verbose comments should be replaced with a single short WHY comment per repo policy: remove the triple-quoted description at the top and replace it with a one-line comment explaining the reason for the special fused behavior (e.g., why we need fused q_norm/k_norm + RoPE for Qwen3.5), and likewise collapse any other multi-line comments at the indicated spots into single-line WHY notes; update the header around the _resolve_norm_module function to have only a concise one-line comment explaining why ModulesToSaveWrapper is unwrapped and what active_adapter significance is.src/axolotl/monkeypatch/models/qwen3_5_moe/fused_attn.py (1)
1-15: ⚡ Quick winApply the one-line WHY-only comment rule here as well.
This file has the same verbose docstring/comment pattern; please shorten to concise WHY-only comments.
As per coding guidelines: “Only add comments when explaining the WHY behind non-obvious logic... Comments should be a maximum of one short line”.
Also applies to: 50-50, 77-77
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/monkeypatch/models/qwen3_5_moe/fused_attn.py` around lines 1 - 15, The top-file docstring and verbose comments in this module should be replaced with one-line WHY-only comments: remove the multi-line module docstring and instead add a single short line explaining why this monkeypatch exists (e.g., "Workaround to make fused-attention use adapter weights for Qwen3.5-MoE."); likewise shorten the comment above _resolve_norm_module and the comments at the locations referenced (around lines 50 and 77) to single-line WHY explanations that describe the non-obvious reason (not the how) for unwrapping ModulesToSaveWrapper and any adapter-specific behavior, leaving implementation/details to the code itself. Ensure references to _resolve_norm_module, ModulesToSaveWrapper/modules_to_save, and active_adapter remain clear in the single-line comments.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@src/axolotl/monkeypatch/models/qwen3_5_moe/fused_attn.py`:
- Around line 1-15: The top-file docstring and verbose comments in this module
should be replaced with one-line WHY-only comments: remove the multi-line module
docstring and instead add a single short line explaining why this monkeypatch
exists (e.g., "Workaround to make fused-attention use adapter weights for
Qwen3.5-MoE."); likewise shorten the comment above _resolve_norm_module and the
comments at the locations referenced (around lines 50 and 77) to single-line WHY
explanations that describe the non-obvious reason (not the how) for unwrapping
ModulesToSaveWrapper and any adapter-specific behavior, leaving
implementation/details to the code itself. Ensure references to
_resolve_norm_module, ModulesToSaveWrapper/modules_to_save, and active_adapter
remain clear in the single-line comments.
In `@src/axolotl/monkeypatch/models/qwen3_5/fused_attn.py`:
- Around line 1-15: The module-level multi-line docstring and other verbose
comments should be replaced with a single short WHY comment per repo policy:
remove the triple-quoted description at the top and replace it with a one-line
comment explaining the reason for the special fused behavior (e.g., why we need
fused q_norm/k_norm + RoPE for Qwen3.5), and likewise collapse any other
multi-line comments at the indicated spots into single-line WHY notes; update
the header around the _resolve_norm_module function to have only a concise
one-line comment explaining why ModulesToSaveWrapper is unwrapped and what
active_adapter significance is.
In `@src/axolotl/monkeypatch/models/qwen3_moe/fused_attn.py`:
- Around line 1-15: The file-level docstring and any multi-line comments
(including the docstring for _resolve_norm_module and the comment at line ~65)
are too verbose; replace each with a single short WHY-only comment explaining
the non-obvious intent (e.g., why we unwrap PEFT ModulesToSaveWrapper) and
remove extra explanatory sentences so comments are one short line; update the
top module docstring to a single-line WHY summary and convert multi-line inline
comments to one-line WHY comments while keeping function names like
_resolve_norm_module unchanged.
In `@src/axolotl/monkeypatch/models/qwen3/fused_attn.py`:
- Around line 1-15: The module-level docstring and the `_resolve_norm_module`
docstring/explanatory comment are too verbose and violate the one-line WHY-only
rule; replace them with concise single-line comments that explain why the fused
kernel exists and why `_resolve_norm_module` unwraps `ModulesToSaveWrapper` to
access the trainable adapter (e.g., "Why: fuse q_norm/k_norm + RoPE into one
Triton kernel for performance" and "Why: unwrap ModulesToSaveWrapper to read
active adapter weights"), and remove the longer descriptive text elsewhere
(including the inline comment at line ~63).
In `@tests/kernels/test_gemma4_fused_rope_unit_offset.py`:
- Around line 29-35: Add a test case inside test_gemma4_fused_rope_unit_offset
that sets unit_offset=True but uses n_rot < D (e.g., D=128, n_rot=64) so the
partial-rotary branch (columns >= n_rot) is exercised; construct x, weight, cos,
sin with matching shapes (x: B,S,H,D; cos/sin: B,S,D), run the same forward
parity check and backward gradient checks used in the existing cases, and assert
equivalence with the non-fused/reference implementation; duplicate this new
scenario for the other unit_offset=True blocks so all forward parity and
gradient checks cover n_rot < D as well.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a4617ab6-0d3c-4139-8659-04acecb39b72
📒 Files selected for processing (22)
docs/optimizations.qmdexamples/qwen3/8b-lora-fused-attn.yamlsrc/axolotl/integrations/liger/args.pysrc/axolotl/integrations/liger/plugin.pysrc/axolotl/kernels/gemma4_fused_rope.pysrc/axolotl/loaders/patch_manager.pysrc/axolotl/monkeypatch/models/qwen3/__init__.pysrc/axolotl/monkeypatch/models/qwen3/fused_attn.pysrc/axolotl/monkeypatch/models/qwen3_5/fused_attn.pysrc/axolotl/monkeypatch/models/qwen3_5_moe/__init__.pysrc/axolotl/monkeypatch/models/qwen3_5_moe/fused_attn.pysrc/axolotl/monkeypatch/models/qwen3_moe/__init__.pysrc/axolotl/monkeypatch/models/qwen3_moe/fused_attn.pysrc/axolotl/utils/schemas/config.pytests/integrations/test_liger_qwen_vl_rope_default.pytests/kernels/test_gemma4_fused_rope_compile.pytests/kernels/test_gemma4_fused_rope_unit_offset.pytests/monkeypatch/test_gemma4_fused_attn.pytests/monkeypatch/test_qwen3_5_fused_attn.pytests/monkeypatch/test_qwen3_fused_attn.pytests/monkeypatch/test_qwen3_fused_attn_defensive.pytests/monkeypatch/test_qwen3_fused_attn_robustness.py
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
|
Benchmark: fused kernel vs the realistic baseline (liger_rms_norm) Setup: isolated q_norm/k_norm + RoPE path (exactly what the kernel replaces), Qwen3-14B shapes (40 q-heads / 8 kv-heads, head_dim 128, B=1), bf16, RTX PRO 6000 Blackwell (sm_120, 188 SM). Median of CUDA-event-timed batches after warmup. Baselines: stock eager Forward + backward (training-relevant), ms/iter ¹ 1k fwd+bwd is sub-ms and clock/noise-dominated (the eager/liger rows bounce run-to-run); treat as ≈ break-even-to-positive. Forward only (inference / generation), ms/iter Peak + resident activation VRAM |
|
@winglian great inputs on the bwd pass and thanks for the affirming bench results! FYI, I have a follow-on fork PR stacked on this branch that revisits the sm_86> compile : where the earlier numbers showed torch_compile was not helpful on Ampere, the new torch 2.11 runs show vanilla compile on top of this fused path is now ~16.5-17% faster than fused-only, with the new Inductor option plumbing adding a smaller additional tuning gain on top (~0.5% gain). Please let me know if I should wait for this to merge or PR this now, the PR against my fork if you are interested in looking: thad0ctor#28 |
af7d7fa to
0ee177d
Compare
…3.5-MoE Generalizes the existing Gemma 4 fused RMSNorm+RoPE Triton kernel to four new Qwen attention variants, and auto-enables Liger's fused (m-)rope kernel for the Qwen-VL family. Eager-mode behavior is bit-identical when the new cfg.fused_attn_kernel flag is unset. Changes ------- * New ``cfg.fused_attn_kernel: bool | None`` (default None / off). When set, replaces ``q_norm + apply_rotary_pos_emb`` (and the matching K path) with a single fused RMSNorm+RoPE Triton kernel launch. Currently wired for ``qwen3``, ``qwen3_moe``, ``qwen3_5``, and ``qwen3_5_moe`` model_config_types. Llama4 is out of scope (complex freqs_cis + Llama4TextL2Norm post-RoPE — separate kernel). * Kernel ``UNIT_OFFSET: tl.constexpr`` flag added to the forward + backward Triton kernels for Qwen3.5's Gemma-style ``(1.0 + weight)`` RMSNorm. Default ``False`` keeps Gemma 4 / Qwen3 / Qwen3-MoE bit-identical to before. Threaded through the triton_op + register_autograd plumbing. * Refactors ``fused_rms_norm_rope`` / ``fused_rms_norm_noscale`` to ``torch.library.triton_op`` + ``register_autograd`` so they trace under ``torch.compile(fullgraph=True)``. Validated: 1 Dynamo frame, 0 graph breaks. On sm_120 the compile path composes to +9.2% combined, −33% peak memory. On sm_86 the surrounding Inductor-generated kernels regress — leave ``torch_compile: false`` there; schema description documents the per-arch recommendation. * Liger Qwen-VL auto-default: when ``cfg.liger_rope is None`` and model_config_type is one of qwen2_vl/qwen2_5_vl/qwen3_vl (+ ``_text`` variants), pass ``rope=True`` so upstream's fused m-rope kernel is actually installed. Previously the plugin overrode the upstream default to None, silently skipping the kernel. * Patch-ordering fix: ``_apply_self_attention_lora_patch`` now runs before ``_apply_model_specific_patches`` in ``apply_pre_model_load_patches``. ``patch_self_attn_lora`` reads ``inspect.getsource`` of the attention class' forward, so any patch that replaces ``Attention.forward`` must run *after* the source-rewrite step. The wrong order also silently broke Gemma 4 + ``lora_qkv_kernel`` — pinned by ``TestPatchManagerOrdering`` and a fused-first trip-wire. Tests ----- * Per-model parity + backward grad flow for Qwen3, Qwen3-MoE, Qwen3.5, Qwen3.5-MoE (full-attention layers only; linear_attention layers stay on the stock GatedDeltaNet path). * Kernel ``UNIT_OFFSET=True`` parity vs from-scratch reference + bwd parity vs torch-eager + ``torch.compile(fullgraph=True)`` parity. * ``torch.compile(fullgraph=True)`` parity for the no-offset path. * Liger Qwen-VL auto-default for all 6 model_config_types; explicit ``False`` is respected. * Patch idempotency (double-apply is a no-op). * Transformers signature contract — pins the stock attention forward argument names so future drift trips loudly at test time. * Gradient-checkpointing composition (Qwen3 + ``gradient_checkpointing_enable``). * Flash-Attention 2 composition (skip-if-unavailable). * LoRA + fused composition on Qwen3 / Qwen3.5 / Qwen3.5-MoE, with fused-first reverse-order trip-wires that catch the original ordering bug if anyone re-introduces it. A pre-existing upstream-drift xfail in ``test_gemma4_fused_attn.py`` documents Gemma 4 + ``lora_qkv_kernel`` being broken in transformers 5.8.1 (new ``shared_kv_states: dict[str, ...]`` signature drift in QKV_PATCHES). Out of scope for this PR; flips to XPASS when patched. Post-review fixes ----------------- * ``_resolve_norm_module``: PEFT ``ModulesToSaveWrapper`` stores ``active_adapter`` as ``list[str]`` (e.g. ``["default"]``), not a string. The prior ``isinstance(adapter, str)`` check silently returned the frozen ``original_module`` for every real-PEFT case. Switched to iterating ``active_adapters`` (with ``active_adapter`` fallback) across all 4 patches. Added a direct unit-test plus an end-to-end test that drives real ``peft.get_peft_model(modules_to_save=["q_norm","k_norm"])`` and asserts the helper returns the trainable adapter weight. * ``cfg.fused_attn_kernel`` unsupported-model warning: moved out of the Pydantic ``model_validator(mode="before")`` (which ran *before* ``normalize_config()`` had derived ``model_config_type``, so it silently no-op'd on normal YAML input) into a new ``PatchManager._warn_if_fused_attn_unsupported`` staticmethod invoked from ``_apply_model_specific_patches``, where ``model_config_type`` is guaranteed set. Added a source-line guard that the helper stays wired.
0ee177d to
3250b92
Compare
|
@winglian I realized I forgot qwen3_vl in the PR, would it be best to add that now or a follow-on PR? |
|
There are failing tests in the qwen3 and gemma4 fused attn tests so may as well include that model too while you fix those tests. |
ddcc2c6 to
7bdce52
Compare
I just pushed the path for qwen3_vl and addressed the tests, updating the PR summary for completeness. The sole remaining xfail appears to be unrelated to this PR but I can address if you would like. Testing and benchmarks:
The reported failing Qwen3/Gemma4 fused-attn tests were not implementation failures. The expected log messages were present in captured stderr, but
The Qwen3-VL patch wires |
# Conflicts: # src/axolotl/loaders/patch_manager.py
Description
Generalizes axolotl's existing Gemma 4 fused RMSNorm + RoPE Triton kernel to Qwen3, Qwen3-MoE, Qwen3.5, Qwen3.5-MoE, Qwen3.6 dense, and Qwen3.6-MoE full-attention layers behind a new opt-in
cfg.fused_attn_kernel. Qwen3.6 checkpoints are loaded by transformers under theqwen3_5/qwen3_5_moemodel_types, so the same dispatch covers both generations.This update also folds in the Qwen3-VL fused-attention work from thad0ctor#29. It adds support for
model_config_type: qwen3_vl/qwen3_vl_textby fusing the Qwen3-VL text attention q/k path:q_proj/k_proj -> per-head RMSNorm -> mRoPEwith the existing compile-safe Triton
fused_rms_norm_ropekernel. The patched forward preserves the stock Qwen3-VL attention flow after q/k preparation, including cache update,ALL_ATTENTION_FUNCTIONSdispatch, dropout/scaling kwargs, output projection, and Axolotl LoRAapply_qkv/apply_ohooks.The PR adds a
UNIT_OFFSETkernel flag for Qwen3.5's Gemma-style(1.0 + weight)RMSNorm; default off keeps every existing Gemma 4 / Qwen3 caller bit-identical.It also auto-enables Liger's fused (m-)rope kernel for
qwen2_vl/qwen2_5_vl/qwen3_vl/qwen3_vl_moeand_textvariants whenliger_ropeis unset. Previously, axolotl's plugin overrode upstream'srope=Truedefault withNone, silently skipping the fused m-rope kernel.Finally, this fixes a pre-existing latent patch-ordering bug in
apply_pre_model_load_patches:patch_self_attn_lorare-parsesAttention.forwardsource viainspect.getsource, so it must run before any patch that replacesforward. The reorder also fixes Gemma 4 +lora_qkv_kernelcomposition.Motivation and Context
The goal is to extend the existing Gemma 4 fused RMSNorm+RoPE Triton kernel to the Qwen3 family for added LoRA throughput. The Qwen3 attention forwards are structurally compatible with the existing kernel:
UNIT_OFFSET: tl.constexprfor their Gemma-style(1.0 + weight)RMSNorm and a chunk-2 split on the gatedq_proj.fused_attn_kernel: true.Performance
Qwen3 Dense Benchmark Runs
All runs below use real
Qwen/Qwen3-0.6B(28 layers, hidden=1024, head_dim=128) under LoRA r=16 alpha=32 on[q,k,v,o]_proj, bf16,attn_implementation=sdpa. Was = stock eager. Is =fused_attn_kernel: true, optionally combined withtorch_compile: true.Notes:
Qwen3-VL Benchmark Runs
This is the benchmark run from thad0ctor#29, using separate fresh processes and real Qwen3-VL-8B text-attention dimensions.
Environment:
CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=2Qwen3-VL-8B-Instructconfigflash_attention_2Qwen3VLTextAttentionforward + backwardEvery fused benchmark run reported
forward_module == 'axolotl.monkeypatch.models.qwen3_vl.fused_attn', so thefused_attn_kernel: truepath was active and not silently deactivated.How has this been tested?
Unit / integration:
tests/monkeypatch/,tests/kernels/,tests/integrations/test_liger*.Test surfaces:
Qwen3VLTextModelPatchManagerdispatch formodel_config_type='qwen3_vl'andqwen3_vl_textUNIT_OFFSET=Trueparity vs from-scratch referencetorch.compile(fullgraph=True)parity for offset and no-offset pathsliger_rope: falseremains respectedfused_attn_kernel: true+liger_rms_norm: true)ModulesToSaveWrappercomposition for trainable q/k RMSNorm adapter weightshead_dim=256attention_maskpass-through andsliding_windowkwarg preservationget_text_config-derived dispatch from multimodal Qwen3 checkpointsLocal validation:
The maintainer-reported Gemma4/Qwen3 fused-attn CI failures were test harness failures, not implementation failures. In the Modal/docker e2e logs, the expected log messages were present in captured stderr, but
caplog.recordswas empty because Axolotl logging had already been configured and theaxolotllogger can run withpropagate=False. The tests now attachcaplog.handlerdirectly to the module loggers they assert against:axolotl.monkeypatch.lora_kernelsfor the Gemma4 fused-then-LoRA skip testaxolotl.loaders.patch_managerfor the unsupportedfused_attn_kernelwarning testEnd-to-end on real checkpoints:
Qwen/Qwen3-0.6Bsave / reload parity (3090): save adapter -> reload into fresh model -> forward parity vs in-memory,max_abs = 0.0Qwen/Qwen3-0.6B50-step LoRA convergence (3090): eager 13.189 -> 11.996, fused 13.186 -> 11.997, final-loss diff =0.0005(bf16 noise); both 33/49 steps monotone-decreasingQwen/Qwen3-0.6BPatchManager pipeline (3090): patches whencfg.fused_attn_kernel: true, skips whenfalseQwen3.5-2Bcheckpoint LoRA parity (3090, 5 steps, full_attention layers exercised):max_abs=0.0020, both finite and decreasingQwen3-VL-8B-Instructreal config path check: normalized config hasmodel_config_type='qwen3_vl',model_config_type_text='qwen3_vl_text', andfused_attn_kernel=True;PatchManager._apply_model_specific_patches()installsaxolotl.monkeypatch.models.qwen3_vl.fused_attn; Liger plugin does not overwrite the fused forwardQwen3-VL-8B-InstructLiger m-rope auto-default end-to-end (3090): realQwen3VLForConditionalGenerationloads and forward-passes finite logits with the swapped kernel in placetorchrun --nproc_per_node=3, full_shard + bf16 mixed precision + LoRA + gradient checkpointing): 3-step training completes with finite loss; max peak memory 2.58 GB/rankQwen3.6-27Bdense multimodal checkpoint, FSDP multi-GPU (4x 24 GB GPU): patch installs on full-attention layers only; 3-step training completes with finite, monotonically decreasing loss; peak memory 13.4 GB/rankQwen3.6-35B-A3BMoE multimodal checkpoint, FSDP multi-GPU (4x 24 GB GPU): patch installs on full-attention layers only; 3-step training completes with finite loss; peak memory 13.5 GB/rankAI Usage Disclaimer
Yes. OpenAI ChatGPT/Codex assisted with code changes, test scaffolding, local benchmarking, validation, and drafting this PR summary. The implementation was validated locally with the tests and benchmark runs described above.
Screenshots (if appropriate)
N/A
Types of changes
fused_attn_kernelfor Qwen3 / Qwen3-MoE / Qwen3.5 / Qwen3.5-MoE / Qwen3-VL text attentionapply_pre_model_load_patches; Liger Qwen-VL m-rope auto-default; configured-logging-safe caplog testsq_norm/k_normweights; PEFTModulesToSaveWrapperunwrap; PatchManager warning when set on unsupportedmodel_config_typeSocial Handles (Optional)
N/A