Skip to content

feat(compile): traceable GDN decoder loop (Qwen3.5 + Qwen3-Next) + model-agnostic FA2 varlen collator#3734

Open
thad0ctor wants to merge 11 commits into
axolotl-ai-cloud:mainfrom
thad0ctor:feat/gdn-compiled-decoder-loop
Open

feat(compile): traceable GDN decoder loop (Qwen3.5 + Qwen3-Next) + model-agnostic FA2 varlen collator#3734
thad0ctor wants to merge 11 commits into
axolotl-ai-cloud:mainfrom
thad0ctor:feat/gdn-compiled-decoder-loop

Conversation

@thad0ctor

@thad0ctor thad0ctor commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Description

Under torch_compile, one graph break inside the GatedDeltaNet decoder for-loop made dynamo skip the whole text-model forward frame (a break inside a loop is unresumable), so every layer (attention included) ran eagerly.

This makes the loop traceable for every GatedDeltaNet model (Qwen3.5 dense + MoE, Qwen3-Next), plus a model-agnostic collator path that fixes the same break for any flash-attention + packed model under compile:

  • gated_delta_net_ops.py (new, shared) — opaque torch.library ops (axolotl_gdn::gdn_conv/gdn_chunk + backward) wrapping the same FLA kernels with identical saved tensors. They derive cu_seqlens from position_ids inside the op, so no aten.nonzero enters the graph. A cast_g flag matches each architecture's eager call (qwen3_5 casts g; qwen3_next keeps f32) for bitwise parity; packed input enforces batch-1 with an explicit ValueError. Also hosts the shared FusedRMSNormGated compile boundary.
  • qwen3_5/ + qwen3_next/modeling.py — route the GDN training path through the ops under compile, and install the FusedRMSNormGated opaque wrapper (its backward recomputes via autograd.grad, since the FLA backward isn't meta-traceable). Self-attention stays in-graph under GC+training, behind a dynamo.disable boundary otherwise (guards an Inductor FA2-backward fusion that corrupts packed grads; pinned by test_fa2_compiled_matches_eager_grads). The qwen3_next wiring also fixes a pre-existing crash (its decoder passed cache_position to a GatedDeltaNet.forward that rejects it).
  • collators/batching.py + builders/causal.py — under torch_compile, the multipack collator precomputes the FA2 varlen kwargs (cu_seq_lens_q/k, max_length_q/k) once per batch, so transformers skips its per-layer nonzero derivation. Model-agnostic — gated on SUPPORTED_MULTIPACK_MODEL_TYPES, so any packed FA2 model compiles clean.
  • patch_manager.py — threads torch_compile into the qwen3_5/qwen3_5_moe/qwen3_next packing patches.

Opt-in via torch_compile; inert otherwise.

Version coupling — the ==0.4.1 pin is load-bearing. The ops wrap FLA internal functions by fixed arity and the fakes hardcode 0.4.1 shapes/dtypes; FLA 0.4.2/0.5.0 already changed chunk_gated_delta_rule_fwd's arity (crash at first forward). So a bump needs revalidating the ops — guarded by test_opcheck_custom_ops + test_eager_parity_ops_vs_legacy_bitwise (the call-time crash isn't caught by the build-time warning). The upstream fix fla-org/flash-linear-attention#909 removes only the index-helper breaks, not the nonzero/@torch.compiler.disable ones these ops target, so the wrapper can't just be dropped on a newer FLA.

Known limit: auto-enabled LoRA triton kernels + a GDN packing patch + torch_compile + GC NaN gradients on torch ≤ 2.10 (pre-existing on main; clean on 2.11, and with any one ingredient removed). PatchManager force-disables those kernels with a warning on torch < 2.11.

Motivation and Context

The decoder loop is the bulk of training compute. Skipping its compilation forfeits the largest fusion payoff: the eager RMSNorm fp32 round-trips (~870 GB/step of traffic at 9B/32k, interposer-measured) + per-kernel Python launch overhead. Once the loop traces, Inductor fuses those.

Scope: the GDN ops cover Qwen3.5/Qwen3.6 (model_type qwen3_5/qwen3_5_moe, dense + MoE, text + VL — every shipped checkpoint is *ForConditionalGeneration) and Qwen3-Next (qwen3_next, same FLA GatedDeltaNet kernels). The collator change is model-agnostic. Dense Qwen3/Qwen2.5/Qwen3-VL have no GatedDeltaNet, so the loop-internal breaks don't exist for them — but they still benefit from the collator path under FA2+packing+compile.

How has this been tested?

Stack: torch 2.11.0+cu130, transformers 5.9.0, FLA 0.4.1, liger-kernel 0.8.0, real flash_attn 2.8.3 (the validated stack). RTX 3090 / 3090 Ti / 5090.

Functional — 12 tests, all passing:

tests/monkeypatch/test_qwen3_5_compiled_loop.py (10): zero-breaks-with-GC (1 graph), FA2-loop-with-varlen-kwargs, VL ForConditionalGeneration 3-D MRoPE, aten.nonzero-break-gone, bitwise eager parity (ops vs legacy), compiled-vs-eager, FA2+GC compiled grad parity, toy MoE (0 breaks), B>1-packed-raises, opcheck (packed/MRoPE/dense × bias × both cast_g).
tests/monkeypatch/test_qwen3_next_compiled_loop.py (2): qwen3_next decoder loop 0 breaks + bitwise eager parity (cast_g=False, g stays f32).
Existing test_qwen3_5_fused_attn.py: 12/12 (no regression). Suite also passes one major version down (torch 2.9.1 / 2.10.0 + transformers 5.5.4).

Multimodal: Multimodal training is non-packed in axolotl (SFT vision examples + MM-CPT (pending PR #3629) ship sample_packing: false), so the GDN ops are inert there by construction this was validated with real images (llava-instruct through the Qwen3.5-2B vision tower: packing patches absent from logs, eager-vs-compiled step-1 loss Δ0.004, eval clean).

Multiple test training runs and benches (below).


Benchmarks (all on the validated stack above)

Qwen3.5-2B/9B = Qwen3_5ForConditionalGeneration, LoRA r=16 q/k/v/o, sample packing, bf16, gradient checkpointing, attn_implementation: flash_attention_2. 60 steps, steady-state window after warmup, TORCH_LOGS=recompiles. Throughput = tok/s; VRAM = reserved (GiB). Zero NaN and step-1/final-10 loss parity in every cell.

A. Core compile speedup On real FA2, torch_compile without this PR is near-neutral (the loop graph-breaks); this PR is what makes compile worth it (RTX 3090):

model eager compile (main, no PR) compile + PR PR vs eager PR vs compile-main
Qwen3.5-2B (seq 4096) 3593 tok/s 3788 (+5.4%) 4617 +28.5% +21.9%
Qwen3.5-9B (seq 4096, CCE) 2724 tok/s 2732 (+0.3%) 3330 +22.2% +21.9%

B. Feature compatibility / composition (Qwen3.5-2B, compile + PR, varying feature): (RTX 3090)

config tok/s reserved VRAM note
compile + PR (LoRA kernels on — default) 4617 14.66 GiB baseline
+ LoRA kernels off 4585 14.40 GiB kernels ≈ neutral under compile
+ fused_attn_kernel (#3680) 4394 14.67 GiB −4.8% — redundant under compile
+ axolotl CCE (cut_cross_entropy) 4781 6.03 GiB fastest and ~2.4× less memory
+ liger fused_linear_cross_entropy 4593 14.66 GiB no-op on Qwen3.5 (= no memory benefit) — requires overlooked gating bug fix following liger 0.8.0 bump
+ liger full stack (rms_norm/gated/swiglu/rope) 4137 10.60 GiB −10% — kernels redundant under compile
(eager + fused_attn_kernel) 3648 18.07 GiB eager: fused_attn only +1.5%

LoRA triton kernels and fused_attn_kernel are subsumed by Inductor once the loop compiles (consistent with #3680's own "fused+compile −18%" dense-Qwen3 numbers); liger composes cleanly but adds no value under compile.

C. Generalization — the model-agnostic collator (compile on/off, FA2 + packing): (RTX 5090)

model eager compile + PR speedup reserved (eager → compile)
Llama-3.2-1B (seq 2048) 5181 tok/s 7572 +46.1% 6.09 → 4.54 GiB
Qwen3-0.6B (seq 2048) 5445 tok/s 11375 +108.9% 5.54 → 5.01 GiB

Neither model has a GatedDeltaNet so speedups come from the collator removing the FA2 per-layer break. Isolation (toy Llama, FA2 + packed, under compile): 3 graph breaks (4 graphs) without the varlen kwargs → 0 breaks (1 graph) with them. 0 recompiles in steady state; loss parity holds.


Convergence: loss is bit-stable across every cell above — step-1 identical to eager and final-10 within float noise (e.g. 2B eager 1.1275 vs compile+PR 1.1275; 9B 0.9264 vs 0.9266). The opaque-op path is bitwise identical to legacy eager (regression-tested) for both qwen3_5 and qwen3_next.

Composition with #3732 (fused-LoRA GDN routing): merges cleanly, both touch the GDN forward, the projection call sites route through _la_proj_fwd while the compiled-ops gating stays intact, and the fused LoRA_O autograd Function traces inside the compiled loop with 0 graph breaks (earlier 2B bench: both-merged+compile +44% vs eager, #3732's +10% marginal surviving compilation because it substitutes bf16 adapter GEMMs Inductor can't infer).

AI Usage Disclaimer

Opus 4.8 / Fable 5 used throughout.

Types of changes

  • New feature (torch_compile support for the GatedDeltaNet decoder loop — Qwen3.5 dense/MoE + Qwen3-Next)
  • New feature (model-agnostic FA2 varlen collator kwargs for any packed model under compile)
  • Performance improvement (non-breaking; numerics-neutral, bitwise eager fallback)
  • New tests

Summary by CodeRabbit

Release Notes

  • New Features

    • Added torch.compile support for Qwen3.5/Qwen3-Next models with optimized compiled execution paths
    • Enhanced FlashAttention v2 varlen kwargs precomputation for improved performance
    • Implemented custom compiled operators for kernel optimization
  • Bug Fixes

    • Fixed NaN gradients in LoRA kernels for specific Qwen3.5 configurations
    • Eliminated unwanted graph breaks in compiled execution paths
  • Tests

    • Added comprehensive CUDA regression tests for compiled decoder loops with eager/compiled parity validation

thad0ctor added 10 commits June 12, 2026 12:43
…eable packing

With torch_compile, a single in-loop graph break made dynamo skip the entire
Qwen3_5TextModel.forward frame, so every decoder layer ran eagerly (no
norm/rope/gating/residual fusion, full per-kernel launch overhead). The
in-loop blockers were: axolotl's get_cu_seqlens() nonzero in the packing
patch, FLA's @torch.compiler.disable'd chunk_gated_delta_rule, causal_conv1d's
.item(), and FusedRMSNormGated's untraceable device-property probe.

- New fla_ops.py: opaque torch.library custom ops (axolotl_qwen3_5::gdn_conv /
  gdn_chunk + backward) wrapping the same FLA host functions with identical
  saved tensors (no recompute, no kernel changes). They take position_ids and
  derive cu_seqlens eagerly inside the op, so no data-dependent op enters the
  graph and every fake impl is static.
- modeling.py: the GDN training/no-cache path routes through the ops; a
  FusedRMSNormGated custom-op wrapper installs whenever torch_compile is on;
  decoder self-attention is traced when gradient_checkpointing and training.
- patch_manager.py threads torch_compile=bool(cfg.torch_compile) into the
  packing patches.

Bit-exactness: eager FLA's dg lands on the bf16 grid (reproduced via an
explicit f32->bf16->f32 round-trip); v reaches the kernels as a non-contiguous
split/reshape view (the bwd op mirrors FLA's input_guard contiguization).

Verified (tiny 4-layer hybrid): 0 graph breaks / 1 graph for the full text
model under GC; eager ops-vs-legacy bitwise identical; compiled-vs-eager
within compile noise. monkeypatch compiled-loop suite 4/4, qwen3_5 fused-attn 12/12.

Note: flash_attention_2 configs still don't get the compiled loop — the
remaining breaker is transformers' per-layer varlen derivation on the FA2
path, fixable purely axolotl-side by emitting precomputed cu_seq_lens/max_length
from the multipack collator (follow-up, not in this change).
… GDN ops

_FLA_COMPILED_OPS now follows the torch_compile flag, so it is set to False when torch_compile=False instead of always enabling the opaque GDN ops.
… the compiled loop

The compiled decoder loop (the preceding commit) works under sdpa but not
flash_attention_2: with packed sequences signalled only via position_ids,
transformers re-derives the varlen metadata per layer inside
_flash_attention_forward (a data-dependent _is_packed_sequence branch +
(position_ids==0).nonzero()), which graph-breaks inside the decoder loop and
makes dynamo skip the whole frame.

transformers already ships the escape hatch: if the caller passes the
FlashAttentionKwargs (cu_seq_lens_q/k + max_length_q/k), is_fa_with_varlen_kwargs
short-circuits before the data-dependent branch and the per-layer derivation is
skipped. This emits exactly those kwargs from the multipack collator, computed
once via transformers' own prepare_fa_kwargs_from_position_ids (so the metadata
is bit-identical to what it would derive per layer). Gated to FA2 + sample
packing + qwen3_5/qwen3_5_moe collators; max_length stays a python int so no
capture_scalar_outputs is needed.

Verified: FA2 + torch_compile now traces the loop with 0 graph breaks (was 1
break / 4 fragmented graphs); eager FA2 loss with-vs-without the kwargs is
bitwise identical. compiled-loop suite 6/6.
…acks, formatting

- build_collator: gate emit_fa_varlen_kwargs on the eval-aware packing state
  (training_args.eval_sample_packing for eval loaders) instead of
  cfg.sample_packing, so eval-only packed mode also gets the precomputed kwargs.
- collator + FusedRMSNormGated boundary: replace silent exception swallows with
  warnings so a disabled compiled-loop optimization is visible.
- ruff-format the new test assert; collapse a two-line comment.
…lback, B>1 guard, weight-None norm fallback, saved-v contiguity; add MoE/FA2-parity/opcheck/B>1 tests
…undary (torch 2.11-only, unreproduced on 2.9/2.10)
- collator: emit FA2 varlen kwargs for ALL multipack models under torch_compile
  (was gated to qwen3_5/qwen3_5_moe); model-agnostic, transformers consumes them natively
- share the GatedDeltaNet opaque ops: move fla_ops -> gated_delta_net_ops (neutral
  axolotl_gdn namespace), add cast_g flag so qwen3_5 (casts g) and qwen3_next (f32 g)
  both stay bit-exact; move the FusedRMSNormGated compile boundary into the shared module
- wire qwen3_next: route conv+chunk through the shared ops under compile, install the
  rmsnorm boundary + self-attn dynamo boundary, thread torch_compile; also fix a
  pre-existing crash (decoder passed cache_position to GatedDeltaNet.forward)
- extend the lora-kernel NaN guard to qwen3_next
- tests: qwen3_next compiled-loop (0 breaks + bitwise eager parity); opcheck both cast_g
@coderabbitai

coderabbitai Bot commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 8a26740f-81b7-44ac-9ad3-9761122523b5

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds torch.compile support for Qwen3.5 and Qwen3-Next packed decoder loops. Introduces a new module with opaque torch.library.custom_op wrappers for GatedDeltaNet FLA kernels to prevent aten.nonzero from entering the compile graph, updates both model monkeypatches to route through these compiled ops, precomputes FlashAttention varlen kwargs in the collator, guards LoRA triton kernels on PyTorch < 2.11.0, and adds CUDA regression tests.

Changes

torch.compile for Qwen3.5/Qwen3-Next Packed Decoder Loops

Layer / File(s) Summary
GatedDeltaNet custom ops and FusedRMSNormGated compile boundary
src/axolotl/monkeypatch/models/gated_delta_net_ops.py
New module registers opaque torch.library.custom_op wrappers axolotl_gdn.gdn_conv and axolotl_gdn.gdn_chunk that eagerly compute cu_seqlens from position_ids inside the op boundary, add register_fake/autograd implementations, and install a compile boundary for FusedRMSNormGated.forward with a torch._dynamo.disable fallback.
Qwen3.5 monkeypatch torch_compile routing
src/axolotl/monkeypatch/models/qwen3_5/modeling.py
Adds _FLA_COMPILED_OPS tracking and _init_fla_compiled_ops(); decoder forward routes self-attention through a dynamo.disable wrapper outside GC; gated-delta forward selects between axolotl_gdn.gdn_conv/gdn_chunk (compiled) and eager FLA paths based on availability; _apply_packing_patches and public entrypoints gain torch_compile parameter.
Qwen3-Next monkeypatch torch_compile routing
src/axolotl/monkeypatch/models/qwen3_next/modeling.py
Same compiled-op routing pattern applied to Qwen3-Next: module-level _FLA_COMPILED_OPS, dynamo.disable self-attn wrapper, position_ids forwarded to linear_attn without cache_position, GatedDeltaNet forward branches through custom ops when available; all four public patch functions gain torch_compile.
FA varlen kwargs precomputation in collator and trainer builder
src/axolotl/utils/collators/batching.py, src/axolotl/core/builders/causal.py
DataCollatorForSeq2Seq gains emit_fa_varlen_kwargs: bool = False; when set, precomputes cu_seq_lens_q/k and max_length_q/k from position_ids with warning-on-failure. HFCausalTrainerBuilder.build_collator sets this flag for batch-sampler collators under FA2 + torch_compile + multipack model types.
PatchManager: torch_compile threading and LoRA kernel guard
src/axolotl/loaders/patch_manager.py
New pre-setup helper disables LoRA triton kernel flags (lora_qkv_kernel, lora_o_kernel, lora_mlp_kernel) for the Qwen3-family + packed + torch_compile + adapter + GC combination on PyTorch < 2.11.0; Qwen3-Next/Qwen3.5/Qwen3.5-MoE patch invocations now forward torch_compile=bool(self.cfg.torch_compile).
CUDA regression tests for compiled decoder loops
tests/monkeypatch/test_qwen3_5_compiled_loop.py, tests/monkeypatch/test_qwen3_next_compiled_loop.py
CUDA-only tests assert zero Dynamo graph breaks, unique_graphs >= 1, bitwise parity between compiled-op and eager paths, loss/grad tolerances under torch.compile, FA2 varlen kwarg paths, VL conditional generation, MoE loop compilation, batch-size > 1 error, and torch.library.opcheck for gdn_chunk/gdn_conv.

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

  • axolotl-ai-cloud/axolotl#3150: Introduces Qwen3-Next packing patch activation in patch_manager, directly connected to this PR's extension of patch_qwen3_next_modeling_packing with torch_compile support.
  • axolotl-ai-cloud/axolotl#3442: Introduces the Qwen3.5 modeling packing monkeypatch (qwen3_5/modeling.py) that this PR extends with compiled-op routing and torch_compile parameter.
  • axolotl-ai-cloud/axolotl#3561: Modifies patch_manager.py to gate Qwen3.5 sample-packing patches, directly related to this PR's additions passing torch_compile into those same patch invocations.

Suggested labels

ready to merge

Suggested reviewers

  • winglian
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.79% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: adding torch_compile support for the GatedDeltaNet decoder loop in Qwen models with a model-agnostic FlashAttention2 varlen collator enhancement.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@thad0ctor thad0ctor changed the title Feat/gdn compiled decoder loop feat(compile): traceable GDN decoder loop (Qwen3.5 + Qwen3-Next) + model-agnostic FA2 varlen collator Jun 14, 2026

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
src/axolotl/monkeypatch/models/qwen3_next/modeling.py (1)

62-140: 💤 Low value

Minor inconsistency with qwen3_5 decoder forward.

In qwen3_5, the _call_self_attn_disabled wrapper is always used for the non-gradient-checkpointing self-attention path (lines 134-143 in qwen3_5/modeling.py), regardless of torch_compile. Here, it's only applied when torch_compile=True.

This inconsistency is harmless (_call_self_attn_disabled is essentially a no-op when dynamo isn't active), but aligning the behavior would improve maintainability.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/monkeypatch/models/qwen3_next/modeling.py` around lines 62 - 140,
The patched_decoder_forward function has an inconsistency with the qwen3_5
implementation. In qwen3_5, the _call_self_attn_disabled wrapper is always used
for the non-gradient-checkpointing self-attention path, but here it's only
applied when torch_compile=True. To align the behavior, remove the torch_compile
check from the conditional statement in the "full_attention" token mixer path so
that _call_self_attn_disabled is always called when gradient checkpointing is
not enabled and training is not active, regardless of the torch_compile
parameter value.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/axolotl/utils/collators/batching.py`:
- Around line 144-150: The FA varlen kwargs precomputation exception handler in
the batching.py file is catching the overly broad Exception class, which can
mask unexpected errors and make debugging difficult. Replace the generic
Exception catch with specific exception types that are actually expected to fail
during the FA varlen precompute operation (such as AttributeError, TypeError, or
RuntimeError, depending on what failures are genuinely expected), so that
unexpected exceptions will properly propagate and alert developers to actual
collator regressions rather than silently falling back to the slow path.

In `@tests/monkeypatch/test_qwen3_5_compiled_loop.py`:
- Around line 27-44: The fixture teardown for _axolotl_compile_boundary in
tests/monkeypatch/test_qwen3_5_compiled_loop.py lines 27-44 conflates a missing
attribute with one that has a False value by using getattr with a False default.
Instead of using getattr with a default False, capture both the presence of the
attribute (using hasattr) and its actual value separately in the saved
dictionary during setup. During teardown, check the presence flag: if the
attribute was originally present, restore its saved value using setattr; if it
was originally absent, delete it using delattr. Apply the identical
presence-aware snapshot and restore logic to
tests/monkeypatch/test_qwen3_next_compiled_loop.py lines 26-43 to ensure
consistent behavior across both test files.
- Around line 422-423: The pytest.raises context manager at line 422 is catching
any Exception type, which can mask unrelated failures that happen to match the
same message text. Change the exception type from Exception to ValueError in the
pytest.raises call to be more specific about what exception should actually be
raised when the batch size contract is violated. This narrower exception type
will ensure only the intended ValueError with the "batch size is expected to be
1" message is caught, preventing unrelated exceptions from being silently
masked.

---

Nitpick comments:
In `@src/axolotl/monkeypatch/models/qwen3_next/modeling.py`:
- Around line 62-140: The patched_decoder_forward function has an inconsistency
with the qwen3_5 implementation. In qwen3_5, the _call_self_attn_disabled
wrapper is always used for the non-gradient-checkpointing self-attention path,
but here it's only applied when torch_compile=True. To align the behavior,
remove the torch_compile check from the conditional statement in the
"full_attention" token mixer path so that _call_self_attn_disabled is always
called when gradient checkpointing is not enabled and training is not active,
regardless of the torch_compile parameter value.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b35d2e01-4abc-43aa-9288-ab03bdb5f258

📥 Commits

Reviewing files that changed from the base of the PR and between 22bcb9a and 02b6a7f.

📒 Files selected for processing (8)
  • src/axolotl/core/builders/causal.py
  • src/axolotl/loaders/patch_manager.py
  • src/axolotl/monkeypatch/models/gated_delta_net_ops.py
  • src/axolotl/monkeypatch/models/qwen3_5/modeling.py
  • src/axolotl/monkeypatch/models/qwen3_next/modeling.py
  • src/axolotl/utils/collators/batching.py
  • tests/monkeypatch/test_qwen3_5_compiled_loop.py
  • tests/monkeypatch/test_qwen3_next_compiled_loop.py

Comment thread src/axolotl/utils/collators/batching.py
Comment thread tests/monkeypatch/test_qwen3_5_compiled_loop.py Outdated
Comment thread tests/monkeypatch/test_qwen3_5_compiled_loop.py Outdated
@codecov

codecov Bot commented Jun 14, 2026

Copy link
Copy Markdown

- batching.py: narrow the FA varlen fallback catch to the realistic exception
  set (ImportError/AttributeError/RuntimeError/TypeError/ValueError/IndexError)
  instead of bare Exception (Ruff BLE001)
- tests: presence-aware restore of FusedRMSNormGated._axolotl_compile_boundary
  (snapshot hasattr + value) in both compiled-loop fixtures
- tests: narrow B>1 pytest.raises to ValueError (verified: the op raises a clean
  ValueError through the model forward)
- qwen3_next: route non-GC self-attn through the dynamo.disable boundary
  unconditionally (a no-op when not compiling), matching qwen3_5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant