feat(compile): traceable GDN decoder loop (Qwen3.5 + Qwen3-Next) + model-agnostic FA2 varlen collator#3734
Conversation
…eable packing With torch_compile, a single in-loop graph break made dynamo skip the entire Qwen3_5TextModel.forward frame, so every decoder layer ran eagerly (no norm/rope/gating/residual fusion, full per-kernel launch overhead). The in-loop blockers were: axolotl's get_cu_seqlens() nonzero in the packing patch, FLA's @torch.compiler.disable'd chunk_gated_delta_rule, causal_conv1d's .item(), and FusedRMSNormGated's untraceable device-property probe. - New fla_ops.py: opaque torch.library custom ops (axolotl_qwen3_5::gdn_conv / gdn_chunk + backward) wrapping the same FLA host functions with identical saved tensors (no recompute, no kernel changes). They take position_ids and derive cu_seqlens eagerly inside the op, so no data-dependent op enters the graph and every fake impl is static. - modeling.py: the GDN training/no-cache path routes through the ops; a FusedRMSNormGated custom-op wrapper installs whenever torch_compile is on; decoder self-attention is traced when gradient_checkpointing and training. - patch_manager.py threads torch_compile=bool(cfg.torch_compile) into the packing patches. Bit-exactness: eager FLA's dg lands on the bf16 grid (reproduced via an explicit f32->bf16->f32 round-trip); v reaches the kernels as a non-contiguous split/reshape view (the bwd op mirrors FLA's input_guard contiguization). Verified (tiny 4-layer hybrid): 0 graph breaks / 1 graph for the full text model under GC; eager ops-vs-legacy bitwise identical; compiled-vs-eager within compile noise. monkeypatch compiled-loop suite 4/4, qwen3_5 fused-attn 12/12. Note: flash_attention_2 configs still don't get the compiled loop — the remaining breaker is transformers' per-layer varlen derivation on the FA2 path, fixable purely axolotl-side by emitting precomputed cu_seq_lens/max_length from the multipack collator (follow-up, not in this change).
… GDN ops _FLA_COMPILED_OPS now follows the torch_compile flag, so it is set to False when torch_compile=False instead of always enabling the opaque GDN ops.
… the compiled loop The compiled decoder loop (the preceding commit) works under sdpa but not flash_attention_2: with packed sequences signalled only via position_ids, transformers re-derives the varlen metadata per layer inside _flash_attention_forward (a data-dependent _is_packed_sequence branch + (position_ids==0).nonzero()), which graph-breaks inside the decoder loop and makes dynamo skip the whole frame. transformers already ships the escape hatch: if the caller passes the FlashAttentionKwargs (cu_seq_lens_q/k + max_length_q/k), is_fa_with_varlen_kwargs short-circuits before the data-dependent branch and the per-layer derivation is skipped. This emits exactly those kwargs from the multipack collator, computed once via transformers' own prepare_fa_kwargs_from_position_ids (so the metadata is bit-identical to what it would derive per layer). Gated to FA2 + sample packing + qwen3_5/qwen3_5_moe collators; max_length stays a python int so no capture_scalar_outputs is needed. Verified: FA2 + torch_compile now traces the loop with 0 graph breaks (was 1 break / 4 fragmented graphs); eager FA2 loss with-vs-without the kwargs is bitwise identical. compiled-loop suite 6/6.
…acks, formatting - build_collator: gate emit_fa_varlen_kwargs on the eval-aware packing state (training_args.eval_sample_packing for eval loaders) instead of cfg.sample_packing, so eval-only packed mode also gets the precomputed kwargs. - collator + FusedRMSNormGated boundary: replace silent exception swallows with warnings so a disabled compiled-loop optimization is visible. - ruff-format the new test assert; collapse a two-line comment.
…lback, B>1 guard, weight-None norm fallback, saved-v contiguity; add MoE/FA2-parity/opcheck/B>1 tests
…undary (torch 2.11-only, unreproduced on 2.9/2.10)
…GC on torch<2.11 (NaN gradients)
- collator: emit FA2 varlen kwargs for ALL multipack models under torch_compile (was gated to qwen3_5/qwen3_5_moe); model-agnostic, transformers consumes them natively - share the GatedDeltaNet opaque ops: move fla_ops -> gated_delta_net_ops (neutral axolotl_gdn namespace), add cast_g flag so qwen3_5 (casts g) and qwen3_next (f32 g) both stay bit-exact; move the FusedRMSNormGated compile boundary into the shared module - wire qwen3_next: route conv+chunk through the shared ops under compile, install the rmsnorm boundary + self-attn dynamo boundary, thread torch_compile; also fix a pre-existing crash (decoder passed cache_position to GatedDeltaNet.forward) - extend the lora-kernel NaN guard to qwen3_next - tests: qwen3_next compiled-loop (0 breaks + bitwise eager parity); opcheck both cast_g
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdds Changestorch.compile for Qwen3.5/Qwen3-Next Packed Decoder Loops
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
src/axolotl/monkeypatch/models/qwen3_next/modeling.py (1)
62-140: 💤 Low valueMinor inconsistency with qwen3_5 decoder forward.
In qwen3_5, the
_call_self_attn_disabledwrapper is always used for the non-gradient-checkpointing self-attention path (lines 134-143 in qwen3_5/modeling.py), regardless oftorch_compile. Here, it's only applied whentorch_compile=True.This inconsistency is harmless (
_call_self_attn_disabledis essentially a no-op when dynamo isn't active), but aligning the behavior would improve maintainability.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/monkeypatch/models/qwen3_next/modeling.py` around lines 62 - 140, The patched_decoder_forward function has an inconsistency with the qwen3_5 implementation. In qwen3_5, the _call_self_attn_disabled wrapper is always used for the non-gradient-checkpointing self-attention path, but here it's only applied when torch_compile=True. To align the behavior, remove the torch_compile check from the conditional statement in the "full_attention" token mixer path so that _call_self_attn_disabled is always called when gradient checkpointing is not enabled and training is not active, regardless of the torch_compile parameter value.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/axolotl/utils/collators/batching.py`:
- Around line 144-150: The FA varlen kwargs precomputation exception handler in
the batching.py file is catching the overly broad Exception class, which can
mask unexpected errors and make debugging difficult. Replace the generic
Exception catch with specific exception types that are actually expected to fail
during the FA varlen precompute operation (such as AttributeError, TypeError, or
RuntimeError, depending on what failures are genuinely expected), so that
unexpected exceptions will properly propagate and alert developers to actual
collator regressions rather than silently falling back to the slow path.
In `@tests/monkeypatch/test_qwen3_5_compiled_loop.py`:
- Around line 27-44: The fixture teardown for _axolotl_compile_boundary in
tests/monkeypatch/test_qwen3_5_compiled_loop.py lines 27-44 conflates a missing
attribute with one that has a False value by using getattr with a False default.
Instead of using getattr with a default False, capture both the presence of the
attribute (using hasattr) and its actual value separately in the saved
dictionary during setup. During teardown, check the presence flag: if the
attribute was originally present, restore its saved value using setattr; if it
was originally absent, delete it using delattr. Apply the identical
presence-aware snapshot and restore logic to
tests/monkeypatch/test_qwen3_next_compiled_loop.py lines 26-43 to ensure
consistent behavior across both test files.
- Around line 422-423: The pytest.raises context manager at line 422 is catching
any Exception type, which can mask unrelated failures that happen to match the
same message text. Change the exception type from Exception to ValueError in the
pytest.raises call to be more specific about what exception should actually be
raised when the batch size contract is violated. This narrower exception type
will ensure only the intended ValueError with the "batch size is expected to be
1" message is caught, preventing unrelated exceptions from being silently
masked.
---
Nitpick comments:
In `@src/axolotl/monkeypatch/models/qwen3_next/modeling.py`:
- Around line 62-140: The patched_decoder_forward function has an inconsistency
with the qwen3_5 implementation. In qwen3_5, the _call_self_attn_disabled
wrapper is always used for the non-gradient-checkpointing self-attention path,
but here it's only applied when torch_compile=True. To align the behavior,
remove the torch_compile check from the conditional statement in the
"full_attention" token mixer path so that _call_self_attn_disabled is always
called when gradient checkpointing is not enabled and training is not active,
regardless of the torch_compile parameter value.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b35d2e01-4abc-43aa-9288-ab03bdb5f258
📒 Files selected for processing (8)
src/axolotl/core/builders/causal.pysrc/axolotl/loaders/patch_manager.pysrc/axolotl/monkeypatch/models/gated_delta_net_ops.pysrc/axolotl/monkeypatch/models/qwen3_5/modeling.pysrc/axolotl/monkeypatch/models/qwen3_next/modeling.pysrc/axolotl/utils/collators/batching.pytests/monkeypatch/test_qwen3_5_compiled_loop.pytests/monkeypatch/test_qwen3_next_compiled_loop.py
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
- batching.py: narrow the FA varlen fallback catch to the realistic exception set (ImportError/AttributeError/RuntimeError/TypeError/ValueError/IndexError) instead of bare Exception (Ruff BLE001) - tests: presence-aware restore of FusedRMSNormGated._axolotl_compile_boundary (snapshot hasattr + value) in both compiled-loop fixtures - tests: narrow B>1 pytest.raises to ValueError (verified: the op raises a clean ValueError through the model forward) - qwen3_next: route non-GC self-attn through the dynamo.disable boundary unconditionally (a no-op when not compiling), matching qwen3_5
Description
Under
torch_compile, one graph break inside the GatedDeltaNet decoder for-loop made dynamo skip the whole text-modelforwardframe (a break inside a loop is unresumable), so every layer (attention included) ran eagerly.This makes the loop traceable for every GatedDeltaNet model (Qwen3.5 dense + MoE, Qwen3-Next), plus a model-agnostic collator path that fixes the same break for any flash-attention + packed model under compile:
gated_delta_net_ops.py(new, shared) — opaquetorch.libraryops (axolotl_gdn::gdn_conv/gdn_chunk+ backward) wrapping the same FLA kernels with identical saved tensors. They derivecu_seqlensfromposition_idsinside the op, so noaten.nonzeroenters the graph. Acast_gflag matches each architecture's eager call (qwen3_5 castsg; qwen3_next keeps f32) for bitwise parity; packed input enforces batch-1 with an explicitValueError. Also hosts the sharedFusedRMSNormGatedcompile boundary.qwen3_5/+qwen3_next/modeling.py— route the GDN training path through the ops under compile, and install theFusedRMSNormGatedopaque wrapper (its backward recomputes viaautograd.grad, since the FLA backward isn't meta-traceable). Self-attention stays in-graph under GC+training, behind adynamo.disableboundary otherwise (guards an Inductor FA2-backward fusion that corrupts packed grads; pinned bytest_fa2_compiled_matches_eager_grads). The qwen3_next wiring also fixes a pre-existing crash (its decoder passedcache_positionto aGatedDeltaNet.forwardthat rejects it).collators/batching.py+builders/causal.py— undertorch_compile, the multipack collator precomputes the FA2 varlen kwargs (cu_seq_lens_q/k,max_length_q/k) once per batch, so transformers skips its per-layernonzeroderivation. Model-agnostic — gated onSUPPORTED_MULTIPACK_MODEL_TYPES, so any packed FA2 model compiles clean.patch_manager.py— threadstorch_compileinto the qwen3_5/qwen3_5_moe/qwen3_next packing patches.Opt-in via
torch_compile; inert otherwise.Version coupling — the
==0.4.1pin is load-bearing. The ops wrap FLA internal functions by fixed arity and the fakes hardcode 0.4.1 shapes/dtypes; FLA 0.4.2/0.5.0 already changedchunk_gated_delta_rule_fwd's arity (crash at first forward). So a bump needs revalidating the ops — guarded bytest_opcheck_custom_ops+test_eager_parity_ops_vs_legacy_bitwise(the call-time crash isn't caught by the build-time warning). The upstream fix fla-org/flash-linear-attention#909 removes only the index-helper breaks, not thenonzero/@torch.compiler.disableones these ops target, so the wrapper can't just be dropped on a newer FLA.Known limit: auto-enabled LoRA triton kernels + a GDN packing patch +
torch_compile+ GC NaN gradients on torch ≤ 2.10 (pre-existing on main; clean on 2.11, and with any one ingredient removed).PatchManagerforce-disables those kernels with a warning on torch < 2.11.Motivation and Context
The decoder loop is the bulk of training compute. Skipping its compilation forfeits the largest fusion payoff: the eager RMSNorm fp32 round-trips (~870 GB/step of traffic at 9B/32k, interposer-measured) + per-kernel Python launch overhead. Once the loop traces, Inductor fuses those.
Scope: the GDN ops cover Qwen3.5/Qwen3.6 (
model_typeqwen3_5/qwen3_5_moe, dense + MoE, text + VL — every shipped checkpoint is*ForConditionalGeneration) and Qwen3-Next (qwen3_next, same FLA GatedDeltaNet kernels). The collator change is model-agnostic. Dense Qwen3/Qwen2.5/Qwen3-VL have no GatedDeltaNet, so the loop-internal breaks don't exist for them — but they still benefit from the collator path under FA2+packing+compile.How has this been tested?
Stack: torch 2.11.0+cu130, transformers 5.9.0, FLA 0.4.1, liger-kernel 0.8.0, real flash_attn 2.8.3 (the validated stack). RTX 3090 / 3090 Ti / 5090.
Functional — 12 tests, all passing:
tests/monkeypatch/test_qwen3_5_compiled_loop.py(10): zero-breaks-with-GC (1 graph), FA2-loop-with-varlen-kwargs, VLForConditionalGeneration3-D MRoPE,aten.nonzero-break-gone, bitwise eager parity (ops vs legacy), compiled-vs-eager, FA2+GC compiled grad parity, toy MoE (0 breaks), B>1-packed-raises,opcheck(packed/MRoPE/dense × bias × bothcast_g).tests/monkeypatch/test_qwen3_next_compiled_loop.py(2): qwen3_next decoder loop 0 breaks + bitwise eager parity (cast_g=False, g stays f32).Existing
test_qwen3_5_fused_attn.py: 12/12 (no regression). Suite also passes one major version down (torch 2.9.1 / 2.10.0 + transformers 5.5.4).Multimodal: Multimodal training is non-packed in axolotl (SFT vision examples + MM-CPT (pending PR #3629) ship
sample_packing: false), so the GDN ops are inert there by construction this was validated with real images (llava-instruct through the Qwen3.5-2B vision tower: packing patches absent from logs, eager-vs-compiled step-1 loss Δ0.004, eval clean).Multiple test training runs and benches (below).
Benchmarks (all on the validated stack above)
Qwen3.5-2B/9B =
Qwen3_5ForConditionalGeneration, LoRA r=16 q/k/v/o, sample packing, bf16, gradient checkpointing,attn_implementation: flash_attention_2. 60 steps, steady-state window after warmup,TORCH_LOGS=recompiles. Throughput = tok/s; VRAM = reserved (GiB). Zero NaN and step-1/final-10 loss parity in every cell.A. Core compile speedup On real FA2,
torch_compilewithout this PR is near-neutral (the loop graph-breaks); this PR is what makes compile worth it (RTX 3090):B. Feature compatibility / composition (Qwen3.5-2B, compile + PR, varying feature): (RTX 3090)
fused_attn_kernel(#3680)cut_cross_entropy)fused_linear_cross_entropyfused_attn_kernel)LoRA triton kernels and
fused_attn_kernelare subsumed by Inductor once the loop compiles (consistent with #3680's own "fused+compile −18%" dense-Qwen3 numbers); liger composes cleanly but adds no value under compile.C. Generalization — the model-agnostic collator (compile on/off, FA2 + packing): (RTX 5090)
Neither model has a GatedDeltaNet so speedups come from the collator removing the FA2 per-layer break. Isolation (toy Llama, FA2 + packed, under compile): 3 graph breaks (4 graphs) without the varlen kwargs → 0 breaks (1 graph) with them. 0 recompiles in steady state; loss parity holds.
Convergence: loss is bit-stable across every cell above — step-1 identical to eager and final-10 within float noise (e.g. 2B eager 1.1275 vs compile+PR 1.1275; 9B 0.9264 vs 0.9266). The opaque-op path is bitwise identical to legacy eager (regression-tested) for both qwen3_5 and qwen3_next.
Composition with #3732 (fused-LoRA GDN routing): merges cleanly, both touch the GDN forward, the projection call sites route through
_la_proj_fwdwhile the compiled-ops gating stays intact, and the fusedLoRA_Oautograd Function traces inside the compiled loop with 0 graph breaks (earlier 2B bench: both-merged+compile +44% vs eager, #3732's +10% marginal surviving compilation because it substitutes bf16 adapter GEMMs Inductor can't infer).AI Usage Disclaimer
Opus 4.8 / Fable 5 used throughout.
Types of changes
torch_compilesupport for the GatedDeltaNet decoder loop — Qwen3.5 dense/MoE + Qwen3-Next)Summary by CodeRabbit
Release Notes
New Features
torch.compilesupport for Qwen3.5/Qwen3-Next models with optimized compiled execution pathsBug Fixes
Tests