feat(compile): traceable GatedDeltaNet decoder loop (Qwen3.5 + Qwen3-Next) + model-agnostic FA2 varlen collator#49
feat(compile): traceable GatedDeltaNet decoder loop (Qwen3.5 + Qwen3-Next) + model-agnostic FA2 varlen collator#49thad0ctor wants to merge 20 commits into
Conversation
…eable packing With torch_compile, a single in-loop graph break made dynamo skip the entire Qwen3_5TextModel.forward frame, so every decoder layer ran eagerly (no norm/rope/gating/residual fusion, full per-kernel launch overhead). The in-loop blockers were: axolotl's get_cu_seqlens() nonzero in the packing patch, FLA's @torch.compiler.disable'd chunk_gated_delta_rule, causal_conv1d's .item(), and FusedRMSNormGated's untraceable device-property probe. - New fla_ops.py: opaque torch.library custom ops (axolotl_qwen3_5::gdn_conv / gdn_chunk + backward) wrapping the same FLA host functions with identical saved tensors (no recompute, no kernel changes). They take position_ids and derive cu_seqlens eagerly inside the op, so no data-dependent op enters the graph and every fake impl is static. - modeling.py: the GDN training/no-cache path routes through the ops; a FusedRMSNormGated custom-op wrapper installs whenever torch_compile is on; decoder self-attention is traced when gradient_checkpointing and training. - patch_manager.py threads torch_compile=bool(cfg.torch_compile) into the packing patches. Bit-exactness: eager FLA's dg lands on the bf16 grid (reproduced via an explicit f32->bf16->f32 round-trip); v reaches the kernels as a non-contiguous split/reshape view (the bwd op mirrors FLA's input_guard contiguization). Verified (tiny 4-layer hybrid): 0 graph breaks / 1 graph for the full text model under GC; eager ops-vs-legacy bitwise identical; compiled-vs-eager within compile noise. monkeypatch compiled-loop suite 4/4, qwen3_5 fused-attn 12/12. Note: flash_attention_2 configs still don't get the compiled loop — the remaining breaker is transformers' per-layer varlen derivation on the FA2 path, fixable purely axolotl-side by emitting precomputed cu_seq_lens/max_length from the multipack collator (follow-up, not in this change).
… GDN ops _FLA_COMPILED_OPS now follows the torch_compile flag, so it is set to False when torch_compile=False instead of always enabling the opaque GDN ops.
… the compiled loop The compiled decoder loop (the preceding commit) works under sdpa but not flash_attention_2: with packed sequences signalled only via position_ids, transformers re-derives the varlen metadata per layer inside _flash_attention_forward (a data-dependent _is_packed_sequence branch + (position_ids==0).nonzero()), which graph-breaks inside the decoder loop and makes dynamo skip the whole frame. transformers already ships the escape hatch: if the caller passes the FlashAttentionKwargs (cu_seq_lens_q/k + max_length_q/k), is_fa_with_varlen_kwargs short-circuits before the data-dependent branch and the per-layer derivation is skipped. This emits exactly those kwargs from the multipack collator, computed once via transformers' own prepare_fa_kwargs_from_position_ids (so the metadata is bit-identical to what it would derive per layer). Gated to FA2 + sample packing + qwen3_5/qwen3_5_moe collators; max_length stays a python int so no capture_scalar_outputs is needed. Verified: FA2 + torch_compile now traces the loop with 0 graph breaks (was 1 break / 4 fragmented graphs); eager FA2 loss with-vs-without the kwargs is bitwise identical. compiled-loop suite 6/6.
…acks, formatting - build_collator: gate emit_fa_varlen_kwargs on the eval-aware packing state (training_args.eval_sample_packing for eval loaders) instead of cfg.sample_packing, so eval-only packed mode also gets the precomputed kwargs. - collator + FusedRMSNormGated boundary: replace silent exception swallows with warnings so a disabled compiled-loop optimization is visible. - ruff-format the new test assert; collapse a two-line comment.
…lback, B>1 guard, weight-None norm fallback, saved-v contiguity; add MoE/FA2-parity/opcheck/B>1 tests
…undary (torch 2.11-only, unreproduced on 2.9/2.10)
…GC on torch<2.11 (NaN gradients)
- collator: emit FA2 varlen kwargs for ALL multipack models under torch_compile (was gated to qwen3_5/qwen3_5_moe); model-agnostic, transformers consumes them natively - share the GatedDeltaNet opaque ops: move fla_ops -> gated_delta_net_ops (neutral axolotl_gdn namespace), add cast_g flag so qwen3_5 (casts g) and qwen3_next (f32 g) both stay bit-exact; move the FusedRMSNormGated compile boundary into the shared module - wire qwen3_next: route conv+chunk through the shared ops under compile, install the rmsnorm boundary + self-attn dynamo boundary, thread torch_compile; also fix a pre-existing crash (decoder passed cache_position to GatedDeltaNet.forward) - extend the lora-kernel NaN guard to qwen3_next - tests: qwen3_next compiled-loop (0 breaks + bitwise eager parity); opcheck both cast_g
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdds opaque Changestorch.compile GatedDeltaNet Opaque Op Integration
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
- batching.py: narrow the FA varlen fallback catch to the realistic exception set (ImportError/AttributeError/RuntimeError/TypeError/ValueError/IndexError) instead of bare Exception (Ruff BLE001) - tests: presence-aware restore of FusedRMSNormGated._axolotl_compile_boundary (snapshot hasattr + value) in both compiled-loop fixtures - tests: narrow B>1 pytest.raises to ValueError (verified: the op raises a clean ValueError through the model forward) - qwen3_next: route non-GC self-attn through the dynamo.disable boundary unconditionally (a no-op when not compiling), matching qwen3_5
* feat: update cce for new models * feat: update transformers * chore: remove dead code * feat: add liger for gemma4 unified * fix: hybrid attn for latest transformers * feat: add gemma4 unified following gemma4 * chore: refactor old logger to axolotl get_logger * fix: update legacy env to new xet env * feat: add missed files * fix: handling of lora kernels * feat: update vision and readme yaml * chore: update numbers from latest run * feat: add text config * chore: update correct number * fix: update cce commit * fix: packing leak * use transformers patch release * 2 parallel jobs for pytests * fix: gate attention_mask for gemma4_unified * fix: restore prior gemma4 e2b shared kv layer helper * chore: refactor gemma4 hybrid attn * feat: update gemma4 results and config * chore: simplify config * fix: update unified results and docs * chore: swap to hybrid attn * feat: add tests * fix: swap to FA2 text * fix: ci logging * fix: generalize rotary patch * fix: deleted file for docs * fix: fsdp defaulted to v2 * fix: support simplenamespace for test * fix: update quarto to include all current scripts * fix: drop quarto doc entry for untracked deepseek_v4 module --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
* fix version mismatch 4 pirate * add whitlist for collab --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
…tent in qwen3_5 template (axolotl-ai-cloud#3725) [skip ci] The inline-<think> assistant branch reassigned content (stripping it to the post-</think> answer) before reasoning_content was extracted from it. Since reasoning_content reads from the already-truncated content, the reasoning trace was dropped and the answer leaked into the <think> block. Swap the two set statements to match the official Qwen3.5 template order.
…onfig (axolotl-ai-cloud#3730) [skip ci] * fix: KTO user_defined dataset transform crashes on every documented config The user_defined.default KTO strategy was broken in all configurations: - when completion_format was provided, the default was assigned to a misnamed chosen_format variable only in the fallback branch, so the transform raised NameError: chosen_format - when completion_format was omitted, the generated placeholder name did not match the .format() keyword (chosen= vs {completion}), raising KeyError - prompt formatting read sample['prompt'] instead of sample[field_prompt], breaking custom field_prompt configs Also surface the underlying exception when a prompt strategy fails to load instead of silently returning None, which previously crashed later with the unhelpful 'TypeError: None is not a callable object'. Fixes axolotl-ai-cloud#2757 Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> * docs: add docstrings to KTO user_defined tests for coverage check Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> * fix: wrong pydantic type * feat: updated test to handle e2e validation --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
…ip ci] * fix: fail early for CI if meet CUDA error * fix: switch to clean abort
* feat: add interactive multi-turn chat mode (--chat) to inference CLI * fix: apply_chat_template returns BatchEncoding in transformers v5 * docs: document interactive chat mode for inference * feat: diffusion turn generation for chat mode; fix fp8 probe on CPU-only torch * feat: suggest command aliases on typo; save chat sessions in multimodal parts format * feat: collapse thinking blocks in chat with /expand, /think template toggle, thinking token stats * fix: suppress unauthenticated HF Hub nag warning in logging config * fix: harden chat REPL against interrupts and command errors; store assistant turns in parse_response format - Ctrl+C during a (diffusion) turn no longer crashes the REPL; the session survives and the user message is kept - exceptions in slash-command handlers no longer kill the session - consecutive user messages merge so strict templates never see two user turns after a failed generation - assistant turns are stored without special tokens, with thinking under reasoning_content (tokenizer parse_response schema when available, think-marker split otherwise); EOS markers no longer leak into the streamed display * perf(chat): lighten the turn loop - /new now drops the cross-turn KV cache instead of leaving it on device until the next generation - throttle live thinking-tail rerenders to the 12 Hz repaint rate (was O(n^2) splitlines over the full think text per chunk) - split think markers once per turn and reuse for counts and the stored message, dropping the redundant full decode * refactor(chat): share the live thinking-tail FPS as a class constant * fix: interrupt cache race condition and parse edge case
|
📖 Documentation Preview: Deployed on Netlify from commit 9953bce |
Description
Under
torch_compile, one graph break inside the GatedDeltaNet decoder for-loop made dynamo skip the whole text-modelforwardframe (a break inside a loop is unresumable), so every layer (attention included) ran eagerly.This makes the loop traceable for every GatedDeltaNet model (Qwen3.5 dense + MoE, Qwen3-Next), plus a model-agnostic collator path that fixes the same break for any flash-attention + packed model under compile:
gated_delta_net_ops.py(new, shared) — opaquetorch.libraryops (axolotl_gdn::gdn_conv/gdn_chunk+ backward) wrapping the same FLA kernels with identical saved tensors. They derivecu_seqlensfromposition_idsinside the op, so noaten.nonzeroenters the graph. Acast_gflag matches each architecture's eager call (qwen3_5 castsg; qwen3_next keeps f32) for bitwise parity; packed input enforces batch-1 with an explicitValueError. Also hosts the sharedFusedRMSNormGatedcompile boundary.qwen3_5/+qwen3_next/modeling.py— route the GDN training path through the ops under compile, and install theFusedRMSNormGatedopaque wrapper (its backward recomputes viaautograd.grad, since the FLA backward isn't meta-traceable). Self-attention stays in-graph under GC+training, behind adynamo.disableboundary otherwise (guards an Inductor FA2-backward fusion that corrupts packed grads; pinned bytest_fa2_compiled_matches_eager_grads). The qwen3_next wiring also fixes a pre-existing crash (its decoder passedcache_positionto aGatedDeltaNet.forwardthat rejects it).collators/batching.py+builders/causal.py— undertorch_compile, the multipack collator precomputes the FA2 varlen kwargs (cu_seq_lens_q/k,max_length_q/k) once per batch, so transformers skips its per-layernonzeroderivation. Model-agnostic — gated onSUPPORTED_MULTIPACK_MODEL_TYPES, so any packed FA2 model compiles clean.patch_manager.py— threadstorch_compileinto the qwen3_5/qwen3_5_moe/qwen3_next packing patches.Opt-in via
torch_compile; inert otherwise.Version coupling — the
==0.4.1pin is load-bearing. The ops wrap FLA internal functions by fixed arity and the fakes hardcode 0.4.1 shapes/dtypes; FLA 0.4.2/0.5.0 already changedchunk_gated_delta_rule_fwd's arity (crash at first forward). So a bump needs revalidating the ops — guarded bytest_opcheck_custom_ops+test_eager_parity_ops_vs_legacy_bitwise(the call-time crash isn't caught by the build-time warning). The upstream fix fla-org/flash-linear-attention#909 removes only the index-helper breaks, not thenonzero/@torch.compiler.disableones these ops target, so the wrapper can't just be dropped on a newer FLA.Known limit: auto-enabled LoRA triton kernels + a GDN packing patch +
torch_compile+ GC NaN gradients on torch ≤ 2.10 (pre-existing on main; clean on 2.11, and with any one ingredient removed).PatchManagerforce-disables those kernels with a warning on torch < 2.11.Motivation and Context
The decoder loop is the bulk of training compute. Skipping its compilation forfeits the largest fusion payoff: the eager RMSNorm fp32 round-trips (~870 GB/step of traffic at 9B/32k, interposer-measured) + per-kernel Python launch overhead. Once the loop traces, Inductor fuses those.
Scope: the GDN ops cover Qwen3.5/Qwen3.6 (
model_typeqwen3_5/qwen3_5_moe, dense + MoE, text + VL — every shipped checkpoint is*ForConditionalGeneration) and Qwen3-Next (qwen3_next, same FLA GatedDeltaNet kernels). The collator change is model-agnostic. Dense Qwen3/Qwen2.5/Qwen3-VL have no GatedDeltaNet, so the loop-internal breaks don't exist for them — but they still benefit from the collator path under FA2+packing+compile.How has this been tested?
Stack: torch 2.11.0+cu130, transformers 5.9.0, FLA 0.4.1, liger-kernel 0.8.0, real flash_attn 2.8.3 (the validated stack). RTX 3090 / 3090 Ti / 5090.
Functional — 12 tests, all passing:
tests/monkeypatch/test_qwen3_5_compiled_loop.py(10): zero-breaks-with-GC (1 graph), FA2-loop-with-varlen-kwargs, VLForConditionalGeneration3-D MRoPE,aten.nonzero-break-gone, bitwise eager parity (ops vs legacy), compiled-vs-eager, FA2+GC compiled grad parity, toy MoE (0 breaks), B>1-packed-raises,opcheck(packed/MRoPE/dense × bias × bothcast_g).tests/monkeypatch/test_qwen3_next_compiled_loop.py(2): qwen3_next decoder loop 0 breaks + bitwise eager parity (cast_g=False, g stays f32).Existing
test_qwen3_5_fused_attn.py: 12/12 (no regression). Suite also passes one major version down (torch 2.9.1 / 2.10.0 + transformers 5.5.4).Multimodal: Multimodal training is non-packed in axolotl (SFT vision examples + MM-CPT (pending PR axolotl-ai-cloud#3629) ship
sample_packing: false), so the GDN ops are inert there by construction this was validated with real images (llava-instruct through the Qwen3.5-2B vision tower: packing patches absent from logs, eager-vs-compiled step-1 loss Δ0.004, eval clean).Multiple test training runs and benches (below).
Benchmarks (all on the validated stack above)
Qwen3.5-2B/9B =
Qwen3_5ForConditionalGeneration, LoRA r=16 q/k/v/o, sample packing, bf16, gradient checkpointing,attn_implementation: flash_attention_2. 60 steps, steady-state window after warmup,TORCH_LOGS=recompiles. Throughput = tok/s; VRAM = reserved (GiB). Zero NaN and step-1/final-10 loss parity in every cell.A. Core compile speedup On real FA2,
torch_compilewithout this PR is near-neutral (the loop graph-breaks); this PR is what makes compile worth it (RTX 3090):B. Feature compatibility / composition (Qwen3.5-2B, compile + PR, varying feature): (RTX 3090)
fused_attn_kernel(axolotl-ai-cloud#3680)cut_cross_entropy)fused_linear_cross_entropyfused_attn_kernel)LoRA triton kernels and
fused_attn_kernelare subsumed by Inductor once the loop compiles (consistent with axolotl-ai-cloud#3680's own "fused+compile −18%" dense-Qwen3 numbers); liger composes cleanly but adds no value under compile.C. Generalization — the model-agnostic collator (compile on/off, FA2 + packing): (RTX 5090)
Neither model has a GatedDeltaNet so speedups come from the collator removing the FA2 per-layer break. Isolation (toy Llama, FA2 + packed, under compile): 3 graph breaks (4 graphs) without the varlen kwargs → 0 breaks (1 graph) with them. 0 recompiles in steady state; loss parity holds.
Convergence: loss is bit-stable across every cell above — step-1 identical to eager and final-10 within float noise (e.g. 2B eager 1.1275 vs compile+PR 1.1275; 9B 0.9264 vs 0.9266). The opaque-op path is bitwise identical to legacy eager (regression-tested) for both qwen3_5 and qwen3_next.
Composition with axolotl-ai-cloud#3732 (fused-LoRA GDN routing): merges cleanly, both touch the GDN forward, the projection call sites route through
_la_proj_fwdwhile the compiled-ops gating stays intact, and the fusedLoRA_Oautograd Function traces inside the compiled loop with 0 graph breaks (earlier 2B bench: both-merged+compile +44% vs eager, axolotl-ai-cloud#3732's +10% marginal surviving compilation because it substitutes bf16 adapter GEMMs Inductor can't infer).(Orthogonal note: any LoRA + GC + compile run carries one pre-loop graph break from transformers'
enable_input_require_gradsembedding hook; on torch 2.11 it fires before the decoder loop and doesn't affect it.)AI Usage Disclaimer
Opus 4.8 / Fable 5 used throughout.
Types of changes
torch_compilesupport for the GatedDeltaNet decoder loop — Qwen3.5 dense/MoE + Qwen3-Next)Summary by CodeRabbit
New Features
torch.compileoptimization support for Qwen3.5 and Qwen3Next models with improved kernel handling.Bug Fixes
Tests
torch.compilecompatibility and performance parity for Qwen model architectures.