Skip to content

feat(mm-cpt): broaden multimodal CPT dataset paths#31

Closed
thad0ctor wants to merge 39 commits into
feat/mm-cpt-dataset-pipeline-basefrom
feat/mm-cpt-dataset-pipeline-clean
Closed

feat(mm-cpt): broaden multimodal CPT dataset paths#31
thad0ctor wants to merge 39 commits into
feat/mm-cpt-dataset-pipeline-basefrom
feat/mm-cpt-dataset-pipeline-clean

Conversation

@thad0ctor

@thad0ctor thad0ctor commented May 29, 2026

Copy link
Copy Markdown
Owner

Description

Builds on upstream PR axolotl-ai-cloud#3629, which adds the initial raw image+text multimodal continued pretraining path. That upstream PR is still pending, so this branch is the next layer on top of that MM CPT work.

This PR broadens MM CPT from the initial streaming-only path into the dataset modes users need for practical continued pretraining:

  • Non-streaming datasets: type: multimodal_pretrain for raw image+text rows.
  • Already-tokenized MM CPT rows under datasets: with skip_prepare_dataset: true.
  • Epoch-based non-streaming MM CPT runs, so users can set num_epochs and let Axolotl infer max_steps from dataset length.
  • Streaming MM CPT cache/resume support, so resume can avoid slow dataloader fast-forward through already-seen multimodal/image samples.

Main implementation pieces:

  • Adds a MultiModalPretrainDatasetWrappingStrategy for the non-streaming datasets: pipeline.
  • Produces prepared MM CPT rows with images, _mm_text, input_ids, attention_mask, and labels.
  • Preserves image references in prepared datasets and lets the MM CPT collator load/process images at batch time with the configured processor.
  • Lets already-tokenized rows pass through the non-streaming dataset path while still retaining _mm_text and images for processor-driven multimodal collation.
  • Adds streaming MM CPT caching through dataset_prepared_path and resume behavior through ignore_data_skip.
  • Wires MM CPT through dataset strategy selection, config normalization, schema validation, eval dataset handling, dataset hashing, docs, and tests.
  • Adds Qwen2.5-VL MM CPT example configs for non-streaming epoch-based QLoRA and streaming cache/resume QLoRA.

Example raw non-streaming config:

datasets:
  - path: /path/to/train.jsonl
    ds_type: json
    type: multimodal_pretrain
    split: train
    text_column: text
    image_column: images
    image_base_dir: /path/to/images

streaming: false
dataset_prepared_path: ./data/mm-cpt-prepared
sample_packing: false
remove_unused_columns: false
num_epochs: 1

Example already-tokenized config:

datasets:
  - path: /path/to/pretokenized.jsonl
    ds_type: json
    type: multimodal_pretrain
    split: train
    image_base_dir: /path/to/images

streaming: false
skip_prepare_dataset: true
sample_packing: false
remove_unused_columns: false
num_epochs: 1

Expected already-tokenized row shape:

{"_mm_text": "<image>\nText target.", "images": ["image.png"], "input_ids": [1, 2, 3], "attention_mask": [1, 1, 1], "labels": [1, 2, 3]}

Example streaming resume/cache config:

pretraining_dataset:
  - path: /path/to/shards/*.jsonl
    ds_type: json
    type: multimodal_pretrain
    split: train
    text_column: text
    image_column: images
    image_base_dir: /path/to/images

streaming: true
dataset_prepared_path: ./data/mm-cpt-stream-cache
ignore_data_skip: true
max_steps: 10000

Motivation and Context

The initial MM CPT work proves raw image+text continued pretraining, but it leaves common workflows uncovered:

  • Users cannot use the normal non-streaming datasets: path for MM CPT prepared-dataset workflows.
  • Users with already-tokenized MM CPT rows have no clear supported path.
  • Users doing non-streaming MM CPT should not need to manually calculate max_steps; map-style datasets have a known length, so num_epochs should be enough.
  • Streaming multimodal resume can be slow because dataloader skip/fast-forward may replay image loading and processor work for samples that were already trained.

This PR keeps the existing streaming/pretraining behavior intact while adding the map-style dataset path. The behavior is intentionally split:

  • pretraining_dataset and streaming: true still require explicit max_steps.
  • Non-streaming datasets: type: multimodal_pretrain can use num_epochs; Axolotl calculates total training steps from the prepared dataset length.

How has this been tested?

Static and focused tests:

git diff --check
pass

python -m pre_commit run check-yaml --files \
  examples/qwen2_5-vl/mm-cpt-nonstreaming-qlora.yaml \
  examples/qwen2_5-vl/mm-cpt-streaming-qlora.yaml
check yaml Passed

python -m pre_commit run trailing-whitespace --files \
  docs/multimodal.qmd \
  examples/qwen2_5-vl/mm-cpt-nonstreaming-qlora.yaml \
  examples/qwen2_5-vl/mm-cpt-streaming-qlora.yaml
trim trailing whitespace Passed

python -m pre_commit run end-of-file-fixer --files \
  docs/multimodal.qmd \
  examples/qwen2_5-vl/mm-cpt-nonstreaming-qlora.yaml \
  examples/qwen2_5-vl/mm-cpt-streaming-qlora.yaml
fix end of files Passed

python -m ruff check \
  tests/utils/schemas/validation/test_multimodal_cpt.py \
  tests/test_multimodal_streaming.py \
  tests/prompt_strategies/test_multimodal_pretrain.py \
  tests/utils/data/test_hash.py \
  tests/utils/data/test_mm_pretrain_cache.py \
  tests/utils/data/test_mm_pretrain_cache_integration.py \
  src/axolotl/prompt_strategies/multimodal_pretrain.py \
  src/axolotl/utils/data/streaming.py \
  src/axolotl/utils/data/wrappers.py \
  src/axolotl/utils/data/sft.py
All checks passed!

TRL_EXPERIMENTAL_SILENCE=1 python -m pytest -q \
  tests/utils/schemas/validation/test_multimodal_cpt.py \
  tests/test_multimodal_streaming.py \
  tests/prompt_strategies/test_multimodal_pretrain.py \
  tests/utils/data/test_hash.py \
  tests/utils/data/test_mm_pretrain_cache.py \
  tests/utils/data/test_mm_pretrain_cache_integration.py
98 passed in 26.92s

Real model smoke validation used a local Qwen3-VL-8B-Instruct checkpoint with 4-bit QLoRA, processor_type: AutoProcessor, sequence_len: 4096, micro_batch_size: 1, gradient_accumulation_steps: 1, sample_packing: false, and remove_unused_columns: false.

Raw non-streaming preprocess:

  • Input: 2 raw JSONL rows with Qwen3-VL vision placeholder text and relative image paths resolved through image_base_dir.
  • Command: axolotl preprocess raw_epoch.yml
  • Result: prepared dataset saved successfully.
  • Prepared columns: images, input_ids, labels, attention_mask, _mm_text.
  • Verified _mm_text was preserved and images remained as image references for batch-time processor/collator loading.

Raw non-streaming epoch train:

  • Config used datasets: type: multimodal_pretrain, streaming: false, dataset_prepared_path, num_epochs: 1, and no max_steps.
  • Command: axolotl train raw_epoch.yml
  • Axolotl computed Maximum number of steps set at 2.
  • Training completed to global_step=2, max_steps=2, epoch=1.0.
  • Checkpoint 2 included trainer_state.json, optimizer.pt, scheduler.pt, and tokens_state.json.

Already-tokenized non-streaming epoch train:

  • Input rows already contained _mm_text, images, input_ids, attention_mask, and labels.
  • Config used datasets: type: multimodal_pretrain, skip_prepare_dataset: true, num_epochs: 1, and no max_steps.
  • Command: axolotl train pretokenized_epoch.yml
  • Axolotl computed Maximum number of steps set at 2.
  • Training completed to global_step=2, max_steps=2, epoch=1.0.
  • Checkpoint 2 included trainer_state.json, optimizer.pt, scheduler.pt, and tokens_state.json.

Streaming MM CPT cache/resume smoke:

  • Step 1 config used pretraining_dataset: type: multimodal_pretrain, streaming: true, dataset_prepared_path, ignore_data_skip: true, and max_steps: 1.
  • Step 1 created the prepared stream cache and saved checkpoint 1 at global_step=1.
  • Resume config set resume_from_checkpoint: checkpoint-1 and max_steps: 2.
  • Resume loaded the prepared stream cache from disk and restored token accounting from checkpoint 1.
  • Resume completed to global_step=2, max_steps=2.
  • Token state advanced from {"total": 128, "trainable": 8} at checkpoint 1 to {"total": 256, "trainable": 16} at checkpoint 2.
  • Checkpoint 2 included trainer_state.json, optimizer.pt, scheduler.pt, scaler.pt, rng_state.pth, and tokens_state.json.

AI Usage Disclaimer

Yes. OpenAI Codex assisted with implementation, local validation, and drafting this PR summary. The changes were reviewed and tested locally before pushing.

Screenshots (if appropriate)

N/A

Types of changes

  • New feature: non-streaming datasets: type: multimodal_pretrain
  • New feature: already-tokenized MM CPT rows with skip_prepare_dataset: true
  • New feature: epoch-based non-streaming MM CPT runs without manually calculating max_steps
  • New feature: streaming MM CPT cache/resume support
  • Documentation update
  • YAML config examples
  • Tests
  • Bug fix
  • Breaking change

Social Handles (Optional)

N/A

Summary by CodeRabbit

  • Documentation

    • Expanded multimodal CPT guide with streaming vs non-streaming routes, prepared-YAML examples, resume/cache guidance, and notes on pre-tokenized rows.
  • New Features

    • Added example configs for streaming and non-streaming multimodal CPT (Qwen2.5-VL/QLoRA) and support for pre-tokenized multimodal rows with collator re-tokenization.
  • Validation

    • New checks rejecting ambiguous/unsupported multimodal dataset combinations (dual declarations, streaming-with-datasets, truncate-with-datasets).
  • Data handling

    • Dataset hashing and processor fingerprinting now include multimodal fields for cache sensitivity.
  • Tests

    • Expanded multimodal test coverage for routes, collators, hashing, and caching/resume.

Review Change Stack

ved1beta and others added 25 commits May 22, 2026 15:21
* cp fix for nemo

* nemo and flcon patch

* import patch

* Revert "import patch"

This reverts commit ef42d1f.

* undo falcon

* pakcing + mamba support for nemo , falcon , grenite zamba

* training run bugs

* docks

* doc string coverage + test

* mamba guard

* 2k n 2*1k test

* not is_cp_active()

* seq_len fix

* model list

* val ring atten fix

* disable double spliting in hf

* less comments

* undo zamba and bamba

* new configs

* lint
* feat: update transformers to 5.8.1

* ignore uv.lock for now

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
* feat-qgalore

* spell check

---------

Co-authored-by: Your Name <you@example.com>
* support with autoprocessor

* simple add dict

---------

Co-authored-by: Your Name <you@example.com>
… [skip ci]

* fix AssertionError: Original QKV code not found

* skip ig gemma for lor a

* fix misleading commentsT_T'
Co-authored-by: Your Name <you@example.com>
* rmv skip

* test verison

* lint

* undo
* fix: ep test missed teardown

* fix: change hardcoded ports
…#3679) [skip ci]

* fix broken MX tests from transformers 5.8.1 upgrade

* test isolation

* wrap for torchao possible import error

* isolate reward model test more

* fix PRM
…olotl-ai-cloud#3670)

* feat(fsdp2): add fp32_norms for keeping RMSNorm/LayerNorm in fp32

Add an opt-in config flag that shards norm modules under their own FSDP2
MixedPrecisionPolicy (fp32) before the standard decoder-layer wrap, so the
norm and decoder shard groups stay independent. This lets models that
declare fp32 norms for training stability train under FSDP2 while the rest
of the model runs in bf16/fp16.

FSDP1 enforces flat-param dtype uniformity within each wrap group, which is
incompatible with keeping norms in fp32; the validator therefore requires
fsdp_version: 2.

Matching: patterns without a "." match type(module).__name__ as a suffix
(catches LlamaRMSNorm, Qwen3RMSNorm, AfmoeRMSNorm, nn.LayerNorm, etc.);
patterns containing a "." match the fully qualified class path exactly.
Defaults to ["RMSNorm", "LayerNorm"].

Signed-off-by: Wing Lian <wing@axolotl.ai>

* fixup! feat(fsdp2): address review findings + fix CI caplog assertions

- matcher: skip empty/whitespace-only patterns (cls_name.endswith("") is
  True for any class, which would silently match everything).
- validator: also require fsdp_config to be set, not just fsdp_version==2.
  fsdp_config is the canonical "is_fsdp" signal elsewhere in the codebase
  (used by check_fsdp_torch_version, sample_packing validators, etc.).
- tests: temporarily flip propagate=True on the `axolotl` logger so
  pytest caplog can see the warnings. axolotl.cli.configure_logging()
  sets propagate=False at import time, which is the documented reason
  the assertions were failing in CI even though the warnings were
  firing visibly in stdout.
- comment: replace multi-line rationale near the fp32_norms helpers with
  a one-line summary (the longer version lives in the PR description).

Signed-off-by: Wing Lian <wing@axolotl.ai>

* test(fsdp2): multi-GPU e2e for fp32_norms with dtype-preservation assertion

The existing fp32_norms tests are pure-CPU and monkeypatch fully_shard —
they cover the matcher logic and validator guard rails but never exercise
the actual FSDP2 path that motivated this PR.

Adds tests/e2e/multigpu/test_fsdp2_fp32_norms.py: spawns a 2-GPU
`axolotl train` subprocess with `fp32_norms: true` + `fsdp_version: 2` +
`bf16: true` on tiny-qwen3-129m (full FT, 2 steps) and asserts:

  1. Training completes — the original FSDP1 flat-param dtype crash
     can't recur because we're on FSDP2 with the per-module
     MixedPrecisionPolicy.
  2. All RMSNorm params are float32 after step 1 — captured via a
     test-only TrainerCallback in
     tests/e2e/multigpu/_fp32_norms_dtype_capture.py, dumped to JSON
     at $FP32_NORMS_DTYPE_DUMP_PATH on rank 0.
  3. At least one non-norm param is bfloat16 — proves the two FSDP2
     MixedPrecisionPolicy groups are independent (catches a silent
     globally-fp32 fallback that would technically satisfy assertion 2
     but defeat the point of the feature).

The dtype-capture plugin is plumbed in via the test's yaml `plugins:`
list, with PYTHONPATH=<repo_root> on the subprocess env so the
tests.e2e.multigpu._fp32_norms_dtype_capture module resolves.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* chore: lint

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
* latest typer breaks HF CLI

* wrong comparison
…tl-ai-cloud#3687)

transformers decorates Gemma4VisionAttention with
@use_kernelized_func(apply_rotary_pos_emb) where the target is a bare
function. Under use_kernels=True (force-enabled by KernelsArgs for the
ScatterMoE path), from_pretrained calls model.kernelize(), whose
attach_hidden_kernels step does register_module(name, fn) for each
_hidden_kernels entry. register_module rejects the non-Module function:

    TypeError: ...apply_rotary_pos_emb is not a Module subclass

with a follow-on AttributeError from the cleanup path. The MoE itself is
accelerated via the transformers ExpertsInterface (experts_implementation),
independent of this path, and the vision forward uses
apply_multidimensional_rope, never apply_rotary_pos_emb -- so the
registered entry is dead weight.

Add monkeypatch gemma4_kernelize that strips non-Module _hidden_kernels
entries from Gemma4VisionAttention, wired in
patch_manager._apply_model_specific_patches for gemma4 when use_kernels is
set. state_dict is unchanged, so the fix is behavior-neutral.

Also add ddp_find_unused_parameters: true to the 26b-a4b MoE QLoRA example
(multi-GPU only -- text-backbone LoRA plus KV-sharing layers leave some
adapter params gradient-less under DDP).
…xolotl-ai-cloud#3651)

* fix: refactor kernels patch to drop routing and inject into Expert
registry

* chore: add to optim doc

* feat: update sonicmoe version

* chore: cleanup with DEEPEP and kernels compat

* gate/guard model expert setup

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
… ci]

* feat(scattermoe-lora): selective dequant for mxfp4 expert weights

Add an MXFP4 branch to `selective_expert_weights()` that detects a
torchao `MXTensor` parameter (elem_dtype=float4_e2m1fn_x2) and
dequantizes only the active experts via index-then-construct of a
compact sub-MXTensor. The K-axis OCP block layout (last storage dim)
matches `experts.gate_up_proj` natural shape `[E, N, K]`, so the
caller's existing `.transpose(2, 1)` post-step keeps producing the
kernel's `[E, K, N]` weight tile unchanged.

`HFScatterMoEGatedMLP.forward` now also routes through the selective
path whenever the experts hold MXFP4 weights — full-tensor MX dequant
of 256-expert models is prohibitive and the kernel needs bf16 input.

Tests (CUDA-only) compare against a bf16 baseline produced by the
same MXTensor's full dequant; outputs are bitwise identical for both
forward and backward (dX, dA, dB) across small [E=8,K=128,N=256] and
representative [E=32,K=2048,N=1024] shapes, and across all four
combinations of \`use_fused_dX\` / \`use_fused_gather\`.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* feat(scattermoe-lora): fused mxfp4 dequant in triton kernel

Add MX-aware forward and dX kernels that consume an ``MXWeights``
container (packed uint8 + E8M0 scales) directly, so the base-weight
tile is dequantized inside the K-loop instead of through a materialized
bf16 buffer. The K-loop loads two FP4 values per uint8 byte, looks them
up in a 16-entry codebook tensor (``±{0, 0.5, 1, 1.5, 2, 3, 4, 6}``),
multiplies by ``2^(scale_byte - 127)``, and casts to bf16 for the
matmul. ``BLOCK_K`` is constrained to a multiple of the OCP block size
(32) so each tile aligns with whole scale blocks; an MX-aware autotune
pruner accounts for the extra packed/scale SMEM.

The dX kernel reuses the *forward* MX layout (block axis = K, the dX
output axis) — for each (K_tile, N_tile) sub-tile, nibbles decode
along the K rows (the byte is shared by two adjacent K rows) and
scales broadcast within their MX block. This avoids the
dequant + re-quantize "pre-transpose" the spec suggested and the
extra MX-rounding error that round-trip would have introduced.

``ScatterMoELoRA.forward`` now accepts either a dense tensor or an
``MXWeights``; the MX branch always selects the fused-dX and
fused-gather backward kernels (the non-fused dX path would have to
materialize a bf16 weight tile, defeating the win).

Unit tests cover forward, dX, dA, dB parity for small
[E=8, K=128, N=256] and representative [E=32, K=2048, N=1024] shapes;
tolerances are calibrated to bf16 MMA noise (atomic-add ordering and
FMA reordering between the full-E baseline and compact-active MX path).
Integration test exercises a tiny synthetic DeepSeek-V4-style MoE
block (E=8, hidden=512, intermediate=256, top_k=2) end-to-end through
both Strategy A and Strategy B with LoRA disabled.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* chore(scattermoe-lora): mxfp4 forward/backward benchmark

Add ``bench_mxfp4.py`` and committed results for the representative
DeepSeek-V4-style shape (E=128, K=2048, N=1024, top_k=8, M=4096,
rank=16). Reports ms/iter, tokens/s, peak GPU memory, and HBM
bandwidth utilisation for three configurations: bf16 baseline,
Strategy A (selective dequant), Strategy B (fused MX).

On the RTX PRO 6000 Blackwell, the all-active-experts shape used
here doesn't exercise selective dequant's memory savings (active = E
= 128) — A pays the cost of materialising the full bf16 dequant
buffer per step (~9 GB peak vs 1.9 GB for B) while still routing
through the bf16 kernel. B halves A's wall time (~12 ms vs 30 ms) by
eliminating the buffer, but stays slower than the bf16 baseline (5
ms) which assumes the bf16 weights already exist in memory.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* bench(scattermoe-lora): mxfp4 sparse-routing benchmark numbers

Signed-off-by: Wing Lian <wing@axolotl.ai>

* bench(scattermoe-lora): mxfp4 seqlen sweep with load-balanced routing

Signed-off-by: Wing Lian <wing@axolotl.ai>

* fix(scattermoe-lora): correct mx forward smem accounting

The MX-aware autotune pruner for the forward kernel under-accounted
SMEM: it computed the packed-tile cost as BLOCK_N * BLOCK_K/2 and the
scale-tile cost as BLOCK_N * BLOCK_K/MX_BLOCK_SIZE, but the actual
tl.load issues a full [BLOCK_N, BLOCK_K]-shaped uint8 fetch for both
buffers (the packed buffer reads each byte twice because K_byte =
K // 2 indexes a [BLOCK_K]-wide vector; the scale buffer broadcasts
within each MX_BLOCK_SIZE K-block). Bring the forward pruner up to the
same conservative full-tile accounting already used by
_prune_dX_mx_configs. Without this, on the [E=128, K=2048, N=1024]
shape with the typical GPU SMEM caps, two to six high-stage configs
that were previously selectable would have overflowed SMEM at launch
under correct accounting — a silent OOM-in-the-future risk.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* docs(scattermoe-lora): align mx dx kernel docstring with implementation

The file-level docstring for the MXFP4 kernels described the dX kernel
as using a pre-transposed [E, K, N/2] layout produced by a
'mx_pre_transpose_for_dx' helper. That helper doesn't exist; the dX
kernel actually reuses the forward [E, N, K/2] layout, iterating the N
reduction in outer tiles and decoding nibbles along the K rows of each
tile. Rewrite the docstring to describe what the code actually does,
including the rationale — reusing the forward buffer avoids the
dequant + re-quantize round-trip that a pre-transpose would require
and keeps dX numerics free of a second MX rounding error stacked on
top of the forward quantization.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* chore(scattermoe-lora): mx code-review nit cleanup

F4: Hoist 'is_mxfp4_param' import from inside 'HFScatterMoEGatedMLP.forward'
to the top of layers.py — it was being re-imported every step on the hot
path.

F5: Add a thin compatibility shim for torchao MXTensor internals access in
mx_weights.py. The MX paths in selective_dequant.py / mx_weights.py used
to reach into 'mx_param.qdata', 'mx_param.scale',
'mx_param.kernel_preference' and call 'MXTensor(...)' with positional
args directly. That works at the pinned torchao 0.17.0 but is fragile to
internal renames in future torchao releases. Funnel through three
helpers — '_mx_qdata', '_mx_scale', '_construct_mxtensor_subset' — that
use 'getattr' fallbacks for the buffer attributes and pass the
constructor's optional args via 'getattr' too. Single point of pain,
no API change.

F7: Remove the unused 'NO_K_MASK' heuristic + tl.constexpr param from
the dX MX kernel '_scatter2scatter_lora_dX_mx'. The dX kernel never
references it (its inner loop masks N, not K), so the constexpr just
forced extra autotune key entries.

F8: Consolidate the duplicate '_torchao_mxtensor_cls()' definitions
(one in selective_dequant.py, one in mx_weights.py) into a single
definition in mx_weights.py. selective_dequant.py imports it.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* test(scattermoe-lora): strengthen mx backward test coverage

F3: 'test_strategy_a_backward_fused_variants' previously used
'torch.ones_like(output)' as the grad input and asserted only on dX.
A uniform grad zeros out cross-token differences in the fused-gather
accumulation, masking reordering bugs; restricting the assertion to dX
silently let the dA/dB paths go unchecked across the four
'(use_fused_dX, use_fused_gather)' production variants.

  * Drive the backward with 'torch.randn_like(output) * 0.1'.
  * Capture and assert dA and dB parity across all four variants
    using the same 'row_idx' gather pattern as
    'test_strategy_a_backward_matches_bf16'.
  * Forward and dX are still asserted bitwise via 'torch.equal'. dA/dB
    fall back to atol/rtol = 1e-3 because the fused dA/dB kernel uses
    'atomic_add' across N-block programs and the in-flight program
    count differs between the full-E baseline and the compact-active
    path; combined with FMA reordering, the 'use_fused_dX=True'
    variants accumulate ~1 bf16 ULP of unavoidable atomic-order noise.
    The new bound is still an order of magnitude below that noise
    floor, so it catches real bugs.

F9: The 'test_strategy_b_backward_matches_bf16' dX comparison runs at
'atol=0.5, rtol=2e-2' (small) / 'atol=2.0, rtol=3e-2' (representative)
to allow for accumulated bf16 MMA noise over the N reduction. Those
bounds are appropriate for legitimate per-element drift but would also
admit a uniform multiplicative bug — e.g. an off-by-one on the E8M0
exponent that scales every dX element by 2x.

Add a guard alongside the existing 'torch.allclose': mask out
near-zero baseline elements (relative to 'bf16_dX.abs().max()'), then
require the per-element ratio 'mx_dX / bf16_dX' to have std < 0.5. A
uniform multiplicative bug pushes that std to ~0 while the mean shifts;
a real-bug per-element drift pushes the std up. This crosscuts the
allclose check rather than replacing it.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* bench(scattermoe-lora): exclude per-iter setup from timed window

The previous bench harness did a fresh '.clone()' of x and a
'requires_grad_(True)' on cloned lora A/B tensors every iter inside
the timed window. That accounts for buffer allocation, not kernel
cost, and biases the numbers toward whichever path produced the
smallest activations. Restructure the runners so:

  * 'x' is cloned once into a leaf tensor with 'requires_grad_(True)'
    inside 'bench()' (outside the timed warmup + timed loop).
  * LoRA A/B leaf tensors are constructed once in the runner factory,
    not per iter.
  * Each iter calls the runner which sets 'x.grad = A.grad = B.grad =
    None' (cheap, no GPU sync) so the autograd graph for the timed
    iteration is fresh and grads don't accumulate.

Re-run all three configs end-to-end after this change (dense E=128,
sparse E=256 / 10-active, balanced E=256 M-sweep at M ∈ {256, 1024,
4096, 16384}) and refresh the numbers in bench_mxfp4_results.md.
Headers and table structure are unchanged. The qualitative ordering
holds (Strategy A wins at low active/E, Strategy B wins near
active/E ≈ 1, and Strategy A still OOMs across the balanced sweep on
the workstation with vLLM colocated), with per-cell numbers within
single-digit percent of the prior runs.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* style(scattermoe-lora): apply pre-commit auto-fixes and mypy fixes

Signed-off-by: Wing Lian <wing@axolotl.ai>

* fix(scattermoe-lora): mxfp4 shape validation + torchao version messages

Signed-off-by: Wing Lian <wing@axolotl.ai>

* lint and PR review fixes

* fix(scattermoe-lora): restore lint task fixes + add missing forward parity assertions

Wing's "lint and PR review fixes" commit (9007a82) reverted three fixes
from the prior lint pass. Restore them:

1. parallel_linear_lora.py: use isinstance(expert_weights, MXWeights)
   directly so mypy can narrow the union — the `is_mx` boolean alias
   blocks narrowing and re-introduces 2 union-attr errors.

2. bench_mxfp4.py: assert template is not None before the MXTensor(...)
   constructor — the chunked converter initializes template to None
   then sets it inside the loop, which mypy can't prove non-None at
   the call site (6 None-attr errors).

3. test_mxfp4_expert_weights.py: the F841 on fwd_tol was actually a
   smell of dropped logic. Both backward tests
   (test_strategy_a_backward_matches_bf16 and
   test_strategy_b_backward_matches_bf16) compute the forward outputs
   out_b/out_a/out_s, run backward, and assert gradients match — but
   never assert that the forward outputs match. A forward bug
   producing a constant offset (and therefore zero gradient delta)
   would slip past the bwd-only checks. Add the missing
   torch.equal(out_b, out_a) for Strategy A (bitwise contract) and
   torch.allclose(out_b, out_s, **fwd_tol) for Strategy B (MX tol).

Signed-off-by: Wing Lian <wing@axolotl.ai>

* don't worry about flash-attn direct patches for now

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
… fp32 fix + AC-vs-tiled gap analysis (axolotl-ai-cloud#3666) [skip ci]

* feat(scattermoe-lora): selective dequant for mxfp4 expert weights

Add an MXFP4 branch to `selective_expert_weights()` that detects a
torchao `MXTensor` parameter (elem_dtype=float4_e2m1fn_x2) and
dequantizes only the active experts via index-then-construct of a
compact sub-MXTensor. The K-axis OCP block layout (last storage dim)
matches `experts.gate_up_proj` natural shape `[E, N, K]`, so the
caller's existing `.transpose(2, 1)` post-step keeps producing the
kernel's `[E, K, N]` weight tile unchanged.

`HFScatterMoEGatedMLP.forward` now also routes through the selective
path whenever the experts hold MXFP4 weights — full-tensor MX dequant
of 256-expert models is prohibitive and the kernel needs bf16 input.

Tests (CUDA-only) compare against a bf16 baseline produced by the
same MXTensor's full dequant; outputs are bitwise identical for both
forward and backward (dX, dA, dB) across small [E=8,K=128,N=256] and
representative [E=32,K=2048,N=1024] shapes, and across all four
combinations of \`use_fused_dX\` / \`use_fused_gather\`.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* feat(scattermoe-lora): fused mxfp4 dequant in triton kernel

Add MX-aware forward and dX kernels that consume an ``MXWeights``
container (packed uint8 + E8M0 scales) directly, so the base-weight
tile is dequantized inside the K-loop instead of through a materialized
bf16 buffer. The K-loop loads two FP4 values per uint8 byte, looks them
up in a 16-entry codebook tensor (``±{0, 0.5, 1, 1.5, 2, 3, 4, 6}``),
multiplies by ``2^(scale_byte - 127)``, and casts to bf16 for the
matmul. ``BLOCK_K`` is constrained to a multiple of the OCP block size
(32) so each tile aligns with whole scale blocks; an MX-aware autotune
pruner accounts for the extra packed/scale SMEM.

The dX kernel reuses the *forward* MX layout (block axis = K, the dX
output axis) — for each (K_tile, N_tile) sub-tile, nibbles decode
along the K rows (the byte is shared by two adjacent K rows) and
scales broadcast within their MX block. This avoids the
dequant + re-quantize "pre-transpose" the spec suggested and the
extra MX-rounding error that round-trip would have introduced.

``ScatterMoELoRA.forward`` now accepts either a dense tensor or an
``MXWeights``; the MX branch always selects the fused-dX and
fused-gather backward kernels (the non-fused dX path would have to
materialize a bf16 weight tile, defeating the win).

Unit tests cover forward, dX, dA, dB parity for small
[E=8, K=128, N=256] and representative [E=32, K=2048, N=1024] shapes;
tolerances are calibrated to bf16 MMA noise (atomic-add ordering and
FMA reordering between the full-E baseline and compact-active MX path).
Integration test exercises a tiny synthetic DeepSeek-V4-style MoE
block (E=8, hidden=512, intermediate=256, top_k=2) end-to-end through
both Strategy A and Strategy B with LoRA disabled.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* chore(scattermoe-lora): mxfp4 forward/backward benchmark

Add ``bench_mxfp4.py`` and committed results for the representative
DeepSeek-V4-style shape (E=128, K=2048, N=1024, top_k=8, M=4096,
rank=16). Reports ms/iter, tokens/s, peak GPU memory, and HBM
bandwidth utilisation for three configurations: bf16 baseline,
Strategy A (selective dequant), Strategy B (fused MX).

On the RTX PRO 6000 Blackwell, the all-active-experts shape used
here doesn't exercise selective dequant's memory savings (active = E
= 128) — A pays the cost of materialising the full bf16 dequant
buffer per step (~9 GB peak vs 1.9 GB for B) while still routing
through the bf16 kernel. B halves A's wall time (~12 ms vs 30 ms) by
eliminating the buffer, but stays slower than the bf16 baseline (5
ms) which assumes the bf16 weights already exist in memory.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* bench(scattermoe-lora): mxfp4 sparse-routing benchmark numbers

Signed-off-by: Wing Lian <wing@axolotl.ai>

* bench(scattermoe-lora): mxfp4 seqlen sweep with load-balanced routing

Signed-off-by: Wing Lian <wing@axolotl.ai>

* fix(scattermoe-lora): correct mx forward smem accounting

The MX-aware autotune pruner for the forward kernel under-accounted
SMEM: it computed the packed-tile cost as BLOCK_N * BLOCK_K/2 and the
scale-tile cost as BLOCK_N * BLOCK_K/MX_BLOCK_SIZE, but the actual
tl.load issues a full [BLOCK_N, BLOCK_K]-shaped uint8 fetch for both
buffers (the packed buffer reads each byte twice because K_byte =
K // 2 indexes a [BLOCK_K]-wide vector; the scale buffer broadcasts
within each MX_BLOCK_SIZE K-block). Bring the forward pruner up to the
same conservative full-tile accounting already used by
_prune_dX_mx_configs. Without this, on the [E=128, K=2048, N=1024]
shape with the typical GPU SMEM caps, two to six high-stage configs
that were previously selectable would have overflowed SMEM at launch
under correct accounting — a silent OOM-in-the-future risk.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* docs(scattermoe-lora): align mx dx kernel docstring with implementation

The file-level docstring for the MXFP4 kernels described the dX kernel
as using a pre-transposed [E, K, N/2] layout produced by a
'mx_pre_transpose_for_dx' helper. That helper doesn't exist; the dX
kernel actually reuses the forward [E, N, K/2] layout, iterating the N
reduction in outer tiles and decoding nibbles along the K rows of each
tile. Rewrite the docstring to describe what the code actually does,
including the rationale — reusing the forward buffer avoids the
dequant + re-quantize round-trip that a pre-transpose would require
and keeps dX numerics free of a second MX rounding error stacked on
top of the forward quantization.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* chore(scattermoe-lora): mx code-review nit cleanup

F4: Hoist 'is_mxfp4_param' import from inside 'HFScatterMoEGatedMLP.forward'
to the top of layers.py — it was being re-imported every step on the hot
path.

F5: Add a thin compatibility shim for torchao MXTensor internals access in
mx_weights.py. The MX paths in selective_dequant.py / mx_weights.py used
to reach into 'mx_param.qdata', 'mx_param.scale',
'mx_param.kernel_preference' and call 'MXTensor(...)' with positional
args directly. That works at the pinned torchao 0.17.0 but is fragile to
internal renames in future torchao releases. Funnel through three
helpers — '_mx_qdata', '_mx_scale', '_construct_mxtensor_subset' — that
use 'getattr' fallbacks for the buffer attributes and pass the
constructor's optional args via 'getattr' too. Single point of pain,
no API change.

F7: Remove the unused 'NO_K_MASK' heuristic + tl.constexpr param from
the dX MX kernel '_scatter2scatter_lora_dX_mx'. The dX kernel never
references it (its inner loop masks N, not K), so the constexpr just
forced extra autotune key entries.

F8: Consolidate the duplicate '_torchao_mxtensor_cls()' definitions
(one in selective_dequant.py, one in mx_weights.py) into a single
definition in mx_weights.py. selective_dequant.py imports it.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* test(scattermoe-lora): strengthen mx backward test coverage

F3: 'test_strategy_a_backward_fused_variants' previously used
'torch.ones_like(output)' as the grad input and asserted only on dX.
A uniform grad zeros out cross-token differences in the fused-gather
accumulation, masking reordering bugs; restricting the assertion to dX
silently let the dA/dB paths go unchecked across the four
'(use_fused_dX, use_fused_gather)' production variants.

  * Drive the backward with 'torch.randn_like(output) * 0.1'.
  * Capture and assert dA and dB parity across all four variants
    using the same 'row_idx' gather pattern as
    'test_strategy_a_backward_matches_bf16'.
  * Forward and dX are still asserted bitwise via 'torch.equal'. dA/dB
    fall back to atol/rtol = 1e-3 because the fused dA/dB kernel uses
    'atomic_add' across N-block programs and the in-flight program
    count differs between the full-E baseline and the compact-active
    path; combined with FMA reordering, the 'use_fused_dX=True'
    variants accumulate ~1 bf16 ULP of unavoidable atomic-order noise.
    The new bound is still an order of magnitude below that noise
    floor, so it catches real bugs.

F9: The 'test_strategy_b_backward_matches_bf16' dX comparison runs at
'atol=0.5, rtol=2e-2' (small) / 'atol=2.0, rtol=3e-2' (representative)
to allow for accumulated bf16 MMA noise over the N reduction. Those
bounds are appropriate for legitimate per-element drift but would also
admit a uniform multiplicative bug — e.g. an off-by-one on the E8M0
exponent that scales every dX element by 2x.

Add a guard alongside the existing 'torch.allclose': mask out
near-zero baseline elements (relative to 'bf16_dX.abs().max()'), then
require the per-element ratio 'mx_dX / bf16_dX' to have std < 0.5. A
uniform multiplicative bug pushes that std to ~0 while the mean shifts;
a real-bug per-element drift pushes the std up. This crosscuts the
allclose check rather than replacing it.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* bench(scattermoe-lora): exclude per-iter setup from timed window

The previous bench harness did a fresh '.clone()' of x and a
'requires_grad_(True)' on cloned lora A/B tensors every iter inside
the timed window. That accounts for buffer allocation, not kernel
cost, and biases the numbers toward whichever path produced the
smallest activations. Restructure the runners so:

  * 'x' is cloned once into a leaf tensor with 'requires_grad_(True)'
    inside 'bench()' (outside the timed warmup + timed loop).
  * LoRA A/B leaf tensors are constructed once in the runner factory,
    not per iter.
  * Each iter calls the runner which sets 'x.grad = A.grad = B.grad =
    None' (cheap, no GPU sync) so the autograd graph for the timed
    iteration is fresh and grads don't accumulate.

Re-run all three configs end-to-end after this change (dense E=128,
sparse E=256 / 10-active, balanced E=256 M-sweep at M ∈ {256, 1024,
4096, 16384}) and refresh the numbers in bench_mxfp4_results.md.
Headers and table structure are unchanged. The qualitative ordering
holds (Strategy A wins at low active/E, Strategy B wins near
active/E ≈ 1, and Strategy A still OOMs across the balanced sweep on
the workstation with vLLM colocated), with per-cell numbers within
single-digit percent of the prior runs.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* style(scattermoe-lora): apply pre-commit auto-fixes and mypy fixes

Signed-off-by: Wing Lian <wing@axolotl.ai>

* fix(scattermoe-lora): mxfp4 shape validation + torchao version messages

Signed-off-by: Wing Lian <wing@axolotl.ai>

* lint and PR review fixes

* fix(scattermoe-lora): restore lint task fixes + add missing forward parity assertions

Wing's "lint and PR review fixes" commit (9007a82) reverted three fixes
from the prior lint pass. Restore them:

1. parallel_linear_lora.py: use isinstance(expert_weights, MXWeights)
   directly so mypy can narrow the union — the `is_mx` boolean alias
   blocks narrowing and re-introduces 2 union-attr errors.

2. bench_mxfp4.py: assert template is not None before the MXTensor(...)
   constructor — the chunked converter initializes template to None
   then sets it inside the loop, which mypy can't prove non-None at
   the call site (6 None-attr errors).

3. test_mxfp4_expert_weights.py: the F841 on fwd_tol was actually a
   smell of dropped logic. Both backward tests
   (test_strategy_a_backward_matches_bf16 and
   test_strategy_b_backward_matches_bf16) compute the forward outputs
   out_b/out_a/out_s, run backward, and assert gradients match — but
   never assert that the forward outputs match. A forward bug
   producing a constant offset (and therefore zero gradient delta)
   would slip past the bwd-only checks. Add the missing
   torch.equal(out_b, out_a) for Strategy A (bitwise contract) and
   torch.allclose(out_b, out_s, **fwd_tol) for Strategy B (MX tol).

Signed-off-by: Wing Lian <wing@axolotl.ai>

* don't worry about flash-attn direct patches for now

* feat(tiled-mlp): support MoE block classes in patcher

Extend patch_tiled_mlp to discover MoE block classes
({prefix}SparseMoeBlock / MoeMLP / MoE) and patch the routing+expert
forward when scattermoe-lora is active.

The kernels library installs HFScatterMoEGatedMLP.forward per instance
during model.kernelize(), which shadows class-level patches. Add a
post-model-load step (patch_tiled_mlp_moe_instances) that re-wraps each
MoE block instance so tiling layers on top of the kernels-installed
forward instead of being bypassed.

Falls back to the existing dense {prefix}MLP / {prefix}TextMLP path
when no MoE block class exists. The gpt_oss special case for
DeepSpeedTiledMLPMoE is preserved and extended to every MoE block.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* fix(tiled-mlp): defer FSDP2 reshard + correct per-shard grad accumulation

Two backward-pass correctness fixes in TiledMLP.backward.

1) Defer FSDP2 post-backward reshard across the tile loop.

   The backward issues one torch.autograd.backward per shard. Under
   FSDP2 (torch.distributed.fsdp.fully_shard), the first inner backward
   triggers the wrapping FSDPModule's post-backward hook, which reshards
   parameters; subsequent shards then recompute against only-local
   DTensor shards. Silent gradient corruption at best, crash at worst.

   LinkedIn's Liger-Kernel PR axolotl-ai-cloud#1128 fixed this for FSDP1 with
   FSDP.summon_full_params(writeback=True). That API does not exist in
   FSDP2. The PyTorch 2.11 FSDP2 surface is
   FSDPModule.set_reshard_after_backward(False) — toggle off around the
   tile loop, restore the prior value, and issue one explicit reshard()
   afterwards.

   The wrapping FSDPModule is discovered by walking the global
   _module_state_mapping registry (FSDP2 is typically applied at the
   decoder-layer level, so the MLP itself is rarely the FSDPModule).
   Result is cached on the MLP instance so the walk runs once.

   No-op under DDP, single-GPU, or DeepSpeed. DeepSpeedTiledMLPMoE is
   left alone — DeepSpeed coordinates its own gather and the two
   backends are mutually exclusive.

2) Replace the hook-based GradientAccumulator with inline fp32
   accumulation.

   The previous implementation called grad_accumulator.install_hooks()
   inside every shard iteration, so the N-th shard ran N stacked hooks
   that each accumulated the same shard contribution — and on the last
   shard the manually-set param.grad was then re-added by AccumulateGrad,
   doubling it. The accumulator also scaled by 1/N, but sequence-dim
   sharded gradients are additive (not averaged). Combined, param.grad
   came out ~2x-2.5x the analytical value.

   Inline accumulation captures param.grad after each shard's inner
   backward, sums into a per-param fp32 accumulator, clears the running
   grad, and writes the total back once at the end (preserving any
   pre-existing .grad from earlier graph segments).

Signed-off-by: Wing Lian <wing@axolotl.ai>

* test(tiled-mlp): single-gpu MoE + scattermoe-lora coverage

Three parity checks (tiled vs un-tiled forward+backward) plus two
patcher-internals tests. All gated on CUDA.

- Dense LlamaMLP-shape (hidden=64, intermediate=128, seq=64): tight
  atol=1e-5 on outputs, dX, and every parameter grad. Uses batch=1
  to match the sequence-packed inputs production sees.
- Hand-rolled MoE block (E=8, hidden=64, intermediate=128, top_k=2):
  same shape + same tolerances against an index_add-based reference.
- ScatterMoEGatedMLP in bf16: norm-relative tolerance < 1%, matching
  the established bar in tests/integrations/test_scattermoe_lora_kernels.py
  (bf16 + tiled reduction order makes max abs error a noisy signal).
- Patcher unit tests: MoE block class discovery prefers SparseMoeBlock
  / MoeMLP over MoE, and returns None for dense models.

Synthetic-shape modules only — no transformers checkpoints loaded.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* test(tiled-mlp): FSDP2 multi-rank correctness

Two parity tests (dense + scattermoe-lora) that wrap a tiny MLP /
ScatterMoEGatedMLP with FSDP2 (`fully_shard`) and compare tiled
forward+backward against a non-tiled FSDP2 reference. Both must run
through the FSDPModule's __call__ so FSDP2's pre-forward hooks
materialize the unsharded params before TiledMLP.apply chunks the
input; the helper _install_tiled_forward mirrors what the production
patcher does instance-side.

Designed to be launched with
`torchrun --nproc-per-node=2 -m pytest tests/e2e/multigpu/test_tiled_mlp_fsdp2.py`.
Skips with a clear reason on a 1-GPU executor or when launched without
torchrun. Verified to pass on a 2-GPU runner (RTX PRO 6000 Blackwell).

Signed-off-by: Wing Lian <wing@axolotl.ai>

* feat(scattermoe-lora): shared dequant buffer across tile shards

Adds shared_dequant_across_shards() which hoists the MXFP4 dequant out of the per-shard selective path. The orthogonal tiled wrapper calls selective_expert_weights once per shard; when active-expert sets overlap (the common case under softmax routing) the dequant is wasted work. The helper computes the union of active experts across all shards, dequantizes that union once, and returns per-shard remaps so each shard's parallel_linear_lora call uses the correct slice.

Bitwise contract: a shard's gathered slice is byte-identical to the per-shard selective_expert_weights output, verified by test_shared_dequant_helper.py with N=4 overlapping shards plus disjoint and single-shard regression cases.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* fix(tiled-mlp): default grad accumulator to param dtype, skip redundant casts

The orthogonal TiledMLP wrapper pre-allocated an fp32 accumulator the
size of every compute param, then cast each shard's bf16 ``param.grad``
to fp32 inside the loop before adding it. For E=128 / hidden=2048 /
intermediate=8192 MoE training in bf16 that's roughly 17 GiB of fp32
buffer on the gate_up_proj alone — net 2x parameter-side memory
regression vs. simply accumulating at the param's own dtype. The
per-shard ``grad.to(fp32)`` cast was also a per-shard HBM bandwidth tax
that dominated the wall-clock regression at intermediate=8192.

Match what AccumulateGrad does in the unsharded backward: accumulate
at the param's own dtype, skip the cast when shard-grad dtype matches
the accumulator dtype, and only cast back to param dtype at write-back
when the buffer dtype differs. fp32 accumulation is opt-in via
AXOLOTL_TILED_MLP_ACCUM_FP32=1 for callers who care about bf16
round-off in very-large-N-shard sums.

The dead ``GradientAccumulator`` class (no longer called after the
inline-accumulation refactor in b13375a0) is updated to the same
defaults — param-dtype accumulator, gradient_scale=1.0 — so it is in
a coherent state if anyone re-introduces a hook-based path.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* test(tiled-mlp): strengthen tiled-vs-untiled grad parity

Add three regression guards for the TiledMLP gradient-accumulator
fix:

1) ``test_tiled_dense_mlp_grad_parity_nonuniform_weights`` and
   ``test_tiled_moe_grad_parity_nonuniform_weights`` exercise
   shards in {1, 2, 4} with non-uniform per-token upstream weights.
   A mean-vs-sum scaling bug in the per-shard accumulator (the
   historical ``gradient_scale = 1/total_shards``) would show up as
   roughly ``(N-1)/N`` relative drift in the param grads. The old
   tests used a single shard count and uniform-magnitude upstream,
   which allowed the bug to slip through.

2) ``test_tiled_dense_mlp_grad_parity_bf16`` runs the same parity at
   bf16 to lock the default param-dtype accumulator path (no fp32
   buffer) against regression.

3) ``test_tiled_grad_accumulator_dtype_matches_param_dtype`` is an
   allocation-side guard: spy on ``torch.zeros_like`` during a
   bf16 tiled backward and assert none of the per-param accumulator
   allocations request fp32. A future change that re-introduces the
   fp32 buffer by default would fail this check without needing a
   memory-resident bench.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* fix(tiled-mlp): default to ~32K tokens/shard, not ceil(seq/hidden)

The previous heuristic put only ~2K tokens/shard at long context — well
below the MoE Triton kernel's BLOCK_M sweet spot. An empirical sweep at
seq ∈ {64K, 128K, 256K, 512K} showed 3.2× speed-up at 64–256K and 2.1×
at 512K from raising per-shard tokens to ~32K, with only a modest
peak-mem cost (~5–10 GiB extra at seq=256K) because the routed
intermediate buffer dominates and scales linearly with per-shard
tokens.

Bench data is operator-archived locally; the headline numbers are
included in the PR description.

The 32K target is empirical, not theoretical — it's the largest
tokens-per-shard that fits at seq up to 256K without OOM and stays
inside the cuBLAS large-batch_count safe regime that surfaces a
separate bug at seq=512K + s=16. Operators can override via
cfg_num_shards for niche cases (smaller intermediate, larger top_k).

Also includes ruff-format cleanup of cherry-picked commits.

Signed-off-by: Wing Lian <wing@axolotl.ai>

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
…oud#3660)

LigerFusedLinearKLTopKLogprobFunction.forward's loss_fn_for_grad returned
(soft_loss, ce_loss) directly to torch.func.grad_and_value(has_aux=True),
which treats only the first element as the grad target. CE was silently
dropped from the backward graph: CE-only training had grad_norm=0 every
step, and KD-mix training updated parameters with KD-only gradients
despite reported loss showing both terms.

Combine the two losses inside loss_fn_for_grad so both contribute to
backward, and keep (soft_loss, ce_loss) as aux for reporting. Outer
accumulators, temperature scaling, and the final reported loss formula
are unchanged.

Adds regression tests in tests/integrations/test_kd_liger.py.
… (supersedes chunking workaround) (axolotl-ai-cloud#3667)

* test(scattermoe-lora): repro CUBLAS_STATUS_EXECUTION_FAILED at large batch_count

The tiled-MLP long-context bench surfaces a hard failure at seq=524288
with 16 shards: ``cublasGemmStridedBatchedEx`` raises
``CUBLAS_STATUS_EXECUTION_FAILED`` at parallel_experts.py:72's
``gates.unsqueeze(1) @ output_expanded``. The crash is reproducible on
the bench shape (T=32K tokens/shard, top_k=8, hidden=2048,
intermediate=8192) and is a downstream symptom of an int32
pointer-offset overflow in the upstream ``scatter2scatter`` Triton
kernel during the up-projection — at that shape its output buffer is
2**32 elements, so ``M_block * stride_ym`` overflows int32 for the
trailing rows.

Add three repro tests in tests/integrations covering:

  * fast-path bit-identity vs the raw kernel below the threshold,
  * non-corruption at the overflow shape via the int32-safe wrapper,
  * end-to-end ``parallel_linear`` smoke at the failing bench shape.

The tests are marked ``pytest.mark.skip`` pending the fix in the next
commit; the follow-up "un-mark" commit re-enables them so they guard
the fix going forward. Symbols imported inside each test body
(``_scatter2scatter_int32_safe``, ``_SCATTER2SCATTER_INT32_LIMIT``)
land alongside the fix in the next commit, so the skipped tests do
not fail collection at this point in the history.

Signed-off-by: Wing Lian <wing@axolotl.ai>

fix(scattermoe-lora): work around cuBLAS large-batch_count failure in gates @ output_expanded

The originally-reported symptom — ``CUBLAS_STATUS_EXECUTION_FAILED``
at parallel_experts.py:72 — is NOT a cuBLAS bug. The cuBLAS bmm
shape at the failing seq=512K / 16-shard config is tiny
(batch_count=32768, M=1, K=8, N=2048) and works in isolation. The
crash is a downstream symptom of an int32 pointer-offset overflow in
the upstream ``scatter2scatter`` Triton kernel during the
up-projection, surfaced at the next CUDA-sync point (the bmm).

Diagnosis (verified by inserting a ``torch.cuda.synchronize()``
immediately after the up-projection's scatter2scatter — that sync
itself raises "an illegal memory access was encountered", proving the
fault is upstream of the bmm):

The Triton kernel computes output pointer offsets as
``Y_ptr + M_block * stride_ym + N_block * stride_yn`` with int32
``M_block`` / ``stride_ym``. At seq=524288 / shards=16 the
up-projection's output is
``[L_scattered=262144, y_dim=2*INTERMEDIATE=16384]`` = ``2**32``
elements; the trailing rows whose ``M_block * stride_ym`` overflows
int32 have their masked stores silently drop (rows come back as zeros)
or land at bogus pointers, which then trips a delayed
``CUDA illegal memory access`` that the next kernel surfaces.

Workaround at the smallest scope appropriate to the actual root cause
(NOT at parallel_experts.py:72, which is downstream): wrap
``kernels.ops.scatter2scatter`` with ``_scatter2scatter_int32_safe``
and route both call sites (``ParallelLinear.forward`` and
``ParallelLinear.backward``) through it.

The wrapper:

  * Fast path (common case): when ``L_scattered * y_dim < 2**31``,
    dispatches a single direct kernel call — no overhead vs the
    pre-fix code. Verified at seq=524288 / shards=64: 36741 tokens/s
    post-fix vs ~37512 tokens/s pre-fix (within noise, no regression).
  * Slow path: when the output would overflow AND ``y_grouped=True``,
    allocates the full output and chunks along the L_scattered axis.
    Each sub-call writes to ``out[chunk_start:chunk_end]`` with the
    matching sei / ssi slice; the chunk size is the largest
    BLOCK_M-aligned row count keeping ``rows * y_dim < 2**31``. The
    chunked path drops into ``kernels.ops.scatter2scatter_compileable``
    directly to bypass the high-level wrapper's
    ``sorted_scattered_idxs.size(0) == X.size(0) * k`` assertion that
    only holds for full calls.
  * When ``x_grouped=True`` (the down-proj backward), X is sliced in
    lockstep so the kernel's ``M_in_idx = M_block`` correctly reads
    ``X_chunk[0..chunk_size-1]``. When ``x_grouped=False`` (the
    up-proj forward) X stays full because the kernel indexes X via
    global ``M_idx // FAN_OUT`` from the per-position
    ``sorted_scattered_idxs`` values.
  * For ``y_grouped=False`` at overflow scale, the wrapper hard-raises
    ``RuntimeError`` — the kernel uses per-position scattered indices
    as output row indices so the wrapper cannot tile that case
    safely; the kernel itself needs an int64 pointer-arithmetic fix
    before that path is callable at this scale. Production paths
    today are all ``y_grouped=True`` so this branch is unreachable in
    the bench. Silent corruption is strictly worse than a clear raise.
  * Includes ``assert L_scattered % chunk_rows == 0`` for the
    ``x_grouped=False`` chunked path, since the kernel's
    ``M_boundary_mask`` uses the full (unchunked) X size and a
    partial last chunk would let the final tile read past sei_chunk
    / ssi_chunk. The assertion holds for all realistic power-of-2
    shapes and fires loudly if a future caller hits a non-aligned
    one. The ``x_grouped=True`` chunked path is naturally bounded
    because X is chunked in lockstep.

Before / after on the bench config:

  * pre-fix:  CUBLAS_STATUS_EXECUTION_FAILED (no result)
  * post-fix: 10084 ms/iter, 51989 tokens/s, peak 64.16 GiB
              (matches the previously-fastest non-failing rows of
              the s=64 / s=256 sweep; s=16 was the predicted fastest
              row that the bug had been hiding)

Constraints honoured: no changes to ``kernels.ops.scatter2scatter``
or any other Triton kernel; no public-API change on
``parallel_experts.py``; scope confined to scattermoe-lora's
ParallelLinear; common-case fast path untouched. The LoRA-path
counterpart (``parallel_linear_lora.py`` → ``scatter2scatter_lora``)
has the same architectural risk but is not exercised by the failing
bench config and is left for follow-up.

Signed-off-by: Wing Lian <wing@axolotl.ai>

test(scattermoe-lora): enable large-batch repro tests now the fix has landed

Remove the ``pytest.mark.skip`` marker added in the
``repro CUBLAS_STATUS_EXECUTION_FAILED`` commit. The fix in the
previous ``work around cuBLAS large-batch_count failure`` commit
provides ``_scatter2scatter_int32_safe`` and the matching
``_SCATTER2SCATTER_INT32_LIMIT`` constant referenced by these tests,
so the three tests now run and guard the fix going forward:

  * ``test_int32_safe_wrapper_matches_direct_call_below_threshold`` —
    fast-path equivalence (no overhead in the common case).
  * ``test_int32_safe_wrapper_no_corruption_at_overflow_shape`` —
    chunked slow-path correctness at the bench shape.
  * ``test_parallel_linear_long_seq_routing_combination`` —
    end-to-end smoke through ``ScatterMoEGatedMLP.forward`` shape
    sequence at seq=524288 / shards=16.

All three pass on CUDA hardware; they self-skip when CUDA is
unavailable.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* feat(scattermoe-lora): add INT64_INDICES tl.constexpr to dense scatter2scatter

The Triton scatter2scatter kernel computes output pointer offsets as
``Y_ptr + M_block * stride_ym + N_block * stride_yn`` where M_block /
M_idx are int32 by default. At seq=512K with coarse shards the
``L_scattered * y_dim`` product exceeds 2**31 elements and the int32
arithmetic overflows; PR axolotl-ai-cloud#3667 worked around this by chunking the call
along the L_scattered axis when y_grouped=True, but that workaround
doesn't cover y_grouped=False (raises) or the LoRA-path kernels.

Add an ``INT64_INDICES: tl.constexpr = False`` knob to the dense
``_scatter2scatter`` kernel signature. When True, the M_block range and
the scattered-index lookup ``M_idx`` are cast to int64 before they enter
the pointer-offset multiplication, so all downstream pointer arithmetic
propagates int64. Strides themselves stay as the kernel sees them
(coming from ``tensor.stride()`` they're already int64 at the Python
level); only the *index* values change type. Triton will JIT a separate
variant per constexpr value, so the existing int32 fast path is
unaffected.

The wrapper-level auto-dispatch (compute ``needs_int64`` from tensor
sizes and forward to the kernel) lands in a follow-up commit; this
commit just exposes the constexpr and a Python-side ``int64_indices``
kwarg on ``scatter2scatter`` / ``scatter2scatter_compileable``.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* feat(scattermoe-lora): add INT64_INDICES to LoRA forward + dX kernels (bf16 + MX)

Adds the same ``INT64_INDICES: tl.constexpr = False`` knob as the dense
kernel to the four LoRA-path scatter2scatter kernels:

  * ``_scatter2scatter_lora``       — bf16 fused base+LoRA forward
  * ``_scatter2scatter_lora_dX``    — bf16 fused dX backward
  * ``_scatter2scatter_lora_mx``    — MXFP4 fused base+LoRA forward
  * ``_scatter2scatter_lora_dX_mx`` — MXFP4 fused dX backward

The cast pattern matches the dense kernel: when ``INT64_INDICES=True``,
the per-launch ``M_block`` range and the scattered ``M_idx`` lookup are
cast to int64 before they enter the ``M_*_idx * stride_*m`` pointer
arithmetic. That promotes the multiplication to int64 and prevents the
silent overflow at ``L_scattered * y_dim >= 2**31`` that the chunking
workaround on the dense path was guarding against.

The Python-side wrappers (``scatter2scatter_lora``,
``scatter2scatter_lora_dX``, ``scatter2scatter_lora_mx``,
``scatter2scatter_lora_dX_mx``) gain an ``int64_indices: bool = False``
kwarg and forward it to the kernel via the constexpr. Auto-dispatch
from tensor sizes lands in a follow-up commit.

PR axolotl-ai-cloud#3667's chunking workaround only covered the bf16 dense forward; the
LoRA path had the same architectural risk and wasn't covered. With
these constexprs in place and the wrapper-side dispatch coming next,
the kernel itself becomes int64-safe for all five variants and the
chunking wrapper can be retired.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* feat(scattermoe-lora): add INT64_INDICES to group_bwd_lora kernels

Adds ``INT64_INDICES: tl.constexpr = False`` to the three LoRA gradient
kernels that index into the grouped M dimension:

  * ``_group_bwd_lora``         — non-split LoRA-grad kernel (used by
                                  autotune-collector mocks; kept in
                                  sync for telemetry / future callers)
  * ``_group_bwd_lora_split``   — split dA/dB kernel that the public
                                  ``group_bwd_lora`` wrapper actually
                                  dispatches today
  * ``_group_bwd_lora_fused``   — fused gather + dA/dB kernel used by
                                  the LoRA path in ``parallel_linear_lora.py``

In each kernel, when ``INT64_INDICES=True`` we cast:

  - the per-expert ``start_idx`` / ``end_idx`` (and the fused kernel's
    ``real_*`` variants) to int64 on load,
  - ``M_block = tl.arange(0, BLOCK_M)`` to int64 so the per-iter
    ``M_idx = start_idx + i * BLOCK_M + M_block`` propagates int64,
  - and (in the fused kernel) ``scatter_idx`` from sorted-index lookups
    to int64 so ``scatter_idx * stride_dym`` and the
    ``X_token_idx = scatter_idx // FAN_OUT`` arithmetic stay int64.

Strides themselves stay as the kernel receives them (already int64 at
the Python level via ``tensor.stride()``). Triton JITs a separate
variant per constexpr value, so the int32 fast path is unchanged.

Python wrappers (``group_bwd_lora`` and ``group_bwd_lora_fused``) gain
an ``int64_indices: bool = False`` kwarg that forwards to the kernel.
Wrapper-level auto-dispatch from tensor sizes lands in the next commit.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* feat(scattermoe-lora): auto-dispatch INT64_INDICES based on tensor sizes

Adds ``_needs_int64_indices(*tensors)`` in ``parallel_experts.py``: True
iff any input/output tensor's ``numel() >= 2**31 - 1``. That's a
sufficient condition for the kernel's ``M_idx * stride_*m`` pointer
arithmetic to overflow int32 somewhere in the buffer.

Wires the result through the autograd Functions:

  * ``ParallelLinear`` (``parallel_experts.py``): forward computes
    ``needs_int64 = (L_scattered * y_dim) >= INT_MAX or
    _needs_int64_indices(x)`` and forwards via the new ``int64_indices``
    kwarg on ``_scatter2scatter_int32_safe``. The wrapper's fast path
    now passes ``int64_indices`` through to ``kernels.ops.scatter2scatter``
    so the kernel takes the int64 path at overflow scale. The wrapper
    also adds a new branch above the chunking path that routes directly
    to the int64 kernel when ``int64_indices=True`` is requested —
    notably this covers the y_grouped=False overflow case that the
    chunking workaround used to raise on. Backward follows the same
    pattern using ``L_scattered * K`` for the dX-axis bound.

  * ``ScatterMoELoRA`` (``parallel_linear_lora.py``): forward computes
    ``needs_int64`` from ``L_scattered * N`` and forwards to
    ``scatter2scatter_lora`` / ``scatter2scatter_lora_mx``. Backward
    computes a single ``needs_int64_bwd`` from ``M_total * max(N, K)``
    (covering both the dX and the dA/dB kernels' index ranges) and
    forwards to ``group_bwd_lora`` / ``group_bwd_lora_fused`` and to
    the dX kernels (bf16 + MX, fused and non-fused).

The auto-dispatch is cheap (one ``Tensor.numel()`` per check) and
Triton JITs a separate kernel variant per constexpr value, so the int32
fast path is unaffected for small/medium shapes.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* test(scattermoe-lora): int32-vs-int64 parity and overflow correctness

Adds ``tests/integrations/test_scattermoe_lora_int64_indices.py``
covering two properties of the new INT64_INDICES path:

* **Bitwise parity at non-overflow shapes.** For each of the modified
  kernels, ``INT64_INDICES=False`` and ``INT64_INDICES=True`` compute
  the same MMA in the same accumulation order — only the index *type*
  changes. The tests assert ``torch.equal`` between the two variants
  for the dense forward (both y_grouped=True and y_grouped=False), the
  LoRA forward, and the LoRA dX backward. For ``_group_bwd_lora_split``
  the assertion is bitwise; for ``_group_bwd_lora_fused`` it's
  ``torch.allclose`` within bf16 tolerance because that kernel uses
  ``tl.atomic_add`` whose ordering is non-deterministic across launches
  (so bit-equality is not achievable between any two runs of the same
  variant, let alone across variants).

* **Overflow correctness at the failing bench shape.** At
  L_scattered=262144 / y_dim=16384 (2**32 element output), the
  ``INT64_INDICES=True`` kernel populates every row of the output
  (including rows past the int32 overflow boundary) and matches the
  chunked workaround within a generous bf16 tolerance. A second
  bench-shape test runs the real ``ParallelLinear`` forward and uses a
  monkeypatched spy on ``scatter2scatter_compileable`` to assert the
  auto-dispatcher routes through the *direct* int64 kernel call (one
  launch) and **not** the chunking workaround (>=2 launches).

Also folds in a kernel-side fix that the parity tests caught: the
group_bwd_lora kernels' ``if E_idx == 0: start_idx = 0`` branch
produced a plain int32 zero in Triton, which clashes with the int64
``start_idx`` produced by the else-branch under ``INT64_INDICES=True``,
firing ``AssertionError: Mismatched type for start_idx between then
block (int32) and else block (int64)`` at compile. Switching the
zero-initialisation to ``tl.zeros([], dtype=tl.int64/tl.int32)`` keeps
both branches' types consistent.

The bench-shape tests are skipped when free GPU memory is below 80 GiB.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* bench(scattermoe-lora): int64-vs-int32 indexing overhead

Adds ``tests/integrations/kernels/scattermoe_lora/bench_int64_kernel.py``,
a stand-alone script (not pytest) that times the dense
``kernels.ops.scatter2scatter`` at three representative shapes and
reports ms/iter for both ``INT64_INDICES=False`` (int32 fast path) and
``INT64_INDICES=True`` (int64 safe path):

  * **small**    — seq=8K, top_k=8, hidden=2048, N=2048 (auto_int64=False)
  * **medium**   — seq=128K, top_k=8, hidden=2048, N=2048 (auto_int64=False)
  * **overflow** — seq=512K with 16 shards → L_scattered=262144, N=16384
                   (auto_int64=True; the previously-failing bench config)

At overflow shapes the int32 path is silently incorrect, so the int32
column is replaced by the chunked workaround from PR axolotl-ai-cloud#3667 as the
apples-to-apples baseline. Results land in
``bench_int64_kernel_results.md`` next to the script.

Captured on RTX PRO 6000 Blackwell Max-Q (1.79 TB/s HBM):

  shape           int32 ms    int64 ms    chunked ms   penalty
  -------------   --------    --------    ----------   -------
  small             2.687       2.689        —          +0.0%
  medium           40.220      40.581        —          +0.9%
  overflow            —        79.572      79.985        -0.5%

Both acceptance bounds are comfortably met: ≤5% on the int32 fast path
(actual: <1%), and ≤25% on the int64 path vs the chunked workaround
(actual: −0.5%, i.e. the int64 kernel is slightly *faster* than
chunking at this shape because it avoids the per-chunk launch overhead).

Signed-off-by: Wing Lian <wing@axolotl.ai>

refactor(scattermoe-lora): deprecate _scatter2scatter_int32_safe chunking now that kernel is int64-safe

PR axolotl-ai-cloud#3667's ``_scatter2scatter_int32_safe`` chunking wrapper was the
minimum-scope fix for the int32 pointer-overflow at the failing bench
config: it tiled the call along the L_scattered axis to keep each
sub-launch's ``rows * y_dim < 2**31``. With the kernel-level
``INT64_INDICES`` constexpr now landing on every relevant
scatter2scatter family kernel and the wrapper-level auto-dispatch
plumbed through both ``parallel_experts.py`` and
``parallel_linear_lora.py``, the chunking workaround is redundant —
the kernel handles the overflow itself in a single launch.

Removes from ``parallel_experts.py``:

  * ``_scatter2scatter_int32_safe`` and its 160-line chunking loop
  * ``_SCATTER2SCATTER_INT32_LIMIT`` and ``_SCATTER2SCATTER_BLOCK_M``
    constants used only by the chunking path
  * the ``RuntimeError`` raise for ``y_grouped=False`` at overflow
    scale — the int64 kernel handles that case directly

Routes ``ParallelLinear.forward`` / ``.backward`` straight to
``kernels.ops.scatter2scatter`` with ``int64_indices=needs_int64``.

The bench config (seq=524288, 16 shards → L_scattered=262144,
y_dim=16384, output=2**32 elements) now goes through a single int64
kernel launch and matches the bench-recorded perf
(79.6 ms/iter vs. the chunked workaround's 80.0 ms/iter — slightly
*faster* because it eliminates the per-chunk launch overhead).

The PR axolotl-ai-cloud#3667 repro tests are retained as regression guards and
updated to call the new direct-kernel path:

  * ``test_scatter2scatter_below_threshold_no_overhead`` (renamed
    from ``test_int32_safe_wrapper_matches_direct_call_below_threshold``)
    asserts INT64_INDICES=False vs True is bit-identical at non-
    overflow shapes — guards the int32 fast path.
  * ``test_scatter2scatter_no_corruption_at_overflow_shape`` (renamed
    from ``test_int32_safe_wrapper_no_corruption_at_overflow_shape``)
    asserts the int64 kernel populates rows past the int32 overflow
    boundary — guards the kernel-level overflow fix.
  * ``test_parallel_linear_long_seq_routing_combination`` is
    unchanged; it runs ``parallel_linear`` end-to-end at the bench
    shape and asserts no all-zero rows / no NaNs — guards the
    auto-dispatch wiring.

The new ``test_parallel_linear_overflow_takes_int64_kernel_path`` in
``test_scattermoe_lora_int64_indices.py`` is also updated to monkey-
patch ``scatter2scatter_compileable`` and assert the single launch
sets ``int64_indices=True``, which directly verifies the auto-
dispatch verdict at the failing bench shape.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* chore(scattermoe-lora): pre-commit fixups for INT64 indices commits

* ruff format pass on the four touched files (line-wrapping only,
  no functional changes).
* ``parallel_linear_lora.py``: replace the one-line conditional
  ``N_dim = expert_weights.N if is_mx else expert_weights.size(-1)``
  with an explicit if/else and ``# type: ignore[union-attr]`` on each
  branch — mypy can't narrow ``Union[Tensor, MXWeights]`` through a
  ternary, but does respect the explicit branches enough to need the
  ignore only on the offending attribute access. The pre-existing
  ternary in the backward path stays as-is (already covered by the
  surrounding type checks).
* ``bench_int64_kernel.py``: drop imports of the removed
  ``_scatter2scatter_int32_safe`` / ``_SCATTER2SCATTER_INT32_LIMIT``
  symbols (they went away in the refactor commit) and the
  now-unused ``ms_chunk`` column. The bench now reports int32 vs
  int64 timings only; the overflow row shows only int64 since the
  int32 kernel is silently incorrect there.

Signed-off-by: Wing Lian <wing@axolotl.ai>

* test(scattermoe-lora): add small-shape int64 overflow tests (run on L40S/24 GiB)

The two bench-shape overflow tests above need ~80 GiB free and skip on
the Modal CI L40S 48 GiB runner, so the actual overflow path the kernel
fix targets was not exercised on CI. The new ..._small variants repro
the same property at the smallest shape that still straddles the int32
boundary: L_scattered * y_dim = 2**32 (2x past 2**31, guaranteed
overflow without int64_indices=True), with E=4, K=256, y_dim=4096 so W
is ~8 MiB and the only big allocation is the ~8 GiB scatter output.
Gated at 12 GiB free to leave headroom for pytest-xdist workers on
48 GiB devices.

* fix(scattermoe-lora): bump _SMALL_E so int64-overflow topk is valid

_SMALL_TOP_K=8 with _SMALL_E=4 makes torch.topk(logits[T,4], k=8) raise
'selected index k out of range', skipping the two _small overflow tests.
The shape invariant L_scattered*y_dim=2**32 (T=131072, y_dim=4096) requires
top_k=8, so E must be >= 8.


* perf(scattermoe-lora): bucket M in autotune key to dedupe sweeps

The 7 multi-config @triton.autotune kernels in lora_ops.py keyed on
["M", "N", "K"]. M = X.size(0) (or DY.size(0)) scales with
batch*seq*top_k, so any seqlen variation triggers a fresh 30-60 config
sweep per step until the cache happens to cover every realized M. With
N, K model-fixed this was the only churning dimension.

Add a phantom M_BUCKET arg to each kernel signature and switch the
autotune key to ["M_BUCKET", "N", "K"]. The kernel still runs on the
real M (loop bounds + masks unchanged); only the cache lookup is
bucketed to the next multiple of _M_BUCKET_GRANULARITY=1024. No padding,
no wasted FLOPs.

autotune_collector._KEY_NAMES tracks the renamed key so telemetry
matches what's actually in the .cache dict.

Tests:
- New test_scattermoe_lora_m_bucket.py pins both directions: same-bucket
  M values produce one cache entry, distinct-bucket M values produce two.
- Updated telemetry test assertions for the renamed key.
- Existing scattermoe-lora suite (62 tests) + int64 indices (10 tests)
  + telemetry (13 tests) all pass unchanged.

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
@coderabbitai

coderabbitai Bot commented May 29, 2026

Copy link
Copy Markdown

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ffdc2783-4194-4cc4-9233-ec2a5bbdd622

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds non-streaming multimodal CPT via datasets: type: multimodal_pretrain, consolidates streaming and non-streaming encoding, refactors MM-CPT detection and validation, makes dataset hashing and prepared-data caches processor-sensitive, updates docs and examples, and adds unit/regression tests.

Changes

Non-streaming Multimodal CPT via Datasets

Layer / File(s) Summary
Schema definition for multimodal pretrain dataset
src/axolotl/utils/schemas/datasets.py, src/axolotl/utils/schemas/config.py
MultiModalPretrainDataset model class extending PretrainingDataset with type validation is introduced and added to training datasets union.
Core encoder and dataset wrapping strategy for non-streaming MM CPT
src/axolotl/prompt_strategies/multimodal_pretrain.py
encode_multimodal_pretrain() validates and tokenizes text+image pairs with placeholder alignment; MultiModalPretrainDatasetWrappingStrategy wraps datasets; updated load() requires a processor and returns the strategy.
Refactor streaming encoder to delegate to shared implementation
src/axolotl/utils/data/streaming.py
encode_streaming_multimodal() delegates all tokenization and validation to encode_multimodal_pretrain(), removing duplicate logic.
Refactor MM CPT detection to support both streaming and non-streaming pathways
src/axolotl/core/builders/causal.py
Add _entry_is_multimodal_cpt() and _get_mm_cpt_config() to scan test_datasets, pretraining_dataset, and datasets; update _is_multimodal_cpt() and collator config selection.
Normalize and validate multimodal dataset entries in config
src/axolotl/utils/config/__init__.py
validate_config() converts datasets: entries with type: multimodal_pretrain to MultiModalPretrainDataset instances and rejects multimodal: true shortcut for datasets.
Extend multimodal CPT validation for datasets configuration
src/axolotl/utils/schemas/validation.py
check_multimodal_cpt() detects and validates multimodal entries under datasets:, enforcing single-entry rules, forbidding streaming/truncate combinations, and preventing mixed pretraining_dataset+datasets MM CPT.
Update dataset hash computation to include multimodal-specific fields
src/axolotl/utils/data/shared.py
_dataset_hash_component() extracts multimodal-related fields with fallback attribute access; generate_dataset_hash_from_config() uses per-dataset components (order-sensitive) and conditionally appends processor fingerprint when multimodal datasets are present.
Make prepared dataset cache processor-sensitive
src/axolotl/utils/data/sft.py
Include processor fingerprint (now including boi_token) in prepared/raw dataset hash generation calls so prepared dataset reuse depends on processor identity.
Documentation and example configurations for both MM CPT pathways
docs/multimodal.qmd, examples/qwen2_5-vl/mm-cpt-nonstreaming-qlora.yaml, examples/qwen2_5-vl/mm-cpt-streaming-qlora.yaml
Docs describe streaming and non-streaming MM CPT routes, prepared-path resume guidance, and already-tokenized-row behavior; includes Qwen2.5-VL QLoRA YAML examples for streaming and non-streaming runs.
Unit and regression tests for MM CPT pathways
tests/... (multiple files)
Add and update tests covering non-streaming wrapping, streaming detection & partial-patching, collator image path resolution, tokens-per-second state restore, dataset hash sensitivity, processor fingerprint, and validation gate coverage for datasets-style MM CPT.

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 11.63% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title clearly and specifically describes the main change: broadening multimodal CPT to support non-streaming dataset paths in addition to streaming.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/mm-cpt-dataset-pipeline-clean

Comment @coderabbitai help to get the list of available commands and usage tips.

@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai full review

@coderabbitai

coderabbitai Bot commented May 29, 2026

Copy link
Copy Markdown
✅ Actions performed

Full review triggered.

@github-actions

github-actions Bot commented May 29, 2026

Copy link
Copy Markdown

📖 Documentation Preview:

Deployed on Netlify from commit 8d15592

@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented May 29, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/axolotl/prompt_strategies/multimodal_pretrain.py (2)

32-42: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Restore fail-fast sequence-length validation in the map path.

encode_multimodal_pretrain() already avoids truncation; setting enforce_max_length=False here just disables the only early guard for oversized rows. Those samples then survive preprocessing, and the collator only logs a warning when they exceed sequence_len, which makes this much harder to diagnose at training time.

Suggested fix
     def _encode_batch(self, examples: dict[str, list]) -> dict[str, list]:
         return encode_multimodal_pretrain(
             examples,
             tokenizer=self.tokenizer,
             max_tokens=self.sequence_len,
             image_token=self.image_token_spec.image_token,
             image_token_id=self.image_token_spec.image_token_id,
             text_column=self.text_column,
             image_column=self.image_column,
-            enforce_max_length=False,
+            enforce_max_length=True,
         )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 32 - 42,
The _encode_batch wrapper disables the early, fail-fast sequence-length guard by
passing enforce_max_length=False into encode_multimodal_pretrain, allowing
oversized samples to slip through preprocessing and only be warned about later
in the collator; update _encode_batch to enable the guard (pass
enforce_max_length=True) or remove the override so encode_multimodal_pretrain
uses its default fail-fast behavior, referencing _encode_batch,
encode_multimodal_pretrain, enforce_max_length and sequence_len to locate the
change.

85-100: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Fail fast when tokenizer doesn’t match processor.tokenizer (and re-enable max-length validation)

  • build_image_token_spec() resolves image_token_id from processor.tokenizer, but MultiModalPretrainDatasetWrappingStrategy._encode_batch() tokenizes/counts placeholders with the separate tokenizer passed to load(). Enforce tokenizer is processor.tokenizer (dataset encoding happens before the collator) to keep placeholder counting/label masking aligned.
  • _encode_batch() calls encode_multimodal_pretrain(... enforce_max_length=False), disabling the only early guard for rows that exceed sequence_len. This defers failures and risks placeholder/image-count inconsistencies later—enable enforce_max_length (or make it configurable) so oversized rows fail fast.
Suggested fix
 def load(
     tokenizer,
     cfg,
     ds_cfg: Optional[dict[str, Any]] = None,
     processor: ProcessorMixin | None = None,
 ):
     ds_cfg = ds_cfg or {}
     if processor is None:
         raise ValueError(
             "Multimodal CPT (type: multimodal_pretrain) requires a processor. "
             "Set `processor_type: AutoProcessor` (or the concrete processor "
             "class) in your config."
         )
     check_processor_compatibility(processor)
+    proc_tokenizer = getattr(processor, "tokenizer", None)
+    if proc_tokenizer is not None and proc_tokenizer is not tokenizer:
+        raise ValueError(
+            "Multimodal CPT requires `tokenizer` to be `processor.tokenizer` "
+            "so image placeholder ids stay aligned during encoding."
+        )
 
     text_column = ds_cfg.get("text_column") or "text"
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/prompt_strategies/multimodal_pretrain.py` around lines 85 - 100,
Ensure the tokenizer passed to load() is exactly the same object as
processor.tokenizer by adding a fast-fail check (use
check_processor_compatibility or add an explicit assertion) so tokenizer is
processor.tokenizer before creating MultiModalPretrainDatasetWrappingStrategy;
this guarantees build_image_token_spec and
MultiModalPretrainDatasetWrappingStrategy._encode_batch use the same vocab for
placeholder IDs. Also re-enable max-length validation by calling
encode_multimodal_pretrain with enforce_max_length=True (or expose a
configurable flag) inside
MultiModalPretrainDatasetWrappingStrategy._encode_batch so rows that exceed
cfg.sequence_len fail fast rather than downstream.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@src/axolotl/prompt_strategies/multimodal_pretrain.py`:
- Around line 32-42: The _encode_batch wrapper disables the early, fail-fast
sequence-length guard by passing enforce_max_length=False into
encode_multimodal_pretrain, allowing oversized samples to slip through
preprocessing and only be warned about later in the collator; update
_encode_batch to enable the guard (pass enforce_max_length=True) or remove the
override so encode_multimodal_pretrain uses its default fail-fast behavior,
referencing _encode_batch, encode_multimodal_pretrain, enforce_max_length and
sequence_len to locate the change.
- Around line 85-100: Ensure the tokenizer passed to load() is exactly the same
object as processor.tokenizer by adding a fast-fail check (use
check_processor_compatibility or add an explicit assertion) so tokenizer is
processor.tokenizer before creating MultiModalPretrainDatasetWrappingStrategy;
this guarantees build_image_token_spec and
MultiModalPretrainDatasetWrappingStrategy._encode_batch use the same vocab for
placeholder IDs. Also re-enable max-length validation by calling
encode_multimodal_pretrain with enforce_max_length=True (or expose a
configurable flag) inside
MultiModalPretrainDatasetWrappingStrategy._encode_batch so rows that exceed
cfg.sequence_len fail fast rather than downstream.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: bb546c3e-e3f3-4e0e-bb9a-17a08bd60bb4

📥 Commits

Reviewing files that changed from the base of the PR and between 157b7d4 and c42adf4.

📒 Files selected for processing (7)
  • examples/qwen2_5-vl/mm-cpt-nonstreaming-qlora.yaml
  • examples/qwen2_5-vl/mm-cpt-streaming-qlora.yaml
  • src/axolotl/prompt_strategies/multimodal_pretrain.py
  • src/axolotl/utils/data/shared.py
  • src/axolotl/utils/schemas/datasets.py
  • tests/prompt_strategies/test_multimodal_pretrain.py
  • tests/utils/data/test_hash.py
💤 Files with no reviewable changes (1)
  • src/axolotl/utils/schemas/datasets.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/axolotl/utils/data/shared.py
  • examples/qwen2_5-vl/mm-cpt-nonstreaming-qlora.yaml

thad0ctor and others added 3 commits May 28, 2026 17:42
…fault (axolotl-ai-cloud#3680)

* feat(qwen): fused RMSNorm+RoPE for Qwen3 / Qwen3-MoE / Qwen3.5 / Qwen3.5-MoE

Generalizes the existing Gemma 4 fused RMSNorm+RoPE Triton kernel to four
new Qwen attention variants, and auto-enables Liger's fused (m-)rope
kernel for the Qwen-VL family. Eager-mode behavior is bit-identical when
the new cfg.fused_attn_kernel flag is unset.

Changes
-------
* New ``cfg.fused_attn_kernel: bool | None`` (default None / off). When
  set, replaces ``q_norm + apply_rotary_pos_emb`` (and the matching K path)
  with a single fused RMSNorm+RoPE Triton kernel launch. Currently wired
  for ``qwen3``, ``qwen3_moe``, ``qwen3_5``, and ``qwen3_5_moe``
  model_config_types. Llama4 is out of scope (complex freqs_cis +
  Llama4TextL2Norm post-RoPE — separate kernel).
* Kernel ``UNIT_OFFSET: tl.constexpr`` flag added to the forward + backward
  Triton kernels for Qwen3.5's Gemma-style ``(1.0 + weight)`` RMSNorm.
  Default ``False`` keeps Gemma 4 / Qwen3 / Qwen3-MoE bit-identical to
  before. Threaded through the triton_op + register_autograd plumbing.
* Refactors ``fused_rms_norm_rope`` / ``fused_rms_norm_noscale`` to
  ``torch.library.triton_op`` + ``register_autograd`` so they trace under
  ``torch.compile(fullgraph=True)``. Validated: 1 Dynamo frame, 0 graph
  breaks. On sm_120 the compile path composes to +9.2% combined,
  −33% peak memory. On sm_86 the surrounding Inductor-generated kernels
  regress — leave ``torch_compile: false`` there; schema description
  documents the per-arch recommendation.
* Liger Qwen-VL auto-default: when ``cfg.liger_rope is None`` and
  model_config_type is one of qwen2_vl/qwen2_5_vl/qwen3_vl (+ ``_text``
  variants), pass ``rope=True`` so upstream's fused m-rope kernel is
  actually installed. Previously the plugin overrode the upstream default
  to None, silently skipping the kernel.
* Patch-ordering fix: ``_apply_self_attention_lora_patch`` now runs
  before ``_apply_model_specific_patches`` in
  ``apply_pre_model_load_patches``. ``patch_self_attn_lora`` reads
  ``inspect.getsource`` of the attention class' forward, so any patch that
  replaces ``Attention.forward`` must run *after* the source-rewrite step.
  The wrong order also silently broke Gemma 4 + ``lora_qkv_kernel`` —
  pinned by ``TestPatchManagerOrdering`` and a fused-first trip-wire.

Tests
-----
* Per-model parity + backward grad flow for Qwen3, Qwen3-MoE, Qwen3.5,
  Qwen3.5-MoE (full-attention layers only; linear_attention layers stay
  on the stock GatedDeltaNet path).
* Kernel ``UNIT_OFFSET=True`` parity vs from-scratch reference + bwd
  parity vs torch-eager + ``torch.compile(fullgraph=True)`` parity.
* ``torch.compile(fullgraph=True)`` parity for the no-offset path.
* Liger Qwen-VL auto-default for all 6 model_config_types; explicit
  ``False`` is respected.
* Patch idempotency (double-apply is a no-op).
* Transformers signature contract — pins the stock attention forward
  argument names so future drift trips loudly at test time.
* Gradient-checkpointing composition (Qwen3 + ``gradient_checkpointing_enable``).
* Flash-Attention 2 composition (skip-if-unavailable).
* LoRA + fused composition on Qwen3 / Qwen3.5 / Qwen3.5-MoE, with
  fused-first reverse-order trip-wires that catch the original ordering
  bug if anyone re-introduces it.

A pre-existing upstream-drift xfail in ``test_gemma4_fused_attn.py``
documents Gemma 4 + ``lora_qkv_kernel`` being broken in transformers
5.8.1 (new ``shared_kv_states: dict[str, ...]`` signature drift in
QKV_PATCHES). Out of scope for this PR; flips to XPASS when patched.

Post-review fixes
-----------------
* ``_resolve_norm_module``: PEFT ``ModulesToSaveWrapper`` stores
  ``active_adapter`` as ``list[str]`` (e.g. ``["default"]``), not a string.
  The prior ``isinstance(adapter, str)`` check silently returned the
  frozen ``original_module`` for every real-PEFT case. Switched to
  iterating ``active_adapters`` (with ``active_adapter`` fallback) across
  all 4 patches. Added a direct unit-test plus an end-to-end test that
  drives real ``peft.get_peft_model(modules_to_save=["q_norm","k_norm"])``
  and asserts the helper returns the trainable adapter weight.
* ``cfg.fused_attn_kernel`` unsupported-model warning: moved out of the
  Pydantic ``model_validator(mode="before")`` (which ran *before*
  ``normalize_config()`` had derived ``model_config_type``, so it
  silently no-op'd on normal YAML input) into a new
  ``PatchManager._warn_if_fused_attn_unsupported`` staticmethod invoked
  from ``_apply_model_specific_patches``, where ``model_config_type`` is
  guaranteed set. Added a source-line guard that the helper stays wired.

* address coderabbit comments

* improve bwd pass throughput

* feat(qwen3-vl): add fused attention patch

* test: capture fused attention logs from concrete loggers

* ci: rerun tests

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
…xolotl-ai-cloud#3661)

* compute kd loss in trainer

* add kd trainer compute_loss tests

* remove unused kd kernel patch module

* don't materialize all the logits

* ensure dtype from hidden states matches dtype for chunked kd since we're not inside the autocasting anymore

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
@thad0ctor thad0ctor force-pushed the feat/mm-cpt-dataset-pipeline-clean branch from 68d96fc to 8534115 Compare May 29, 2026 03:59
…tl-ai-cloud#3689) [skip-ci]

When the suite runs under pytest-xdist, multiple workers race for the same
physical GPU's memory budget. A test that fits comfortably in isolation
can OOM purely because peer workers are already holding most of VRAM
(observed: 8 workers each holding ~44 GiB on a 44 GiB card).

Add a conftest in tests/integrations/kernels/scattermoe_lora/ that hooks
pytest_runtest_call and converts torch.OutOfMemoryError into a skip. Real
correctness bugs still surface as failures since they raise asserts /
typed exceptions, not OOM.

Uses a hookwrapper rather than an autouse fixture because pytest captures
the test exception before re-entering the fixture's generator, so the
fixture's try/except around yield never sees it.
@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented May 29, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@thad0ctor thad0ctor force-pushed the feat/mm-cpt-dataset-pipeline-clean branch from 8534115 to 0b14383 Compare May 29, 2026 06:39
@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented May 29, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/prompt_strategies/test_multimodal_pretrain.py`:
- Line 154: The pytest.raises assertion uses match="processor.tokenizer" where
the dot is a regex wildcard; change it to a literal match by escaping the dot
(e.g., match=r"processor\.tokenizer") or by using
re.escape("processor.tokenizer") so the test asserts the exact substring; update
the pytest.raises call in tests/prompt_strategies/test_multimodal_pretrain.py
accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f480a349-a646-49e2-96c1-33934cc3826c

📥 Commits

Reviewing files that changed from the base of the PR and between c42adf4 and 8534115.

📒 Files selected for processing (10)
  • src/axolotl/prompt_strategies/multimodal_pretrain.py
  • src/axolotl/utils/data/sft.py
  • src/axolotl/utils/data/shared.py
  • src/axolotl/utils/schemas/datasets.py
  • tests/prompt_strategies/test_multimodal_pretrain.py
  • tests/test_multimodal_streaming.py
  • tests/utils/data/test_hash.py
  • tests/utils/data/test_mm_cpt_eval.py
  • tests/utils/data/test_mm_pretrain_cache.py
  • tests/utils/schemas/validation/test_multimodal_cpt.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/utils/schemas/validation/test_multimodal_cpt.py
  • src/axolotl/prompt_strategies/multimodal_pretrain.py

Comment thread tests/prompt_strategies/test_multimodal_pretrain.py Outdated
@thad0ctor thad0ctor force-pushed the feat/mm-cpt-dataset-pipeline-clean branch from 0b14383 to 8d15592 Compare May 29, 2026 06:45
@thad0ctor

Copy link
Copy Markdown
Owner Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented May 29, 2026

Copy link
Copy Markdown
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

winglian added 2 commits June 1, 2026 12:23
…3697)

* add pytorch 2.12 base and prune unused base images

* Add back 2.11.0 and add them to basic pytest matrices
* bump transformers to 5.9.0 and trl to 1.5.1

* test(gemma4-kernelize): accept ValueError from transformers 5.9 attach_hidden_kernels

transformers ≤5.8 surfaced the non-Module ``_hidden_kernels`` entry as
TypeError/AttributeError via ``module.register_module(name, fn)``. 5.9
reworked ``attach_hidden_kernels`` to raise ``ValueError`` directly with
a clearer error message. The patch under test (strip dead entries
before ``kernelize()`` runs) does the right thing either way; broaden
the expected-crash assertion so the test reflects current upstream
behavior.

* 30 min timeout

* fix(activation-offload): drop monkey-patched __enter__ now that TRL 1.5.1 ships upstream fix

TRL 1.5.1 implements huggingface/trl#5730 natively — ``OffloadActivations``
now has its own ``__enter__`` that clears tracker / stashes between steps,
**plus** two things the axolotl backport never had:

- ``self.tensor_id = 0`` reset (without this, the tensor_id counter accumulates
  across steps; harmless on its own but skews the ``fwd_stash`` eviction window).
- ``torch.cuda.empty_cache()`` when bitsandbytes is loaded — flushes the BNB
  allocator between steps so its compute / optimizer-state buffers don't
  accumulate as live storage.

TRL 1.5.1 also adds a ``__exit__`` that syncs the offload streams (``s0``,
``s1``) before the parent cleanup runs. The axolotl backport only overrode
``__enter__``, so ``__exit__`` was inherited correctly either way.

Once we bumped TRL 1.1.0 → 1.5.1 (transformers 5.9 bundle), the monkey-patch
became strictly worse than upstream — it shadowed the better ``__enter__``,
dropping the ``tensor_id`` reset and the BNB ``empty_cache``. Combined with
cu130's stricter cross-stream lifetime checks, this surfaced as XID 43
(driver-killed CUDA channel) during ``test_activation_offloading[lora]``,
followed by every subsequent test failing at ``torch.manual_seed(42)``
because the CUDA context was permanently poisoned.

Drop the patch and the wrapper — upstream is now the source of truth, per
the existing TODO in this file.
winglian and others added 5 commits June 2, 2026 16:49
* prefer latest pytorch as gated e2e tests

* fix(fsdp2-qlora): match _init_sharded_param anchor for torch 2.12 + fallback to 2.11

torch 2.12.0 rewrote the sharded-param construction in
FSDPParam._init_sharded_param from a two-line form

    self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
    self.sharded_param.requires_grad_(param.requires_grad)

to a single multi-line Parameter() call with requires_grad= as a kwarg

    self.sharded_param = nn.Parameter(
        self.to_sharded_dtensor(sharded_param),
        requires_grad=param.requires_grad,
    )

Functionally identical, but the axolotl monkey-patch is source-level
text replacement: the 2.11 anchor no longer matches the 2.12 source, so
the substitution silently falls through to the warning branch and the
method stays unpatched — bnb Params4bit / Int8Params lose their
quantization metadata through the FSDP2 shard cycle.

Try the 2.12 anchor first; fall back to the 2.11 anchor so the patch
keeps working against both torch versions in our test matrix.

init_unsharded_param uses the same kwarg-style call in both 2.11 and
2.12, so its anchor is untouched.

* fix(fsdp2-qlora): match init_unsharded_param anchor for torch 2.12

torch 2.12 hoisted the unsharded-param construction out of the
first-all-gather `else:` branch up to method-body level, so the 2.11
anchor (8-space, inside else) no longer matched and the patch silently
no-op'd. This left bitsandbytes Params4bit unreconstructed under FSDP2,
surfacing as `mat1 and mat2 shapes cannot be multiplied (... 1x36864)`
in QLoRA training. Add the 2.12 method-body-level anchor with its own
replacement indentation, falling back to the 2.11 form.

* test(multigpu): stabilize test_lora_ddp with 20 steps + seed

test_lora_ddp ran only 2 steps with no seed, so train_loss was a random
draw (observed 1.95-3.23 across runs) and the 2.8 threshold tripped
intermittently — the torch 2.12 bump just happened to surface it. Run 20
steps with seed=42 to make the loss deterministic (2.189-2.191 spread),
and tighten the threshold to 2.5.

* fix(optimizers): support torch 2.11 graph health-check rename in ADOPT

torch 2.11 renamed Optimizer._cuda_graph_capture_health_check to
_accelerator_graph_capture_health_check (2.12 re-added the old name as an
alias). ADOPT called the old name, so it raised AttributeError under torch
2.11 — surfaced by bumping the docker-e2e row from 2.9.1 to 2.11.0. Resolve
whichever name exists, preferring the new one. Also swap the deprecated
torch._utils.is_compiling() for torch.compiler.is_compiling().
axolotl-ai-cloud#3700) [skip ci]

The pyproject migration removed setup.py, so the publish workflow failed at
`python setup.py sdist` (No such file). Build the sdist+wheel with `uv build`
(PEP 517; setuptools backend reads the version from VERSION). Also make the
GitHub release step idempotent so a re-run/re-tag of an existing release
doesn't fail, and drop the unused dependency-install step.
…tl-ai-cloud#3701)

The fused Gemma4 attention monkeypatch read and stored shared KV states
by `kv_shared_layer_index`/`layer_idx`, but transformers 5.8 dropped the
`kv_shared_layer_index` attribute and switched to keying `shared_kv_states`
by `layer_type`. On the pinned transformers 5.9, any Gemma4 model with
`num_kv_shared_layers > 0` (e.g. gemma-4-E2B vision) raised
`AttributeError: 'Gemma4TextAttention' object has no attribute
'kv_shared_layer_index'` once execution reached a shared layer.

Derive the read/store key from whichever attribute the installed
transformers exposes, keeping compatibility with both the old and new
APIs. Add a fused-attn regression with `num_kv_shared_layers > 0` so the
shared-KV branch is actually exercised (existing tests defaulted to 0).
…er tests (axolotl-ai-cloud#3705) [skip ci]

The Python 3.12 PyTest legs run ~2x slower than 3.14 on the same test set
(816s vs 403s) and were tipping over the 30-minute job timeout. Two causes,
both in the slow tail:

- dataset_num_proc=4 forks 4 dataset workers per .map() on CPU-only runners,
  each re-importing the torch stack to process a few hundred rows — pure
  overhead. Lower to 1 in the affected tests (none assert on it or test
  multiprocessing); results are unchanged.
- --dist loadfile pins a whole file to one worker, so the entire builder
  suite serialized on a single worker at the end. Move shared fixtures to
  tests/core/conftest.py and split the RL trainer-builder tests into
  test_builders_rl.py so they run on a separate worker from the SFT/reward
  builder tests.
@thad0ctor thad0ctor deleted the branch feat/mm-cpt-dataset-pipeline-base June 5, 2026 21:12
@thad0ctor thad0ctor closed this Jun 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants