feat(mm-cpt): broaden multimodal CPT dataset paths#31
Conversation
* cp fix for nemo * nemo and flcon patch * import patch * Revert "import patch" This reverts commit ef42d1f. * undo falcon * pakcing + mamba support for nemo , falcon , grenite zamba * training run bugs * docks * doc string coverage + test * mamba guard * 2k n 2*1k test * not is_cp_active() * seq_len fix * model list * val ring atten fix * disable double spliting in hf * less comments * undo zamba and bamba * new configs * lint
* feat: update transformers to 5.8.1 * ignore uv.lock for now --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
* feat-qgalore * spell check --------- Co-authored-by: Your Name <you@example.com>
* support with autoprocessor * simple add dict --------- Co-authored-by: Your Name <you@example.com>
… [skip ci] * fix AssertionError: Original QKV code not found * skip ig gemma for lor a * fix misleading commentsT_T'
Co-authored-by: Your Name <you@example.com>
* rmv skip * test verison * lint * undo
* fix: ep test missed teardown * fix: change hardcoded ports
…#3679) [skip ci] * fix broken MX tests from transformers 5.8.1 upgrade * test isolation * wrap for torchao possible import error * isolate reward model test more * fix PRM
…olotl-ai-cloud#3670) * feat(fsdp2): add fp32_norms for keeping RMSNorm/LayerNorm in fp32 Add an opt-in config flag that shards norm modules under their own FSDP2 MixedPrecisionPolicy (fp32) before the standard decoder-layer wrap, so the norm and decoder shard groups stay independent. This lets models that declare fp32 norms for training stability train under FSDP2 while the rest of the model runs in bf16/fp16. FSDP1 enforces flat-param dtype uniformity within each wrap group, which is incompatible with keeping norms in fp32; the validator therefore requires fsdp_version: 2. Matching: patterns without a "." match type(module).__name__ as a suffix (catches LlamaRMSNorm, Qwen3RMSNorm, AfmoeRMSNorm, nn.LayerNorm, etc.); patterns containing a "." match the fully qualified class path exactly. Defaults to ["RMSNorm", "LayerNorm"]. Signed-off-by: Wing Lian <wing@axolotl.ai> * fixup! feat(fsdp2): address review findings + fix CI caplog assertions - matcher: skip empty/whitespace-only patterns (cls_name.endswith("") is True for any class, which would silently match everything). - validator: also require fsdp_config to be set, not just fsdp_version==2. fsdp_config is the canonical "is_fsdp" signal elsewhere in the codebase (used by check_fsdp_torch_version, sample_packing validators, etc.). - tests: temporarily flip propagate=True on the `axolotl` logger so pytest caplog can see the warnings. axolotl.cli.configure_logging() sets propagate=False at import time, which is the documented reason the assertions were failing in CI even though the warnings were firing visibly in stdout. - comment: replace multi-line rationale near the fp32_norms helpers with a one-line summary (the longer version lives in the PR description). Signed-off-by: Wing Lian <wing@axolotl.ai> * test(fsdp2): multi-GPU e2e for fp32_norms with dtype-preservation assertion The existing fp32_norms tests are pure-CPU and monkeypatch fully_shard — they cover the matcher logic and validator guard rails but never exercise the actual FSDP2 path that motivated this PR. Adds tests/e2e/multigpu/test_fsdp2_fp32_norms.py: spawns a 2-GPU `axolotl train` subprocess with `fp32_norms: true` + `fsdp_version: 2` + `bf16: true` on tiny-qwen3-129m (full FT, 2 steps) and asserts: 1. Training completes — the original FSDP1 flat-param dtype crash can't recur because we're on FSDP2 with the per-module MixedPrecisionPolicy. 2. All RMSNorm params are float32 after step 1 — captured via a test-only TrainerCallback in tests/e2e/multigpu/_fp32_norms_dtype_capture.py, dumped to JSON at $FP32_NORMS_DTYPE_DUMP_PATH on rank 0. 3. At least one non-norm param is bfloat16 — proves the two FSDP2 MixedPrecisionPolicy groups are independent (catches a silent globally-fp32 fallback that would technically satisfy assertion 2 but defeat the point of the feature). The dtype-capture plugin is plumbed in via the test's yaml `plugins:` list, with PYTHONPATH=<repo_root> on the subprocess env so the tests.e2e.multigpu._fp32_norms_dtype_capture module resolves. Signed-off-by: Wing Lian <wing@axolotl.ai> * chore: lint --------- Signed-off-by: Wing Lian <wing@axolotl.ai>
* latest typer breaks HF CLI * wrong comparison
…tl-ai-cloud#3687) transformers decorates Gemma4VisionAttention with @use_kernelized_func(apply_rotary_pos_emb) where the target is a bare function. Under use_kernels=True (force-enabled by KernelsArgs for the ScatterMoE path), from_pretrained calls model.kernelize(), whose attach_hidden_kernels step does register_module(name, fn) for each _hidden_kernels entry. register_module rejects the non-Module function: TypeError: ...apply_rotary_pos_emb is not a Module subclass with a follow-on AttributeError from the cleanup path. The MoE itself is accelerated via the transformers ExpertsInterface (experts_implementation), independent of this path, and the vision forward uses apply_multidimensional_rope, never apply_rotary_pos_emb -- so the registered entry is dead weight. Add monkeypatch gemma4_kernelize that strips non-Module _hidden_kernels entries from Gemma4VisionAttention, wired in patch_manager._apply_model_specific_patches for gemma4 when use_kernels is set. state_dict is unchanged, so the fix is behavior-neutral. Also add ddp_find_unused_parameters: true to the 26b-a4b MoE QLoRA example (multi-GPU only -- text-backbone LoRA plus KV-sharing layers leave some adapter params gradient-less under DDP).
…xolotl-ai-cloud#3651) * fix: refactor kernels patch to drop routing and inject into Expert registry * chore: add to optim doc * feat: update sonicmoe version * chore: cleanup with DEEPEP and kernels compat * gate/guard model expert setup --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
… ci]
* feat(scattermoe-lora): selective dequant for mxfp4 expert weights
Add an MXFP4 branch to `selective_expert_weights()` that detects a
torchao `MXTensor` parameter (elem_dtype=float4_e2m1fn_x2) and
dequantizes only the active experts via index-then-construct of a
compact sub-MXTensor. The K-axis OCP block layout (last storage dim)
matches `experts.gate_up_proj` natural shape `[E, N, K]`, so the
caller's existing `.transpose(2, 1)` post-step keeps producing the
kernel's `[E, K, N]` weight tile unchanged.
`HFScatterMoEGatedMLP.forward` now also routes through the selective
path whenever the experts hold MXFP4 weights — full-tensor MX dequant
of 256-expert models is prohibitive and the kernel needs bf16 input.
Tests (CUDA-only) compare against a bf16 baseline produced by the
same MXTensor's full dequant; outputs are bitwise identical for both
forward and backward (dX, dA, dB) across small [E=8,K=128,N=256] and
representative [E=32,K=2048,N=1024] shapes, and across all four
combinations of \`use_fused_dX\` / \`use_fused_gather\`.
Signed-off-by: Wing Lian <wing@axolotl.ai>
* feat(scattermoe-lora): fused mxfp4 dequant in triton kernel
Add MX-aware forward and dX kernels that consume an ``MXWeights``
container (packed uint8 + E8M0 scales) directly, so the base-weight
tile is dequantized inside the K-loop instead of through a materialized
bf16 buffer. The K-loop loads two FP4 values per uint8 byte, looks them
up in a 16-entry codebook tensor (``±{0, 0.5, 1, 1.5, 2, 3, 4, 6}``),
multiplies by ``2^(scale_byte - 127)``, and casts to bf16 for the
matmul. ``BLOCK_K`` is constrained to a multiple of the OCP block size
(32) so each tile aligns with whole scale blocks; an MX-aware autotune
pruner accounts for the extra packed/scale SMEM.
The dX kernel reuses the *forward* MX layout (block axis = K, the dX
output axis) — for each (K_tile, N_tile) sub-tile, nibbles decode
along the K rows (the byte is shared by two adjacent K rows) and
scales broadcast within their MX block. This avoids the
dequant + re-quantize "pre-transpose" the spec suggested and the
extra MX-rounding error that round-trip would have introduced.
``ScatterMoELoRA.forward`` now accepts either a dense tensor or an
``MXWeights``; the MX branch always selects the fused-dX and
fused-gather backward kernels (the non-fused dX path would have to
materialize a bf16 weight tile, defeating the win).
Unit tests cover forward, dX, dA, dB parity for small
[E=8, K=128, N=256] and representative [E=32, K=2048, N=1024] shapes;
tolerances are calibrated to bf16 MMA noise (atomic-add ordering and
FMA reordering between the full-E baseline and compact-active MX path).
Integration test exercises a tiny synthetic DeepSeek-V4-style MoE
block (E=8, hidden=512, intermediate=256, top_k=2) end-to-end through
both Strategy A and Strategy B with LoRA disabled.
Signed-off-by: Wing Lian <wing@axolotl.ai>
* chore(scattermoe-lora): mxfp4 forward/backward benchmark
Add ``bench_mxfp4.py`` and committed results for the representative
DeepSeek-V4-style shape (E=128, K=2048, N=1024, top_k=8, M=4096,
rank=16). Reports ms/iter, tokens/s, peak GPU memory, and HBM
bandwidth utilisation for three configurations: bf16 baseline,
Strategy A (selective dequant), Strategy B (fused MX).
On the RTX PRO 6000 Blackwell, the all-active-experts shape used
here doesn't exercise selective dequant's memory savings (active = E
= 128) — A pays the cost of materialising the full bf16 dequant
buffer per step (~9 GB peak vs 1.9 GB for B) while still routing
through the bf16 kernel. B halves A's wall time (~12 ms vs 30 ms) by
eliminating the buffer, but stays slower than the bf16 baseline (5
ms) which assumes the bf16 weights already exist in memory.
Signed-off-by: Wing Lian <wing@axolotl.ai>
* bench(scattermoe-lora): mxfp4 sparse-routing benchmark numbers
Signed-off-by: Wing Lian <wing@axolotl.ai>
* bench(scattermoe-lora): mxfp4 seqlen sweep with load-balanced routing
Signed-off-by: Wing Lian <wing@axolotl.ai>
* fix(scattermoe-lora): correct mx forward smem accounting
The MX-aware autotune pruner for the forward kernel under-accounted
SMEM: it computed the packed-tile cost as BLOCK_N * BLOCK_K/2 and the
scale-tile cost as BLOCK_N * BLOCK_K/MX_BLOCK_SIZE, but the actual
tl.load issues a full [BLOCK_N, BLOCK_K]-shaped uint8 fetch for both
buffers (the packed buffer reads each byte twice because K_byte =
K // 2 indexes a [BLOCK_K]-wide vector; the scale buffer broadcasts
within each MX_BLOCK_SIZE K-block). Bring the forward pruner up to the
same conservative full-tile accounting already used by
_prune_dX_mx_configs. Without this, on the [E=128, K=2048, N=1024]
shape with the typical GPU SMEM caps, two to six high-stage configs
that were previously selectable would have overflowed SMEM at launch
under correct accounting — a silent OOM-in-the-future risk.
Signed-off-by: Wing Lian <wing@axolotl.ai>
* docs(scattermoe-lora): align mx dx kernel docstring with implementation
The file-level docstring for the MXFP4 kernels described the dX kernel
as using a pre-transposed [E, K, N/2] layout produced by a
'mx_pre_transpose_for_dx' helper. That helper doesn't exist; the dX
kernel actually reuses the forward [E, N, K/2] layout, iterating the N
reduction in outer tiles and decoding nibbles along the K rows of each
tile. Rewrite the docstring to describe what the code actually does,
including the rationale — reusing the forward buffer avoids the
dequant + re-quantize round-trip that a pre-transpose would require
and keeps dX numerics free of a second MX rounding error stacked on
top of the forward quantization.
Signed-off-by: Wing Lian <wing@axolotl.ai>
* chore(scattermoe-lora): mx code-review nit cleanup
F4: Hoist 'is_mxfp4_param' import from inside 'HFScatterMoEGatedMLP.forward'
to the top of layers.py — it was being re-imported every step on the hot
path.
F5: Add a thin compatibility shim for torchao MXTensor internals access in
mx_weights.py. The MX paths in selective_dequant.py / mx_weights.py used
to reach into 'mx_param.qdata', 'mx_param.scale',
'mx_param.kernel_preference' and call 'MXTensor(...)' with positional
args directly. That works at the pinned torchao 0.17.0 but is fragile to
internal renames in future torchao releases. Funnel through three
helpers — '_mx_qdata', '_mx_scale', '_construct_mxtensor_subset' — that
use 'getattr' fallbacks for the buffer attributes and pass the
constructor's optional args via 'getattr' too. Single point of pain,
no API change.
F7: Remove the unused 'NO_K_MASK' heuristic + tl.constexpr param from
the dX MX kernel '_scatter2scatter_lora_dX_mx'. The dX kernel never
references it (its inner loop masks N, not K), so the constexpr just
forced extra autotune key entries.
F8: Consolidate the duplicate '_torchao_mxtensor_cls()' definitions
(one in selective_dequant.py, one in mx_weights.py) into a single
definition in mx_weights.py. selective_dequant.py imports it.
Signed-off-by: Wing Lian <wing@axolotl.ai>
* test(scattermoe-lora): strengthen mx backward test coverage
F3: 'test_strategy_a_backward_fused_variants' previously used
'torch.ones_like(output)' as the grad input and asserted only on dX.
A uniform grad zeros out cross-token differences in the fused-gather
accumulation, masking reordering bugs; restricting the assertion to dX
silently let the dA/dB paths go unchecked across the four
'(use_fused_dX, use_fused_gather)' production variants.
* Drive the backward with 'torch.randn_like(output) * 0.1'.
* Capture and assert dA and dB parity across all four variants
using the same 'row_idx' gather pattern as
'test_strategy_a_backward_matches_bf16'.
* Forward and dX are still asserted bitwise via 'torch.equal'. dA/dB
fall back to atol/rtol = 1e-3 because the fused dA/dB kernel uses
'atomic_add' across N-block programs and the in-flight program
count differs between the full-E baseline and the compact-active
path; combined with FMA reordering, the 'use_fused_dX=True'
variants accumulate ~1 bf16 ULP of unavoidable atomic-order noise.
The new bound is still an order of magnitude below that noise
floor, so it catches real bugs.
F9: The 'test_strategy_b_backward_matches_bf16' dX comparison runs at
'atol=0.5, rtol=2e-2' (small) / 'atol=2.0, rtol=3e-2' (representative)
to allow for accumulated bf16 MMA noise over the N reduction. Those
bounds are appropriate for legitimate per-element drift but would also
admit a uniform multiplicative bug — e.g. an off-by-one on the E8M0
exponent that scales every dX element by 2x.
Add a guard alongside the existing 'torch.allclose': mask out
near-zero baseline elements (relative to 'bf16_dX.abs().max()'), then
require the per-element ratio 'mx_dX / bf16_dX' to have std < 0.5. A
uniform multiplicative bug pushes that std to ~0 while the mean shifts;
a real-bug per-element drift pushes the std up. This crosscuts the
allclose check rather than replacing it.
Signed-off-by: Wing Lian <wing@axolotl.ai>
* bench(scattermoe-lora): exclude per-iter setup from timed window
The previous bench harness did a fresh '.clone()' of x and a
'requires_grad_(True)' on cloned lora A/B tensors every iter inside
the timed window. That accounts for buffer allocation, not kernel
cost, and biases the numbers toward whichever path produced the
smallest activations. Restructure the runners so:
* 'x' is cloned once into a leaf tensor with 'requires_grad_(True)'
inside 'bench()' (outside the timed warmup + timed loop).
* LoRA A/B leaf tensors are constructed once in the runner factory,
not per iter.
* Each iter calls the runner which sets 'x.grad = A.grad = B.grad =
None' (cheap, no GPU sync) so the autograd graph for the timed
iteration is fresh and grads don't accumulate.
Re-run all three configs end-to-end after this change (dense E=128,
sparse E=256 / 10-active, balanced E=256 M-sweep at M ∈ {256, 1024,
4096, 16384}) and refresh the numbers in bench_mxfp4_results.md.
Headers and table structure are unchanged. The qualitative ordering
holds (Strategy A wins at low active/E, Strategy B wins near
active/E ≈ 1, and Strategy A still OOMs across the balanced sweep on
the workstation with vLLM colocated), with per-cell numbers within
single-digit percent of the prior runs.
Signed-off-by: Wing Lian <wing@axolotl.ai>
* style(scattermoe-lora): apply pre-commit auto-fixes and mypy fixes
Signed-off-by: Wing Lian <wing@axolotl.ai>
* fix(scattermoe-lora): mxfp4 shape validation + torchao version messages
Signed-off-by: Wing Lian <wing@axolotl.ai>
* lint and PR review fixes
* fix(scattermoe-lora): restore lint task fixes + add missing forward parity assertions
Wing's "lint and PR review fixes" commit (9007a82) reverted three fixes
from the prior lint pass. Restore them:
1. parallel_linear_lora.py: use isinstance(expert_weights, MXWeights)
directly so mypy can narrow the union — the `is_mx` boolean alias
blocks narrowing and re-introduces 2 union-attr errors.
2. bench_mxfp4.py: assert template is not None before the MXTensor(...)
constructor — the chunked converter initializes template to None
then sets it inside the loop, which mypy can't prove non-None at
the call site (6 None-attr errors).
3. test_mxfp4_expert_weights.py: the F841 on fwd_tol was actually a
smell of dropped logic. Both backward tests
(test_strategy_a_backward_matches_bf16 and
test_strategy_b_backward_matches_bf16) compute the forward outputs
out_b/out_a/out_s, run backward, and assert gradients match — but
never assert that the forward outputs match. A forward bug
producing a constant offset (and therefore zero gradient delta)
would slip past the bwd-only checks. Add the missing
torch.equal(out_b, out_a) for Strategy A (bitwise contract) and
torch.allclose(out_b, out_s, **fwd_tol) for Strategy B (MX tol).
Signed-off-by: Wing Lian <wing@axolotl.ai>
* don't worry about flash-attn direct patches for now
---------
Signed-off-by: Wing Lian <wing@axolotl.ai>
… fp32 fix + AC-vs-tiled gap analysis (axolotl-ai-cloud#3666) [skip ci] * feat(scattermoe-lora): selective dequant for mxfp4 expert weights Add an MXFP4 branch to `selective_expert_weights()` that detects a torchao `MXTensor` parameter (elem_dtype=float4_e2m1fn_x2) and dequantizes only the active experts via index-then-construct of a compact sub-MXTensor. The K-axis OCP block layout (last storage dim) matches `experts.gate_up_proj` natural shape `[E, N, K]`, so the caller's existing `.transpose(2, 1)` post-step keeps producing the kernel's `[E, K, N]` weight tile unchanged. `HFScatterMoEGatedMLP.forward` now also routes through the selective path whenever the experts hold MXFP4 weights — full-tensor MX dequant of 256-expert models is prohibitive and the kernel needs bf16 input. Tests (CUDA-only) compare against a bf16 baseline produced by the same MXTensor's full dequant; outputs are bitwise identical for both forward and backward (dX, dA, dB) across small [E=8,K=128,N=256] and representative [E=32,K=2048,N=1024] shapes, and across all four combinations of \`use_fused_dX\` / \`use_fused_gather\`. Signed-off-by: Wing Lian <wing@axolotl.ai> * feat(scattermoe-lora): fused mxfp4 dequant in triton kernel Add MX-aware forward and dX kernels that consume an ``MXWeights`` container (packed uint8 + E8M0 scales) directly, so the base-weight tile is dequantized inside the K-loop instead of through a materialized bf16 buffer. The K-loop loads two FP4 values per uint8 byte, looks them up in a 16-entry codebook tensor (``±{0, 0.5, 1, 1.5, 2, 3, 4, 6}``), multiplies by ``2^(scale_byte - 127)``, and casts to bf16 for the matmul. ``BLOCK_K`` is constrained to a multiple of the OCP block size (32) so each tile aligns with whole scale blocks; an MX-aware autotune pruner accounts for the extra packed/scale SMEM. The dX kernel reuses the *forward* MX layout (block axis = K, the dX output axis) — for each (K_tile, N_tile) sub-tile, nibbles decode along the K rows (the byte is shared by two adjacent K rows) and scales broadcast within their MX block. This avoids the dequant + re-quantize "pre-transpose" the spec suggested and the extra MX-rounding error that round-trip would have introduced. ``ScatterMoELoRA.forward`` now accepts either a dense tensor or an ``MXWeights``; the MX branch always selects the fused-dX and fused-gather backward kernels (the non-fused dX path would have to materialize a bf16 weight tile, defeating the win). Unit tests cover forward, dX, dA, dB parity for small [E=8, K=128, N=256] and representative [E=32, K=2048, N=1024] shapes; tolerances are calibrated to bf16 MMA noise (atomic-add ordering and FMA reordering between the full-E baseline and compact-active MX path). Integration test exercises a tiny synthetic DeepSeek-V4-style MoE block (E=8, hidden=512, intermediate=256, top_k=2) end-to-end through both Strategy A and Strategy B with LoRA disabled. Signed-off-by: Wing Lian <wing@axolotl.ai> * chore(scattermoe-lora): mxfp4 forward/backward benchmark Add ``bench_mxfp4.py`` and committed results for the representative DeepSeek-V4-style shape (E=128, K=2048, N=1024, top_k=8, M=4096, rank=16). Reports ms/iter, tokens/s, peak GPU memory, and HBM bandwidth utilisation for three configurations: bf16 baseline, Strategy A (selective dequant), Strategy B (fused MX). On the RTX PRO 6000 Blackwell, the all-active-experts shape used here doesn't exercise selective dequant's memory savings (active = E = 128) — A pays the cost of materialising the full bf16 dequant buffer per step (~9 GB peak vs 1.9 GB for B) while still routing through the bf16 kernel. B halves A's wall time (~12 ms vs 30 ms) by eliminating the buffer, but stays slower than the bf16 baseline (5 ms) which assumes the bf16 weights already exist in memory. Signed-off-by: Wing Lian <wing@axolotl.ai> * bench(scattermoe-lora): mxfp4 sparse-routing benchmark numbers Signed-off-by: Wing Lian <wing@axolotl.ai> * bench(scattermoe-lora): mxfp4 seqlen sweep with load-balanced routing Signed-off-by: Wing Lian <wing@axolotl.ai> * fix(scattermoe-lora): correct mx forward smem accounting The MX-aware autotune pruner for the forward kernel under-accounted SMEM: it computed the packed-tile cost as BLOCK_N * BLOCK_K/2 and the scale-tile cost as BLOCK_N * BLOCK_K/MX_BLOCK_SIZE, but the actual tl.load issues a full [BLOCK_N, BLOCK_K]-shaped uint8 fetch for both buffers (the packed buffer reads each byte twice because K_byte = K // 2 indexes a [BLOCK_K]-wide vector; the scale buffer broadcasts within each MX_BLOCK_SIZE K-block). Bring the forward pruner up to the same conservative full-tile accounting already used by _prune_dX_mx_configs. Without this, on the [E=128, K=2048, N=1024] shape with the typical GPU SMEM caps, two to six high-stage configs that were previously selectable would have overflowed SMEM at launch under correct accounting — a silent OOM-in-the-future risk. Signed-off-by: Wing Lian <wing@axolotl.ai> * docs(scattermoe-lora): align mx dx kernel docstring with implementation The file-level docstring for the MXFP4 kernels described the dX kernel as using a pre-transposed [E, K, N/2] layout produced by a 'mx_pre_transpose_for_dx' helper. That helper doesn't exist; the dX kernel actually reuses the forward [E, N, K/2] layout, iterating the N reduction in outer tiles and decoding nibbles along the K rows of each tile. Rewrite the docstring to describe what the code actually does, including the rationale — reusing the forward buffer avoids the dequant + re-quantize round-trip that a pre-transpose would require and keeps dX numerics free of a second MX rounding error stacked on top of the forward quantization. Signed-off-by: Wing Lian <wing@axolotl.ai> * chore(scattermoe-lora): mx code-review nit cleanup F4: Hoist 'is_mxfp4_param' import from inside 'HFScatterMoEGatedMLP.forward' to the top of layers.py — it was being re-imported every step on the hot path. F5: Add a thin compatibility shim for torchao MXTensor internals access in mx_weights.py. The MX paths in selective_dequant.py / mx_weights.py used to reach into 'mx_param.qdata', 'mx_param.scale', 'mx_param.kernel_preference' and call 'MXTensor(...)' with positional args directly. That works at the pinned torchao 0.17.0 but is fragile to internal renames in future torchao releases. Funnel through three helpers — '_mx_qdata', '_mx_scale', '_construct_mxtensor_subset' — that use 'getattr' fallbacks for the buffer attributes and pass the constructor's optional args via 'getattr' too. Single point of pain, no API change. F7: Remove the unused 'NO_K_MASK' heuristic + tl.constexpr param from the dX MX kernel '_scatter2scatter_lora_dX_mx'. The dX kernel never references it (its inner loop masks N, not K), so the constexpr just forced extra autotune key entries. F8: Consolidate the duplicate '_torchao_mxtensor_cls()' definitions (one in selective_dequant.py, one in mx_weights.py) into a single definition in mx_weights.py. selective_dequant.py imports it. Signed-off-by: Wing Lian <wing@axolotl.ai> * test(scattermoe-lora): strengthen mx backward test coverage F3: 'test_strategy_a_backward_fused_variants' previously used 'torch.ones_like(output)' as the grad input and asserted only on dX. A uniform grad zeros out cross-token differences in the fused-gather accumulation, masking reordering bugs; restricting the assertion to dX silently let the dA/dB paths go unchecked across the four '(use_fused_dX, use_fused_gather)' production variants. * Drive the backward with 'torch.randn_like(output) * 0.1'. * Capture and assert dA and dB parity across all four variants using the same 'row_idx' gather pattern as 'test_strategy_a_backward_matches_bf16'. * Forward and dX are still asserted bitwise via 'torch.equal'. dA/dB fall back to atol/rtol = 1e-3 because the fused dA/dB kernel uses 'atomic_add' across N-block programs and the in-flight program count differs between the full-E baseline and the compact-active path; combined with FMA reordering, the 'use_fused_dX=True' variants accumulate ~1 bf16 ULP of unavoidable atomic-order noise. The new bound is still an order of magnitude below that noise floor, so it catches real bugs. F9: The 'test_strategy_b_backward_matches_bf16' dX comparison runs at 'atol=0.5, rtol=2e-2' (small) / 'atol=2.0, rtol=3e-2' (representative) to allow for accumulated bf16 MMA noise over the N reduction. Those bounds are appropriate for legitimate per-element drift but would also admit a uniform multiplicative bug — e.g. an off-by-one on the E8M0 exponent that scales every dX element by 2x. Add a guard alongside the existing 'torch.allclose': mask out near-zero baseline elements (relative to 'bf16_dX.abs().max()'), then require the per-element ratio 'mx_dX / bf16_dX' to have std < 0.5. A uniform multiplicative bug pushes that std to ~0 while the mean shifts; a real-bug per-element drift pushes the std up. This crosscuts the allclose check rather than replacing it. Signed-off-by: Wing Lian <wing@axolotl.ai> * bench(scattermoe-lora): exclude per-iter setup from timed window The previous bench harness did a fresh '.clone()' of x and a 'requires_grad_(True)' on cloned lora A/B tensors every iter inside the timed window. That accounts for buffer allocation, not kernel cost, and biases the numbers toward whichever path produced the smallest activations. Restructure the runners so: * 'x' is cloned once into a leaf tensor with 'requires_grad_(True)' inside 'bench()' (outside the timed warmup + timed loop). * LoRA A/B leaf tensors are constructed once in the runner factory, not per iter. * Each iter calls the runner which sets 'x.grad = A.grad = B.grad = None' (cheap, no GPU sync) so the autograd graph for the timed iteration is fresh and grads don't accumulate. Re-run all three configs end-to-end after this change (dense E=128, sparse E=256 / 10-active, balanced E=256 M-sweep at M ∈ {256, 1024, 4096, 16384}) and refresh the numbers in bench_mxfp4_results.md. Headers and table structure are unchanged. The qualitative ordering holds (Strategy A wins at low active/E, Strategy B wins near active/E ≈ 1, and Strategy A still OOMs across the balanced sweep on the workstation with vLLM colocated), with per-cell numbers within single-digit percent of the prior runs. Signed-off-by: Wing Lian <wing@axolotl.ai> * style(scattermoe-lora): apply pre-commit auto-fixes and mypy fixes Signed-off-by: Wing Lian <wing@axolotl.ai> * fix(scattermoe-lora): mxfp4 shape validation + torchao version messages Signed-off-by: Wing Lian <wing@axolotl.ai> * lint and PR review fixes * fix(scattermoe-lora): restore lint task fixes + add missing forward parity assertions Wing's "lint and PR review fixes" commit (9007a82) reverted three fixes from the prior lint pass. Restore them: 1. parallel_linear_lora.py: use isinstance(expert_weights, MXWeights) directly so mypy can narrow the union — the `is_mx` boolean alias blocks narrowing and re-introduces 2 union-attr errors. 2. bench_mxfp4.py: assert template is not None before the MXTensor(...) constructor — the chunked converter initializes template to None then sets it inside the loop, which mypy can't prove non-None at the call site (6 None-attr errors). 3. test_mxfp4_expert_weights.py: the F841 on fwd_tol was actually a smell of dropped logic. Both backward tests (test_strategy_a_backward_matches_bf16 and test_strategy_b_backward_matches_bf16) compute the forward outputs out_b/out_a/out_s, run backward, and assert gradients match — but never assert that the forward outputs match. A forward bug producing a constant offset (and therefore zero gradient delta) would slip past the bwd-only checks. Add the missing torch.equal(out_b, out_a) for Strategy A (bitwise contract) and torch.allclose(out_b, out_s, **fwd_tol) for Strategy B (MX tol). Signed-off-by: Wing Lian <wing@axolotl.ai> * don't worry about flash-attn direct patches for now * feat(tiled-mlp): support MoE block classes in patcher Extend patch_tiled_mlp to discover MoE block classes ({prefix}SparseMoeBlock / MoeMLP / MoE) and patch the routing+expert forward when scattermoe-lora is active. The kernels library installs HFScatterMoEGatedMLP.forward per instance during model.kernelize(), which shadows class-level patches. Add a post-model-load step (patch_tiled_mlp_moe_instances) that re-wraps each MoE block instance so tiling layers on top of the kernels-installed forward instead of being bypassed. Falls back to the existing dense {prefix}MLP / {prefix}TextMLP path when no MoE block class exists. The gpt_oss special case for DeepSpeedTiledMLPMoE is preserved and extended to every MoE block. Signed-off-by: Wing Lian <wing@axolotl.ai> * fix(tiled-mlp): defer FSDP2 reshard + correct per-shard grad accumulation Two backward-pass correctness fixes in TiledMLP.backward. 1) Defer FSDP2 post-backward reshard across the tile loop. The backward issues one torch.autograd.backward per shard. Under FSDP2 (torch.distributed.fsdp.fully_shard), the first inner backward triggers the wrapping FSDPModule's post-backward hook, which reshards parameters; subsequent shards then recompute against only-local DTensor shards. Silent gradient corruption at best, crash at worst. LinkedIn's Liger-Kernel PR axolotl-ai-cloud#1128 fixed this for FSDP1 with FSDP.summon_full_params(writeback=True). That API does not exist in FSDP2. The PyTorch 2.11 FSDP2 surface is FSDPModule.set_reshard_after_backward(False) — toggle off around the tile loop, restore the prior value, and issue one explicit reshard() afterwards. The wrapping FSDPModule is discovered by walking the global _module_state_mapping registry (FSDP2 is typically applied at the decoder-layer level, so the MLP itself is rarely the FSDPModule). Result is cached on the MLP instance so the walk runs once. No-op under DDP, single-GPU, or DeepSpeed. DeepSpeedTiledMLPMoE is left alone — DeepSpeed coordinates its own gather and the two backends are mutually exclusive. 2) Replace the hook-based GradientAccumulator with inline fp32 accumulation. The previous implementation called grad_accumulator.install_hooks() inside every shard iteration, so the N-th shard ran N stacked hooks that each accumulated the same shard contribution — and on the last shard the manually-set param.grad was then re-added by AccumulateGrad, doubling it. The accumulator also scaled by 1/N, but sequence-dim sharded gradients are additive (not averaged). Combined, param.grad came out ~2x-2.5x the analytical value. Inline accumulation captures param.grad after each shard's inner backward, sums into a per-param fp32 accumulator, clears the running grad, and writes the total back once at the end (preserving any pre-existing .grad from earlier graph segments). Signed-off-by: Wing Lian <wing@axolotl.ai> * test(tiled-mlp): single-gpu MoE + scattermoe-lora coverage Three parity checks (tiled vs un-tiled forward+backward) plus two patcher-internals tests. All gated on CUDA. - Dense LlamaMLP-shape (hidden=64, intermediate=128, seq=64): tight atol=1e-5 on outputs, dX, and every parameter grad. Uses batch=1 to match the sequence-packed inputs production sees. - Hand-rolled MoE block (E=8, hidden=64, intermediate=128, top_k=2): same shape + same tolerances against an index_add-based reference. - ScatterMoEGatedMLP in bf16: norm-relative tolerance < 1%, matching the established bar in tests/integrations/test_scattermoe_lora_kernels.py (bf16 + tiled reduction order makes max abs error a noisy signal). - Patcher unit tests: MoE block class discovery prefers SparseMoeBlock / MoeMLP over MoE, and returns None for dense models. Synthetic-shape modules only — no transformers checkpoints loaded. Signed-off-by: Wing Lian <wing@axolotl.ai> * test(tiled-mlp): FSDP2 multi-rank correctness Two parity tests (dense + scattermoe-lora) that wrap a tiny MLP / ScatterMoEGatedMLP with FSDP2 (`fully_shard`) and compare tiled forward+backward against a non-tiled FSDP2 reference. Both must run through the FSDPModule's __call__ so FSDP2's pre-forward hooks materialize the unsharded params before TiledMLP.apply chunks the input; the helper _install_tiled_forward mirrors what the production patcher does instance-side. Designed to be launched with `torchrun --nproc-per-node=2 -m pytest tests/e2e/multigpu/test_tiled_mlp_fsdp2.py`. Skips with a clear reason on a 1-GPU executor or when launched without torchrun. Verified to pass on a 2-GPU runner (RTX PRO 6000 Blackwell). Signed-off-by: Wing Lian <wing@axolotl.ai> * feat(scattermoe-lora): shared dequant buffer across tile shards Adds shared_dequant_across_shards() which hoists the MXFP4 dequant out of the per-shard selective path. The orthogonal tiled wrapper calls selective_expert_weights once per shard; when active-expert sets overlap (the common case under softmax routing) the dequant is wasted work. The helper computes the union of active experts across all shards, dequantizes that union once, and returns per-shard remaps so each shard's parallel_linear_lora call uses the correct slice. Bitwise contract: a shard's gathered slice is byte-identical to the per-shard selective_expert_weights output, verified by test_shared_dequant_helper.py with N=4 overlapping shards plus disjoint and single-shard regression cases. Signed-off-by: Wing Lian <wing@axolotl.ai> * fix(tiled-mlp): default grad accumulator to param dtype, skip redundant casts The orthogonal TiledMLP wrapper pre-allocated an fp32 accumulator the size of every compute param, then cast each shard's bf16 ``param.grad`` to fp32 inside the loop before adding it. For E=128 / hidden=2048 / intermediate=8192 MoE training in bf16 that's roughly 17 GiB of fp32 buffer on the gate_up_proj alone — net 2x parameter-side memory regression vs. simply accumulating at the param's own dtype. The per-shard ``grad.to(fp32)`` cast was also a per-shard HBM bandwidth tax that dominated the wall-clock regression at intermediate=8192. Match what AccumulateGrad does in the unsharded backward: accumulate at the param's own dtype, skip the cast when shard-grad dtype matches the accumulator dtype, and only cast back to param dtype at write-back when the buffer dtype differs. fp32 accumulation is opt-in via AXOLOTL_TILED_MLP_ACCUM_FP32=1 for callers who care about bf16 round-off in very-large-N-shard sums. The dead ``GradientAccumulator`` class (no longer called after the inline-accumulation refactor in b13375a0) is updated to the same defaults — param-dtype accumulator, gradient_scale=1.0 — so it is in a coherent state if anyone re-introduces a hook-based path. Signed-off-by: Wing Lian <wing@axolotl.ai> * test(tiled-mlp): strengthen tiled-vs-untiled grad parity Add three regression guards for the TiledMLP gradient-accumulator fix: 1) ``test_tiled_dense_mlp_grad_parity_nonuniform_weights`` and ``test_tiled_moe_grad_parity_nonuniform_weights`` exercise shards in {1, 2, 4} with non-uniform per-token upstream weights. A mean-vs-sum scaling bug in the per-shard accumulator (the historical ``gradient_scale = 1/total_shards``) would show up as roughly ``(N-1)/N`` relative drift in the param grads. The old tests used a single shard count and uniform-magnitude upstream, which allowed the bug to slip through. 2) ``test_tiled_dense_mlp_grad_parity_bf16`` runs the same parity at bf16 to lock the default param-dtype accumulator path (no fp32 buffer) against regression. 3) ``test_tiled_grad_accumulator_dtype_matches_param_dtype`` is an allocation-side guard: spy on ``torch.zeros_like`` during a bf16 tiled backward and assert none of the per-param accumulator allocations request fp32. A future change that re-introduces the fp32 buffer by default would fail this check without needing a memory-resident bench. Signed-off-by: Wing Lian <wing@axolotl.ai> * fix(tiled-mlp): default to ~32K tokens/shard, not ceil(seq/hidden) The previous heuristic put only ~2K tokens/shard at long context — well below the MoE Triton kernel's BLOCK_M sweet spot. An empirical sweep at seq ∈ {64K, 128K, 256K, 512K} showed 3.2× speed-up at 64–256K and 2.1× at 512K from raising per-shard tokens to ~32K, with only a modest peak-mem cost (~5–10 GiB extra at seq=256K) because the routed intermediate buffer dominates and scales linearly with per-shard tokens. Bench data is operator-archived locally; the headline numbers are included in the PR description. The 32K target is empirical, not theoretical — it's the largest tokens-per-shard that fits at seq up to 256K without OOM and stays inside the cuBLAS large-batch_count safe regime that surfaces a separate bug at seq=512K + s=16. Operators can override via cfg_num_shards for niche cases (smaller intermediate, larger top_k). Also includes ruff-format cleanup of cherry-picked commits. Signed-off-by: Wing Lian <wing@axolotl.ai> --------- Signed-off-by: Wing Lian <wing@axolotl.ai>
…oud#3660) LigerFusedLinearKLTopKLogprobFunction.forward's loss_fn_for_grad returned (soft_loss, ce_loss) directly to torch.func.grad_and_value(has_aux=True), which treats only the first element as the grad target. CE was silently dropped from the backward graph: CE-only training had grad_norm=0 every step, and KD-mix training updated parameters with KD-only gradients despite reported loss showing both terms. Combine the two losses inside loss_fn_for_grad so both contribute to backward, and keep (soft_loss, ce_loss) as aux for reporting. Outer accumulators, temperature scaling, and the final reported loss formula are unchanged. Adds regression tests in tests/integrations/test_kd_liger.py.
… (supersedes chunking workaround) (axolotl-ai-cloud#3667) * test(scattermoe-lora): repro CUBLAS_STATUS_EXECUTION_FAILED at large batch_count The tiled-MLP long-context bench surfaces a hard failure at seq=524288 with 16 shards: ``cublasGemmStridedBatchedEx`` raises ``CUBLAS_STATUS_EXECUTION_FAILED`` at parallel_experts.py:72's ``gates.unsqueeze(1) @ output_expanded``. The crash is reproducible on the bench shape (T=32K tokens/shard, top_k=8, hidden=2048, intermediate=8192) and is a downstream symptom of an int32 pointer-offset overflow in the upstream ``scatter2scatter`` Triton kernel during the up-projection — at that shape its output buffer is 2**32 elements, so ``M_block * stride_ym`` overflows int32 for the trailing rows. Add three repro tests in tests/integrations covering: * fast-path bit-identity vs the raw kernel below the threshold, * non-corruption at the overflow shape via the int32-safe wrapper, * end-to-end ``parallel_linear`` smoke at the failing bench shape. The tests are marked ``pytest.mark.skip`` pending the fix in the next commit; the follow-up "un-mark" commit re-enables them so they guard the fix going forward. Symbols imported inside each test body (``_scatter2scatter_int32_safe``, ``_SCATTER2SCATTER_INT32_LIMIT``) land alongside the fix in the next commit, so the skipped tests do not fail collection at this point in the history. Signed-off-by: Wing Lian <wing@axolotl.ai> fix(scattermoe-lora): work around cuBLAS large-batch_count failure in gates @ output_expanded The originally-reported symptom — ``CUBLAS_STATUS_EXECUTION_FAILED`` at parallel_experts.py:72 — is NOT a cuBLAS bug. The cuBLAS bmm shape at the failing seq=512K / 16-shard config is tiny (batch_count=32768, M=1, K=8, N=2048) and works in isolation. The crash is a downstream symptom of an int32 pointer-offset overflow in the upstream ``scatter2scatter`` Triton kernel during the up-projection, surfaced at the next CUDA-sync point (the bmm). Diagnosis (verified by inserting a ``torch.cuda.synchronize()`` immediately after the up-projection's scatter2scatter — that sync itself raises "an illegal memory access was encountered", proving the fault is upstream of the bmm): The Triton kernel computes output pointer offsets as ``Y_ptr + M_block * stride_ym + N_block * stride_yn`` with int32 ``M_block`` / ``stride_ym``. At seq=524288 / shards=16 the up-projection's output is ``[L_scattered=262144, y_dim=2*INTERMEDIATE=16384]`` = ``2**32`` elements; the trailing rows whose ``M_block * stride_ym`` overflows int32 have their masked stores silently drop (rows come back as zeros) or land at bogus pointers, which then trips a delayed ``CUDA illegal memory access`` that the next kernel surfaces. Workaround at the smallest scope appropriate to the actual root cause (NOT at parallel_experts.py:72, which is downstream): wrap ``kernels.ops.scatter2scatter`` with ``_scatter2scatter_int32_safe`` and route both call sites (``ParallelLinear.forward`` and ``ParallelLinear.backward``) through it. The wrapper: * Fast path (common case): when ``L_scattered * y_dim < 2**31``, dispatches a single direct kernel call — no overhead vs the pre-fix code. Verified at seq=524288 / shards=64: 36741 tokens/s post-fix vs ~37512 tokens/s pre-fix (within noise, no regression). * Slow path: when the output would overflow AND ``y_grouped=True``, allocates the full output and chunks along the L_scattered axis. Each sub-call writes to ``out[chunk_start:chunk_end]`` with the matching sei / ssi slice; the chunk size is the largest BLOCK_M-aligned row count keeping ``rows * y_dim < 2**31``. The chunked path drops into ``kernels.ops.scatter2scatter_compileable`` directly to bypass the high-level wrapper's ``sorted_scattered_idxs.size(0) == X.size(0) * k`` assertion that only holds for full calls. * When ``x_grouped=True`` (the down-proj backward), X is sliced in lockstep so the kernel's ``M_in_idx = M_block`` correctly reads ``X_chunk[0..chunk_size-1]``. When ``x_grouped=False`` (the up-proj forward) X stays full because the kernel indexes X via global ``M_idx // FAN_OUT`` from the per-position ``sorted_scattered_idxs`` values. * For ``y_grouped=False`` at overflow scale, the wrapper hard-raises ``RuntimeError`` — the kernel uses per-position scattered indices as output row indices so the wrapper cannot tile that case safely; the kernel itself needs an int64 pointer-arithmetic fix before that path is callable at this scale. Production paths today are all ``y_grouped=True`` so this branch is unreachable in the bench. Silent corruption is strictly worse than a clear raise. * Includes ``assert L_scattered % chunk_rows == 0`` for the ``x_grouped=False`` chunked path, since the kernel's ``M_boundary_mask`` uses the full (unchunked) X size and a partial last chunk would let the final tile read past sei_chunk / ssi_chunk. The assertion holds for all realistic power-of-2 shapes and fires loudly if a future caller hits a non-aligned one. The ``x_grouped=True`` chunked path is naturally bounded because X is chunked in lockstep. Before / after on the bench config: * pre-fix: CUBLAS_STATUS_EXECUTION_FAILED (no result) * post-fix: 10084 ms/iter, 51989 tokens/s, peak 64.16 GiB (matches the previously-fastest non-failing rows of the s=64 / s=256 sweep; s=16 was the predicted fastest row that the bug had been hiding) Constraints honoured: no changes to ``kernels.ops.scatter2scatter`` or any other Triton kernel; no public-API change on ``parallel_experts.py``; scope confined to scattermoe-lora's ParallelLinear; common-case fast path untouched. The LoRA-path counterpart (``parallel_linear_lora.py`` → ``scatter2scatter_lora``) has the same architectural risk but is not exercised by the failing bench config and is left for follow-up. Signed-off-by: Wing Lian <wing@axolotl.ai> test(scattermoe-lora): enable large-batch repro tests now the fix has landed Remove the ``pytest.mark.skip`` marker added in the ``repro CUBLAS_STATUS_EXECUTION_FAILED`` commit. The fix in the previous ``work around cuBLAS large-batch_count failure`` commit provides ``_scatter2scatter_int32_safe`` and the matching ``_SCATTER2SCATTER_INT32_LIMIT`` constant referenced by these tests, so the three tests now run and guard the fix going forward: * ``test_int32_safe_wrapper_matches_direct_call_below_threshold`` — fast-path equivalence (no overhead in the common case). * ``test_int32_safe_wrapper_no_corruption_at_overflow_shape`` — chunked slow-path correctness at the bench shape. * ``test_parallel_linear_long_seq_routing_combination`` — end-to-end smoke through ``ScatterMoEGatedMLP.forward`` shape sequence at seq=524288 / shards=16. All three pass on CUDA hardware; they self-skip when CUDA is unavailable. Signed-off-by: Wing Lian <wing@axolotl.ai> * feat(scattermoe-lora): add INT64_INDICES tl.constexpr to dense scatter2scatter The Triton scatter2scatter kernel computes output pointer offsets as ``Y_ptr + M_block * stride_ym + N_block * stride_yn`` where M_block / M_idx are int32 by default. At seq=512K with coarse shards the ``L_scattered * y_dim`` product exceeds 2**31 elements and the int32 arithmetic overflows; PR axolotl-ai-cloud#3667 worked around this by chunking the call along the L_scattered axis when y_grouped=True, but that workaround doesn't cover y_grouped=False (raises) or the LoRA-path kernels. Add an ``INT64_INDICES: tl.constexpr = False`` knob to the dense ``_scatter2scatter`` kernel signature. When True, the M_block range and the scattered-index lookup ``M_idx`` are cast to int64 before they enter the pointer-offset multiplication, so all downstream pointer arithmetic propagates int64. Strides themselves stay as the kernel sees them (coming from ``tensor.stride()`` they're already int64 at the Python level); only the *index* values change type. Triton will JIT a separate variant per constexpr value, so the existing int32 fast path is unaffected. The wrapper-level auto-dispatch (compute ``needs_int64`` from tensor sizes and forward to the kernel) lands in a follow-up commit; this commit just exposes the constexpr and a Python-side ``int64_indices`` kwarg on ``scatter2scatter`` / ``scatter2scatter_compileable``. Signed-off-by: Wing Lian <wing@axolotl.ai> * feat(scattermoe-lora): add INT64_INDICES to LoRA forward + dX kernels (bf16 + MX) Adds the same ``INT64_INDICES: tl.constexpr = False`` knob as the dense kernel to the four LoRA-path scatter2scatter kernels: * ``_scatter2scatter_lora`` — bf16 fused base+LoRA forward * ``_scatter2scatter_lora_dX`` — bf16 fused dX backward * ``_scatter2scatter_lora_mx`` — MXFP4 fused base+LoRA forward * ``_scatter2scatter_lora_dX_mx`` — MXFP4 fused dX backward The cast pattern matches the dense kernel: when ``INT64_INDICES=True``, the per-launch ``M_block`` range and the scattered ``M_idx`` lookup are cast to int64 before they enter the ``M_*_idx * stride_*m`` pointer arithmetic. That promotes the multiplication to int64 and prevents the silent overflow at ``L_scattered * y_dim >= 2**31`` that the chunking workaround on the dense path was guarding against. The Python-side wrappers (``scatter2scatter_lora``, ``scatter2scatter_lora_dX``, ``scatter2scatter_lora_mx``, ``scatter2scatter_lora_dX_mx``) gain an ``int64_indices: bool = False`` kwarg and forward it to the kernel via the constexpr. Auto-dispatch from tensor sizes lands in a follow-up commit. PR axolotl-ai-cloud#3667's chunking workaround only covered the bf16 dense forward; the LoRA path had the same architectural risk and wasn't covered. With these constexprs in place and the wrapper-side dispatch coming next, the kernel itself becomes int64-safe for all five variants and the chunking wrapper can be retired. Signed-off-by: Wing Lian <wing@axolotl.ai> * feat(scattermoe-lora): add INT64_INDICES to group_bwd_lora kernels Adds ``INT64_INDICES: tl.constexpr = False`` to the three LoRA gradient kernels that index into the grouped M dimension: * ``_group_bwd_lora`` — non-split LoRA-grad kernel (used by autotune-collector mocks; kept in sync for telemetry / future callers) * ``_group_bwd_lora_split`` — split dA/dB kernel that the public ``group_bwd_lora`` wrapper actually dispatches today * ``_group_bwd_lora_fused`` — fused gather + dA/dB kernel used by the LoRA path in ``parallel_linear_lora.py`` In each kernel, when ``INT64_INDICES=True`` we cast: - the per-expert ``start_idx`` / ``end_idx`` (and the fused kernel's ``real_*`` variants) to int64 on load, - ``M_block = tl.arange(0, BLOCK_M)`` to int64 so the per-iter ``M_idx = start_idx + i * BLOCK_M + M_block`` propagates int64, - and (in the fused kernel) ``scatter_idx`` from sorted-index lookups to int64 so ``scatter_idx * stride_dym`` and the ``X_token_idx = scatter_idx // FAN_OUT`` arithmetic stay int64. Strides themselves stay as the kernel receives them (already int64 at the Python level via ``tensor.stride()``). Triton JITs a separate variant per constexpr value, so the int32 fast path is unchanged. Python wrappers (``group_bwd_lora`` and ``group_bwd_lora_fused``) gain an ``int64_indices: bool = False`` kwarg that forwards to the kernel. Wrapper-level auto-dispatch from tensor sizes lands in the next commit. Signed-off-by: Wing Lian <wing@axolotl.ai> * feat(scattermoe-lora): auto-dispatch INT64_INDICES based on tensor sizes Adds ``_needs_int64_indices(*tensors)`` in ``parallel_experts.py``: True iff any input/output tensor's ``numel() >= 2**31 - 1``. That's a sufficient condition for the kernel's ``M_idx * stride_*m`` pointer arithmetic to overflow int32 somewhere in the buffer. Wires the result through the autograd Functions: * ``ParallelLinear`` (``parallel_experts.py``): forward computes ``needs_int64 = (L_scattered * y_dim) >= INT_MAX or _needs_int64_indices(x)`` and forwards via the new ``int64_indices`` kwarg on ``_scatter2scatter_int32_safe``. The wrapper's fast path now passes ``int64_indices`` through to ``kernels.ops.scatter2scatter`` so the kernel takes the int64 path at overflow scale. The wrapper also adds a new branch above the chunking path that routes directly to the int64 kernel when ``int64_indices=True`` is requested — notably this covers the y_grouped=False overflow case that the chunking workaround used to raise on. Backward follows the same pattern using ``L_scattered * K`` for the dX-axis bound. * ``ScatterMoELoRA`` (``parallel_linear_lora.py``): forward computes ``needs_int64`` from ``L_scattered * N`` and forwards to ``scatter2scatter_lora`` / ``scatter2scatter_lora_mx``. Backward computes a single ``needs_int64_bwd`` from ``M_total * max(N, K)`` (covering both the dX and the dA/dB kernels' index ranges) and forwards to ``group_bwd_lora`` / ``group_bwd_lora_fused`` and to the dX kernels (bf16 + MX, fused and non-fused). The auto-dispatch is cheap (one ``Tensor.numel()`` per check) and Triton JITs a separate kernel variant per constexpr value, so the int32 fast path is unaffected for small/medium shapes. Signed-off-by: Wing Lian <wing@axolotl.ai> * test(scattermoe-lora): int32-vs-int64 parity and overflow correctness Adds ``tests/integrations/test_scattermoe_lora_int64_indices.py`` covering two properties of the new INT64_INDICES path: * **Bitwise parity at non-overflow shapes.** For each of the modified kernels, ``INT64_INDICES=False`` and ``INT64_INDICES=True`` compute the same MMA in the same accumulation order — only the index *type* changes. The tests assert ``torch.equal`` between the two variants for the dense forward (both y_grouped=True and y_grouped=False), the LoRA forward, and the LoRA dX backward. For ``_group_bwd_lora_split`` the assertion is bitwise; for ``_group_bwd_lora_fused`` it's ``torch.allclose`` within bf16 tolerance because that kernel uses ``tl.atomic_add`` whose ordering is non-deterministic across launches (so bit-equality is not achievable between any two runs of the same variant, let alone across variants). * **Overflow correctness at the failing bench shape.** At L_scattered=262144 / y_dim=16384 (2**32 element output), the ``INT64_INDICES=True`` kernel populates every row of the output (including rows past the int32 overflow boundary) and matches the chunked workaround within a generous bf16 tolerance. A second bench-shape test runs the real ``ParallelLinear`` forward and uses a monkeypatched spy on ``scatter2scatter_compileable`` to assert the auto-dispatcher routes through the *direct* int64 kernel call (one launch) and **not** the chunking workaround (>=2 launches). Also folds in a kernel-side fix that the parity tests caught: the group_bwd_lora kernels' ``if E_idx == 0: start_idx = 0`` branch produced a plain int32 zero in Triton, which clashes with the int64 ``start_idx`` produced by the else-branch under ``INT64_INDICES=True``, firing ``AssertionError: Mismatched type for start_idx between then block (int32) and else block (int64)`` at compile. Switching the zero-initialisation to ``tl.zeros([], dtype=tl.int64/tl.int32)`` keeps both branches' types consistent. The bench-shape tests are skipped when free GPU memory is below 80 GiB. Signed-off-by: Wing Lian <wing@axolotl.ai> * bench(scattermoe-lora): int64-vs-int32 indexing overhead Adds ``tests/integrations/kernels/scattermoe_lora/bench_int64_kernel.py``, a stand-alone script (not pytest) that times the dense ``kernels.ops.scatter2scatter`` at three representative shapes and reports ms/iter for both ``INT64_INDICES=False`` (int32 fast path) and ``INT64_INDICES=True`` (int64 safe path): * **small** — seq=8K, top_k=8, hidden=2048, N=2048 (auto_int64=False) * **medium** — seq=128K, top_k=8, hidden=2048, N=2048 (auto_int64=False) * **overflow** — seq=512K with 16 shards → L_scattered=262144, N=16384 (auto_int64=True; the previously-failing bench config) At overflow shapes the int32 path is silently incorrect, so the int32 column is replaced by the chunked workaround from PR axolotl-ai-cloud#3667 as the apples-to-apples baseline. Results land in ``bench_int64_kernel_results.md`` next to the script. Captured on RTX PRO 6000 Blackwell Max-Q (1.79 TB/s HBM): shape int32 ms int64 ms chunked ms penalty ------------- -------- -------- ---------- ------- small 2.687 2.689 — +0.0% medium 40.220 40.581 — +0.9% overflow — 79.572 79.985 -0.5% Both acceptance bounds are comfortably met: ≤5% on the int32 fast path (actual: <1%), and ≤25% on the int64 path vs the chunked workaround (actual: −0.5%, i.e. the int64 kernel is slightly *faster* than chunking at this shape because it avoids the per-chunk launch overhead). Signed-off-by: Wing Lian <wing@axolotl.ai> refactor(scattermoe-lora): deprecate _scatter2scatter_int32_safe chunking now that kernel is int64-safe PR axolotl-ai-cloud#3667's ``_scatter2scatter_int32_safe`` chunking wrapper was the minimum-scope fix for the int32 pointer-overflow at the failing bench config: it tiled the call along the L_scattered axis to keep each sub-launch's ``rows * y_dim < 2**31``. With the kernel-level ``INT64_INDICES`` constexpr now landing on every relevant scatter2scatter family kernel and the wrapper-level auto-dispatch plumbed through both ``parallel_experts.py`` and ``parallel_linear_lora.py``, the chunking workaround is redundant — the kernel handles the overflow itself in a single launch. Removes from ``parallel_experts.py``: * ``_scatter2scatter_int32_safe`` and its 160-line chunking loop * ``_SCATTER2SCATTER_INT32_LIMIT`` and ``_SCATTER2SCATTER_BLOCK_M`` constants used only by the chunking path * the ``RuntimeError`` raise for ``y_grouped=False`` at overflow scale — the int64 kernel handles that case directly Routes ``ParallelLinear.forward`` / ``.backward`` straight to ``kernels.ops.scatter2scatter`` with ``int64_indices=needs_int64``. The bench config (seq=524288, 16 shards → L_scattered=262144, y_dim=16384, output=2**32 elements) now goes through a single int64 kernel launch and matches the bench-recorded perf (79.6 ms/iter vs. the chunked workaround's 80.0 ms/iter — slightly *faster* because it eliminates the per-chunk launch overhead). The PR axolotl-ai-cloud#3667 repro tests are retained as regression guards and updated to call the new direct-kernel path: * ``test_scatter2scatter_below_threshold_no_overhead`` (renamed from ``test_int32_safe_wrapper_matches_direct_call_below_threshold``) asserts INT64_INDICES=False vs True is bit-identical at non- overflow shapes — guards the int32 fast path. * ``test_scatter2scatter_no_corruption_at_overflow_shape`` (renamed from ``test_int32_safe_wrapper_no_corruption_at_overflow_shape``) asserts the int64 kernel populates rows past the int32 overflow boundary — guards the kernel-level overflow fix. * ``test_parallel_linear_long_seq_routing_combination`` is unchanged; it runs ``parallel_linear`` end-to-end at the bench shape and asserts no all-zero rows / no NaNs — guards the auto-dispatch wiring. The new ``test_parallel_linear_overflow_takes_int64_kernel_path`` in ``test_scattermoe_lora_int64_indices.py`` is also updated to monkey- patch ``scatter2scatter_compileable`` and assert the single launch sets ``int64_indices=True``, which directly verifies the auto- dispatch verdict at the failing bench shape. Signed-off-by: Wing Lian <wing@axolotl.ai> * chore(scattermoe-lora): pre-commit fixups for INT64 indices commits * ruff format pass on the four touched files (line-wrapping only, no functional changes). * ``parallel_linear_lora.py``: replace the one-line conditional ``N_dim = expert_weights.N if is_mx else expert_weights.size(-1)`` with an explicit if/else and ``# type: ignore[union-attr]`` on each branch — mypy can't narrow ``Union[Tensor, MXWeights]`` through a ternary, but does respect the explicit branches enough to need the ignore only on the offending attribute access. The pre-existing ternary in the backward path stays as-is (already covered by the surrounding type checks). * ``bench_int64_kernel.py``: drop imports of the removed ``_scatter2scatter_int32_safe`` / ``_SCATTER2SCATTER_INT32_LIMIT`` symbols (they went away in the refactor commit) and the now-unused ``ms_chunk`` column. The bench now reports int32 vs int64 timings only; the overflow row shows only int64 since the int32 kernel is silently incorrect there. Signed-off-by: Wing Lian <wing@axolotl.ai> * test(scattermoe-lora): add small-shape int64 overflow tests (run on L40S/24 GiB) The two bench-shape overflow tests above need ~80 GiB free and skip on the Modal CI L40S 48 GiB runner, so the actual overflow path the kernel fix targets was not exercised on CI. The new ..._small variants repro the same property at the smallest shape that still straddles the int32 boundary: L_scattered * y_dim = 2**32 (2x past 2**31, guaranteed overflow without int64_indices=True), with E=4, K=256, y_dim=4096 so W is ~8 MiB and the only big allocation is the ~8 GiB scatter output. Gated at 12 GiB free to leave headroom for pytest-xdist workers on 48 GiB devices. * fix(scattermoe-lora): bump _SMALL_E so int64-overflow topk is valid _SMALL_TOP_K=8 with _SMALL_E=4 makes torch.topk(logits[T,4], k=8) raise 'selected index k out of range', skipping the two _small overflow tests. The shape invariant L_scattered*y_dim=2**32 (T=131072, y_dim=4096) requires top_k=8, so E must be >= 8. * perf(scattermoe-lora): bucket M in autotune key to dedupe sweeps The 7 multi-config @triton.autotune kernels in lora_ops.py keyed on ["M", "N", "K"]. M = X.size(0) (or DY.size(0)) scales with batch*seq*top_k, so any seqlen variation triggers a fresh 30-60 config sweep per step until the cache happens to cover every realized M. With N, K model-fixed this was the only churning dimension. Add a phantom M_BUCKET arg to each kernel signature and switch the autotune key to ["M_BUCKET", "N", "K"]. The kernel still runs on the real M (loop bounds + masks unchanged); only the cache lookup is bucketed to the next multiple of _M_BUCKET_GRANULARITY=1024. No padding, no wasted FLOPs. autotune_collector._KEY_NAMES tracks the renamed key so telemetry matches what's actually in the .cache dict. Tests: - New test_scattermoe_lora_m_bucket.py pins both directions: same-bucket M values produce one cache entry, distinct-bucket M values produce two. - Updated telemetry test assertions for the renamed key. - Existing scattermoe-lora suite (62 tests) + int64 indices (10 tests) + telemetry (13 tests) all pass unchanged. --------- Signed-off-by: Wing Lian <wing@axolotl.ai>
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR adds non-streaming multimodal CPT via ChangesNon-streaming Multimodal CPT via Datasets
🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
@coderabbitai full review |
✅ Actions performedFull review triggered. |
|
📖 Documentation Preview: Deployed on Netlify from commit 8d15592 |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/axolotl/prompt_strategies/multimodal_pretrain.py (2)
32-42:⚠️ Potential issue | 🟠 Major | ⚡ Quick winRestore fail-fast sequence-length validation in the map path.
encode_multimodal_pretrain()already avoids truncation; settingenforce_max_length=Falsehere just disables the only early guard for oversized rows. Those samples then survive preprocessing, and the collator only logs a warning when they exceedsequence_len, which makes this much harder to diagnose at training time.Suggested fix
def _encode_batch(self, examples: dict[str, list]) -> dict[str, list]: return encode_multimodal_pretrain( examples, tokenizer=self.tokenizer, max_tokens=self.sequence_len, image_token=self.image_token_spec.image_token, image_token_id=self.image_token_spec.image_token_id, text_column=self.text_column, image_column=self.image_column, - enforce_max_length=False, + enforce_max_length=True, )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 32 - 42, The _encode_batch wrapper disables the early, fail-fast sequence-length guard by passing enforce_max_length=False into encode_multimodal_pretrain, allowing oversized samples to slip through preprocessing and only be warned about later in the collator; update _encode_batch to enable the guard (pass enforce_max_length=True) or remove the override so encode_multimodal_pretrain uses its default fail-fast behavior, referencing _encode_batch, encode_multimodal_pretrain, enforce_max_length and sequence_len to locate the change.
85-100:⚠️ Potential issue | 🟠 Major | ⚡ Quick winFail fast when
tokenizerdoesn’t matchprocessor.tokenizer(and re-enable max-length validation)
build_image_token_spec()resolvesimage_token_idfromprocessor.tokenizer, butMultiModalPretrainDatasetWrappingStrategy._encode_batch()tokenizes/counts placeholders with the separatetokenizerpassed toload(). Enforcetokenizer is processor.tokenizer(dataset encoding happens before the collator) to keep placeholder counting/label masking aligned._encode_batch()callsencode_multimodal_pretrain(... enforce_max_length=False), disabling the only early guard for rows that exceedsequence_len. This defers failures and risks placeholder/image-count inconsistencies later—enableenforce_max_length(or make it configurable) so oversized rows fail fast.Suggested fix
def load( tokenizer, cfg, ds_cfg: Optional[dict[str, Any]] = None, processor: ProcessorMixin | None = None, ): ds_cfg = ds_cfg or {} if processor is None: raise ValueError( "Multimodal CPT (type: multimodal_pretrain) requires a processor. " "Set `processor_type: AutoProcessor` (or the concrete processor " "class) in your config." ) check_processor_compatibility(processor) + proc_tokenizer = getattr(processor, "tokenizer", None) + if proc_tokenizer is not None and proc_tokenizer is not tokenizer: + raise ValueError( + "Multimodal CPT requires `tokenizer` to be `processor.tokenizer` " + "so image placeholder ids stay aligned during encoding." + ) text_column = ds_cfg.get("text_column") or "text"🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 85 - 100, Ensure the tokenizer passed to load() is exactly the same object as processor.tokenizer by adding a fast-fail check (use check_processor_compatibility or add an explicit assertion) so tokenizer is processor.tokenizer before creating MultiModalPretrainDatasetWrappingStrategy; this guarantees build_image_token_spec and MultiModalPretrainDatasetWrappingStrategy._encode_batch use the same vocab for placeholder IDs. Also re-enable max-length validation by calling encode_multimodal_pretrain with enforce_max_length=True (or expose a configurable flag) inside MultiModalPretrainDatasetWrappingStrategy._encode_batch so rows that exceed cfg.sequence_len fail fast rather than downstream.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@src/axolotl/prompt_strategies/multimodal_pretrain.py`:
- Around line 32-42: The _encode_batch wrapper disables the early, fail-fast
sequence-length guard by passing enforce_max_length=False into
encode_multimodal_pretrain, allowing oversized samples to slip through
preprocessing and only be warned about later in the collator; update
_encode_batch to enable the guard (pass enforce_max_length=True) or remove the
override so encode_multimodal_pretrain uses its default fail-fast behavior,
referencing _encode_batch, encode_multimodal_pretrain, enforce_max_length and
sequence_len to locate the change.
- Around line 85-100: Ensure the tokenizer passed to load() is exactly the same
object as processor.tokenizer by adding a fast-fail check (use
check_processor_compatibility or add an explicit assertion) so tokenizer is
processor.tokenizer before creating MultiModalPretrainDatasetWrappingStrategy;
this guarantees build_image_token_spec and
MultiModalPretrainDatasetWrappingStrategy._encode_batch use the same vocab for
placeholder IDs. Also re-enable max-length validation by calling
encode_multimodal_pretrain with enforce_max_length=True (or expose a
configurable flag) inside
MultiModalPretrainDatasetWrappingStrategy._encode_batch so rows that exceed
cfg.sequence_len fail fast rather than downstream.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: bb546c3e-e3f3-4e0e-bb9a-17a08bd60bb4
📒 Files selected for processing (7)
examples/qwen2_5-vl/mm-cpt-nonstreaming-qlora.yamlexamples/qwen2_5-vl/mm-cpt-streaming-qlora.yamlsrc/axolotl/prompt_strategies/multimodal_pretrain.pysrc/axolotl/utils/data/shared.pysrc/axolotl/utils/schemas/datasets.pytests/prompt_strategies/test_multimodal_pretrain.pytests/utils/data/test_hash.py
💤 Files with no reviewable changes (1)
- src/axolotl/utils/schemas/datasets.py
🚧 Files skipped from review as they are similar to previous changes (2)
- src/axolotl/utils/data/shared.py
- examples/qwen2_5-vl/mm-cpt-nonstreaming-qlora.yaml
…fault (axolotl-ai-cloud#3680) * feat(qwen): fused RMSNorm+RoPE for Qwen3 / Qwen3-MoE / Qwen3.5 / Qwen3.5-MoE Generalizes the existing Gemma 4 fused RMSNorm+RoPE Triton kernel to four new Qwen attention variants, and auto-enables Liger's fused (m-)rope kernel for the Qwen-VL family. Eager-mode behavior is bit-identical when the new cfg.fused_attn_kernel flag is unset. Changes ------- * New ``cfg.fused_attn_kernel: bool | None`` (default None / off). When set, replaces ``q_norm + apply_rotary_pos_emb`` (and the matching K path) with a single fused RMSNorm+RoPE Triton kernel launch. Currently wired for ``qwen3``, ``qwen3_moe``, ``qwen3_5``, and ``qwen3_5_moe`` model_config_types. Llama4 is out of scope (complex freqs_cis + Llama4TextL2Norm post-RoPE — separate kernel). * Kernel ``UNIT_OFFSET: tl.constexpr`` flag added to the forward + backward Triton kernels for Qwen3.5's Gemma-style ``(1.0 + weight)`` RMSNorm. Default ``False`` keeps Gemma 4 / Qwen3 / Qwen3-MoE bit-identical to before. Threaded through the triton_op + register_autograd plumbing. * Refactors ``fused_rms_norm_rope`` / ``fused_rms_norm_noscale`` to ``torch.library.triton_op`` + ``register_autograd`` so they trace under ``torch.compile(fullgraph=True)``. Validated: 1 Dynamo frame, 0 graph breaks. On sm_120 the compile path composes to +9.2% combined, −33% peak memory. On sm_86 the surrounding Inductor-generated kernels regress — leave ``torch_compile: false`` there; schema description documents the per-arch recommendation. * Liger Qwen-VL auto-default: when ``cfg.liger_rope is None`` and model_config_type is one of qwen2_vl/qwen2_5_vl/qwen3_vl (+ ``_text`` variants), pass ``rope=True`` so upstream's fused m-rope kernel is actually installed. Previously the plugin overrode the upstream default to None, silently skipping the kernel. * Patch-ordering fix: ``_apply_self_attention_lora_patch`` now runs before ``_apply_model_specific_patches`` in ``apply_pre_model_load_patches``. ``patch_self_attn_lora`` reads ``inspect.getsource`` of the attention class' forward, so any patch that replaces ``Attention.forward`` must run *after* the source-rewrite step. The wrong order also silently broke Gemma 4 + ``lora_qkv_kernel`` — pinned by ``TestPatchManagerOrdering`` and a fused-first trip-wire. Tests ----- * Per-model parity + backward grad flow for Qwen3, Qwen3-MoE, Qwen3.5, Qwen3.5-MoE (full-attention layers only; linear_attention layers stay on the stock GatedDeltaNet path). * Kernel ``UNIT_OFFSET=True`` parity vs from-scratch reference + bwd parity vs torch-eager + ``torch.compile(fullgraph=True)`` parity. * ``torch.compile(fullgraph=True)`` parity for the no-offset path. * Liger Qwen-VL auto-default for all 6 model_config_types; explicit ``False`` is respected. * Patch idempotency (double-apply is a no-op). * Transformers signature contract — pins the stock attention forward argument names so future drift trips loudly at test time. * Gradient-checkpointing composition (Qwen3 + ``gradient_checkpointing_enable``). * Flash-Attention 2 composition (skip-if-unavailable). * LoRA + fused composition on Qwen3 / Qwen3.5 / Qwen3.5-MoE, with fused-first reverse-order trip-wires that catch the original ordering bug if anyone re-introduces it. A pre-existing upstream-drift xfail in ``test_gemma4_fused_attn.py`` documents Gemma 4 + ``lora_qkv_kernel`` being broken in transformers 5.8.1 (new ``shared_kv_states: dict[str, ...]`` signature drift in QKV_PATCHES). Out of scope for this PR; flips to XPASS when patched. Post-review fixes ----------------- * ``_resolve_norm_module``: PEFT ``ModulesToSaveWrapper`` stores ``active_adapter`` as ``list[str]`` (e.g. ``["default"]``), not a string. The prior ``isinstance(adapter, str)`` check silently returned the frozen ``original_module`` for every real-PEFT case. Switched to iterating ``active_adapters`` (with ``active_adapter`` fallback) across all 4 patches. Added a direct unit-test plus an end-to-end test that drives real ``peft.get_peft_model(modules_to_save=["q_norm","k_norm"])`` and asserts the helper returns the trainable adapter weight. * ``cfg.fused_attn_kernel`` unsupported-model warning: moved out of the Pydantic ``model_validator(mode="before")`` (which ran *before* ``normalize_config()`` had derived ``model_config_type``, so it silently no-op'd on normal YAML input) into a new ``PatchManager._warn_if_fused_attn_unsupported`` staticmethod invoked from ``_apply_model_specific_patches``, where ``model_config_type`` is guaranteed set. Added a source-line guard that the helper stays wired. * address coderabbit comments * improve bwd pass throughput * feat(qwen3-vl): add fused attention patch * test: capture fused attention logs from concrete loggers * ci: rerun tests --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
…xolotl-ai-cloud#3661) * compute kd loss in trainer * add kd trainer compute_loss tests * remove unused kd kernel patch module * don't materialize all the logits * ensure dtype from hidden states matches dtype for chunked kd since we're not inside the autocasting anymore --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
68d96fc to
8534115
Compare
…tl-ai-cloud#3689) [skip-ci] When the suite runs under pytest-xdist, multiple workers race for the same physical GPU's memory budget. A test that fits comfortably in isolation can OOM purely because peer workers are already holding most of VRAM (observed: 8 workers each holding ~44 GiB on a 44 GiB card). Add a conftest in tests/integrations/kernels/scattermoe_lora/ that hooks pytest_runtest_call and converts torch.OutOfMemoryError into a skip. Real correctness bugs still surface as failures since they raise asserts / typed exceptions, not OOM. Uses a hookwrapper rather than an autouse fixture because pytest captures the test exception before re-entering the fixture's generator, so the fixture's try/except around yield never sees it.
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
8534115 to
0b14383
Compare
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/prompt_strategies/test_multimodal_pretrain.py`:
- Line 154: The pytest.raises assertion uses match="processor.tokenizer" where
the dot is a regex wildcard; change it to a literal match by escaping the dot
(e.g., match=r"processor\.tokenizer") or by using
re.escape("processor.tokenizer") so the test asserts the exact substring; update
the pytest.raises call in tests/prompt_strategies/test_multimodal_pretrain.py
accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: f480a349-a646-49e2-96c1-33934cc3826c
📒 Files selected for processing (10)
src/axolotl/prompt_strategies/multimodal_pretrain.pysrc/axolotl/utils/data/sft.pysrc/axolotl/utils/data/shared.pysrc/axolotl/utils/schemas/datasets.pytests/prompt_strategies/test_multimodal_pretrain.pytests/test_multimodal_streaming.pytests/utils/data/test_hash.pytests/utils/data/test_mm_cpt_eval.pytests/utils/data/test_mm_pretrain_cache.pytests/utils/schemas/validation/test_multimodal_cpt.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/utils/schemas/validation/test_multimodal_cpt.py
- src/axolotl/prompt_strategies/multimodal_pretrain.py
0b14383 to
8d15592
Compare
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
…3697) * add pytorch 2.12 base and prune unused base images * Add back 2.11.0 and add them to basic pytest matrices
* bump transformers to 5.9.0 and trl to 1.5.1 * test(gemma4-kernelize): accept ValueError from transformers 5.9 attach_hidden_kernels transformers ≤5.8 surfaced the non-Module ``_hidden_kernels`` entry as TypeError/AttributeError via ``module.register_module(name, fn)``. 5.9 reworked ``attach_hidden_kernels`` to raise ``ValueError`` directly with a clearer error message. The patch under test (strip dead entries before ``kernelize()`` runs) does the right thing either way; broaden the expected-crash assertion so the test reflects current upstream behavior. * 30 min timeout * fix(activation-offload): drop monkey-patched __enter__ now that TRL 1.5.1 ships upstream fix TRL 1.5.1 implements huggingface/trl#5730 natively — ``OffloadActivations`` now has its own ``__enter__`` that clears tracker / stashes between steps, **plus** two things the axolotl backport never had: - ``self.tensor_id = 0`` reset (without this, the tensor_id counter accumulates across steps; harmless on its own but skews the ``fwd_stash`` eviction window). - ``torch.cuda.empty_cache()`` when bitsandbytes is loaded — flushes the BNB allocator between steps so its compute / optimizer-state buffers don't accumulate as live storage. TRL 1.5.1 also adds a ``__exit__`` that syncs the offload streams (``s0``, ``s1``) before the parent cleanup runs. The axolotl backport only overrode ``__enter__``, so ``__exit__`` was inherited correctly either way. Once we bumped TRL 1.1.0 → 1.5.1 (transformers 5.9 bundle), the monkey-patch became strictly worse than upstream — it shadowed the better ``__enter__``, dropping the ``tensor_id`` reset and the BNB ``empty_cache``. Combined with cu130's stricter cross-stream lifetime checks, this surfaced as XID 43 (driver-killed CUDA channel) during ``test_activation_offloading[lora]``, followed by every subsequent test failing at ``torch.manual_seed(42)`` because the CUDA context was permanently poisoned. Drop the patch and the wrapper — upstream is now the source of truth, per the existing TODO in this file.
* prefer latest pytorch as gated e2e tests
* fix(fsdp2-qlora): match _init_sharded_param anchor for torch 2.12 + fallback to 2.11
torch 2.12.0 rewrote the sharded-param construction in
FSDPParam._init_sharded_param from a two-line form
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)
to a single multi-line Parameter() call with requires_grad= as a kwarg
self.sharded_param = nn.Parameter(
self.to_sharded_dtensor(sharded_param),
requires_grad=param.requires_grad,
)
Functionally identical, but the axolotl monkey-patch is source-level
text replacement: the 2.11 anchor no longer matches the 2.12 source, so
the substitution silently falls through to the warning branch and the
method stays unpatched — bnb Params4bit / Int8Params lose their
quantization metadata through the FSDP2 shard cycle.
Try the 2.12 anchor first; fall back to the 2.11 anchor so the patch
keeps working against both torch versions in our test matrix.
init_unsharded_param uses the same kwarg-style call in both 2.11 and
2.12, so its anchor is untouched.
* fix(fsdp2-qlora): match init_unsharded_param anchor for torch 2.12
torch 2.12 hoisted the unsharded-param construction out of the
first-all-gather `else:` branch up to method-body level, so the 2.11
anchor (8-space, inside else) no longer matched and the patch silently
no-op'd. This left bitsandbytes Params4bit unreconstructed under FSDP2,
surfacing as `mat1 and mat2 shapes cannot be multiplied (... 1x36864)`
in QLoRA training. Add the 2.12 method-body-level anchor with its own
replacement indentation, falling back to the 2.11 form.
* test(multigpu): stabilize test_lora_ddp with 20 steps + seed
test_lora_ddp ran only 2 steps with no seed, so train_loss was a random
draw (observed 1.95-3.23 across runs) and the 2.8 threshold tripped
intermittently — the torch 2.12 bump just happened to surface it. Run 20
steps with seed=42 to make the loss deterministic (2.189-2.191 spread),
and tighten the threshold to 2.5.
* fix(optimizers): support torch 2.11 graph health-check rename in ADOPT
torch 2.11 renamed Optimizer._cuda_graph_capture_health_check to
_accelerator_graph_capture_health_check (2.12 re-added the old name as an
alias). ADOPT called the old name, so it raised AttributeError under torch
2.11 — surfaced by bumping the docker-e2e row from 2.9.1 to 2.11.0. Resolve
whichever name exists, preferring the new one. Also swap the deprecated
torch._utils.is_compiling() for torch.compiler.is_compiling().
axolotl-ai-cloud#3700) [skip ci] The pyproject migration removed setup.py, so the publish workflow failed at `python setup.py sdist` (No such file). Build the sdist+wheel with `uv build` (PEP 517; setuptools backend reads the version from VERSION). Also make the GitHub release step idempotent so a re-run/re-tag of an existing release doesn't fail, and drop the unused dependency-install step.
…tl-ai-cloud#3701) The fused Gemma4 attention monkeypatch read and stored shared KV states by `kv_shared_layer_index`/`layer_idx`, but transformers 5.8 dropped the `kv_shared_layer_index` attribute and switched to keying `shared_kv_states` by `layer_type`. On the pinned transformers 5.9, any Gemma4 model with `num_kv_shared_layers > 0` (e.g. gemma-4-E2B vision) raised `AttributeError: 'Gemma4TextAttention' object has no attribute 'kv_shared_layer_index'` once execution reached a shared layer. Derive the read/store key from whichever attribute the installed transformers exposes, keeping compatibility with both the old and new APIs. Add a fused-attn regression with `num_kv_shared_layers > 0` so the shared-KV branch is actually exercised (existing tests defaulted to 0).
…er tests (axolotl-ai-cloud#3705) [skip ci] The Python 3.12 PyTest legs run ~2x slower than 3.14 on the same test set (816s vs 403s) and were tipping over the 30-minute job timeout. Two causes, both in the slow tail: - dataset_num_proc=4 forks 4 dataset workers per .map() on CPU-only runners, each re-importing the torch stack to process a few hundred rows — pure overhead. Lower to 1 in the affected tests (none assert on it or test multiprocessing); results are unchanged. - --dist loadfile pins a whole file to one worker, so the entire builder suite serialized on a single worker at the end. Move shared fixtures to tests/core/conftest.py and split the RL trainer-builder tests into test_builders_rl.py so they run on a separate worker from the SFT/reward builder tests.
Description
Builds on upstream PR axolotl-ai-cloud#3629, which adds the initial raw image+text multimodal continued pretraining path. That upstream PR is still pending, so this branch is the next layer on top of that MM CPT work.
This PR broadens MM CPT from the initial streaming-only path into the dataset modes users need for practical continued pretraining:
datasets: type: multimodal_pretrainfor raw image+text rows.datasets:withskip_prepare_dataset: true.num_epochsand let Axolotl infermax_stepsfrom dataset length.Main implementation pieces:
MultiModalPretrainDatasetWrappingStrategyfor the non-streamingdatasets:pipeline.images,_mm_text,input_ids,attention_mask, andlabels._mm_textandimagesfor processor-driven multimodal collation.dataset_prepared_pathand resume behavior throughignore_data_skip.Example raw non-streaming config:
Example already-tokenized config:
Expected already-tokenized row shape:
{"_mm_text": "<image>\nText target.", "images": ["image.png"], "input_ids": [1, 2, 3], "attention_mask": [1, 1, 1], "labels": [1, 2, 3]}Example streaming resume/cache config:
Motivation and Context
The initial MM CPT work proves raw image+text continued pretraining, but it leaves common workflows uncovered:
datasets:path for MM CPT prepared-dataset workflows.max_steps; map-style datasets have a known length, sonum_epochsshould be enough.This PR keeps the existing streaming/pretraining behavior intact while adding the map-style dataset path. The behavior is intentionally split:
pretraining_datasetandstreaming: truestill require explicitmax_steps.datasets: type: multimodal_pretraincan usenum_epochs; Axolotl calculates total training steps from the prepared dataset length.How has this been tested?
Static and focused tests:
Real model smoke validation used a local Qwen3-VL-8B-Instruct checkpoint with 4-bit QLoRA,
processor_type: AutoProcessor,sequence_len: 4096,micro_batch_size: 1,gradient_accumulation_steps: 1,sample_packing: false, andremove_unused_columns: false.Raw non-streaming preprocess:
image_base_dir.axolotl preprocess raw_epoch.ymlimages,input_ids,labels,attention_mask,_mm_text._mm_textwas preserved andimagesremained as image references for batch-time processor/collator loading.Raw non-streaming epoch train:
datasets: type: multimodal_pretrain,streaming: false,dataset_prepared_path,num_epochs: 1, and nomax_steps.axolotl train raw_epoch.ymlMaximum number of steps set at 2.global_step=2,max_steps=2,epoch=1.0.trainer_state.json,optimizer.pt,scheduler.pt, andtokens_state.json.Already-tokenized non-streaming epoch train:
_mm_text,images,input_ids,attention_mask, andlabels.datasets: type: multimodal_pretrain,skip_prepare_dataset: true,num_epochs: 1, and nomax_steps.axolotl train pretokenized_epoch.ymlMaximum number of steps set at 2.global_step=2,max_steps=2,epoch=1.0.trainer_state.json,optimizer.pt,scheduler.pt, andtokens_state.json.Streaming MM CPT cache/resume smoke:
pretraining_dataset: type: multimodal_pretrain,streaming: true,dataset_prepared_path,ignore_data_skip: true, andmax_steps: 1.global_step=1.resume_from_checkpoint: checkpoint-1andmax_steps: 2.global_step=2,max_steps=2.{"total": 128, "trainable": 8}at checkpoint 1 to{"total": 256, "trainable": 16}at checkpoint 2.trainer_state.json,optimizer.pt,scheduler.pt,scaler.pt,rng_state.pth, andtokens_state.json.AI Usage Disclaimer
Yes. OpenAI Codex assisted with implementation, local validation, and drafting this PR summary. The changes were reviewed and tested locally before pushing.
Screenshots (if appropriate)
N/A
Types of changes
datasets: type: multimodal_pretrainskip_prepare_dataset: truemax_stepsSocial Handles (Optional)
N/A
Summary by CodeRabbit
Documentation
New Features
Validation
Data handling
Tests