Fused ScatterMoE-LoRA for MXFP4 weights#3663
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdds MXFP4 support to ScatterMoE LoRA with fused forward and backward Triton kernels that dequantize expert weights on-the-fly, a new MXWeights container type for expert quantization metadata, selective expert dequantization for both strategies, integration into ScatterMoELoRA's forward and backward passes, and comprehensive unit, integration, and performance tests. ChangesMXFP4 Fused Kernels for ScatterMoE LoRA
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/axolotl/integrations/kernels/libs/scattermoe_lora/mx_weights.py (1)
212-212:⚠️ Potential issue | 🟠 Major | ⚡ Quick winResolve pre-commit formatting drift before merge.
CI already reports
end-of-file-fixerandruff-formatmodifying this file; please commit those formatting changes so lint passes.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/mx_weights.py` at line 212, Run the project's pre-commit hooks / formatters on the affected file to resolve formatting drift: apply end-of-file-fixer (ensure a single trailing newline) and run ruff/ruff-format (or the project's formatter) on mx_weights.py, then stage and commit the updated file so CI no longer modifies it; alternatively run the project's pre-commit install and `pre-commit run --all-files`, review the changes, and commit them.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/mx_weights.py`:
- Around line 91-104: MXWeights.__post_init__ currently silently reinterprets
non-uint8 scales and doesn't validate shapes, which can corrupt buffers; change
it to require scales.dtype == torch.uint8 (raise ValueError instead of view())
and add explicit shape/rank checks tying packed and scales to K, N, and
block_size: validate packed.dtype == uint8, packed.ndim and scales.ndim, confirm
packed.size(1)/packed.size(2) or packed layout matches expected K and N derived
from attributes (and that N is divisible by MX_BLOCK_SIZE), and ensure
scales.shape matches (num_experts, N // MX_BLOCK_SIZE) or the correct per-block
layout used by your Triton kernel; only after these validations set num_experts
from packed.size(0) if None, otherwise raise descriptive errors on mismatch.
- Around line 151-162: Update the misleading ImportError strings that reference
the wrong torchao minimum version: locate the MXFP4 import guard where MXTensor
= _torchao_mxtensor_cls() and the subsequent raise ImportError call that
currently says "MXFP4 path requires torchao (install `torchao>=0.7`)" (and the
similar message later around the other check at ~line 186) and change the
messages to require "torchao>=0.17.0" so they match the documented and pinned
dependency and the constructor signature used.
In
`@src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py`:
- Around line 71-79: The mypy error comes from using is_mx =
isinstance(expert_weights, MXWeights) which doesn't narrow the union type for
the else branch; replace the boolean temp with a direct isinstance check or
explicitly narrow the type in the else branch—e.g., change "is_mx =
isinstance(expert_weights, MXWeights); if is_mx: ..." to "if
isinstance(expert_weights, MXWeights): ..." or, after the current check, add
"from typing import cast; expert_weights = cast(torch.Tensor, expert_weights)"
before using tensor-specific attributes so expert_weights (and symbols
MXWeights, MXLayout, and the else branch that accesses expert_weights.dtype) are
properly narrowed for mypy.
In `@tests/integrations/kernels/scattermoe_lora/bench_mxfp4.py`:
- Around line 122-143: The variable `template` is Optional and is dereferenced
when constructing the returned MXTensor, which trips mypy; add an assertion like
`assert template is not None` (or equivalent narrowing) immediately before the
MXTensor(...) return to guarantee to the type checker that `template` is
non-None; ensure the assertion sits right before the MXTensor construction so
references to `template.elem_dtype`, `template.block_size`,
`template.orig_dtype`, `template.kernel_preference`,
`template.act_quant_kwargs`, and `template.is_swizzled_scales` are accepted by
mypy.
In `@tests/integrations/kernels/scattermoe_lora/test_mxfp4_expert_weights.py`:
- Line 441: The local variable fwd_tol is assigned from _tol_for_shape(K) but
never used, causing a lint F841; remove the unused assignment or explicitly
discard it (e.g. call _tol_for_shape(K) and assign to _ ) in the
test_mxfp4_expert_weights.py test where fwd_tol is created so that the linter no
longer flags an unused variable; locate the statement involving fwd_tol and
either delete that line or replace the variable name with an underscore.
---
Outside diff comments:
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/mx_weights.py`:
- Line 212: Run the project's pre-commit hooks / formatters on the affected file
to resolve formatting drift: apply end-of-file-fixer (ensure a single trailing
newline) and run ruff/ruff-format (or the project's formatter) on mx_weights.py,
then stage and commit the updated file so CI no longer modifies it;
alternatively run the project's pre-commit install and `pre-commit run
--all-files`, review the changes, and commit them.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6e0becca-338b-4c18-8fd4-fe47a239927f
📒 Files selected for processing (11)
src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.pysrc/axolotl/integrations/kernels/libs/scattermoe_lora/layers.pysrc/axolotl/integrations/kernels/libs/scattermoe_lora/mx_weights.pysrc/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.pysrc/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.pytests/integrations/kernels/__init__.pytests/integrations/kernels/scattermoe_lora/__init__.pytests/integrations/kernels/scattermoe_lora/bench_mxfp4.pytests/integrations/kernels/scattermoe_lora/bench_mxfp4_results.mdtests/integrations/kernels/scattermoe_lora/test_mxfp4_expert_weights.pytests/integrations/kernels/scattermoe_lora/test_mxfp4_integration.py
|
📖 Documentation Preview: https://6a17b69b033abf02a98fa668--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 005c12e |
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
1dbb3bc to
3f64bbd
Compare
Add an MXFP4 branch to `selective_expert_weights()` that detects a torchao `MXTensor` parameter (elem_dtype=float4_e2m1fn_x2) and dequantizes only the active experts via index-then-construct of a compact sub-MXTensor. The K-axis OCP block layout (last storage dim) matches `experts.gate_up_proj` natural shape `[E, N, K]`, so the caller's existing `.transpose(2, 1)` post-step keeps producing the kernel's `[E, K, N]` weight tile unchanged. `HFScatterMoEGatedMLP.forward` now also routes through the selective path whenever the experts hold MXFP4 weights — full-tensor MX dequant of 256-expert models is prohibitive and the kernel needs bf16 input. Tests (CUDA-only) compare against a bf16 baseline produced by the same MXTensor's full dequant; outputs are bitwise identical for both forward and backward (dX, dA, dB) across small [E=8,K=128,N=256] and representative [E=32,K=2048,N=1024] shapes, and across all four combinations of \`use_fused_dX\` / \`use_fused_gather\`. Signed-off-by: Wing Lian <wing@axolotl.ai>
Add MX-aware forward and dX kernels that consume an ``MXWeights``
container (packed uint8 + E8M0 scales) directly, so the base-weight
tile is dequantized inside the K-loop instead of through a materialized
bf16 buffer. The K-loop loads two FP4 values per uint8 byte, looks them
up in a 16-entry codebook tensor (``±{0, 0.5, 1, 1.5, 2, 3, 4, 6}``),
multiplies by ``2^(scale_byte - 127)``, and casts to bf16 for the
matmul. ``BLOCK_K`` is constrained to a multiple of the OCP block size
(32) so each tile aligns with whole scale blocks; an MX-aware autotune
pruner accounts for the extra packed/scale SMEM.
The dX kernel reuses the *forward* MX layout (block axis = K, the dX
output axis) — for each (K_tile, N_tile) sub-tile, nibbles decode
along the K rows (the byte is shared by two adjacent K rows) and
scales broadcast within their MX block. This avoids the
dequant + re-quantize "pre-transpose" the spec suggested and the
extra MX-rounding error that round-trip would have introduced.
``ScatterMoELoRA.forward`` now accepts either a dense tensor or an
``MXWeights``; the MX branch always selects the fused-dX and
fused-gather backward kernels (the non-fused dX path would have to
materialize a bf16 weight tile, defeating the win).
Unit tests cover forward, dX, dA, dB parity for small
[E=8, K=128, N=256] and representative [E=32, K=2048, N=1024] shapes;
tolerances are calibrated to bf16 MMA noise (atomic-add ordering and
FMA reordering between the full-E baseline and compact-active MX path).
Integration test exercises a tiny synthetic DeepSeek-V4-style MoE
block (E=8, hidden=512, intermediate=256, top_k=2) end-to-end through
both Strategy A and Strategy B with LoRA disabled.
Signed-off-by: Wing Lian <wing@axolotl.ai>
Add ``bench_mxfp4.py`` and committed results for the representative DeepSeek-V4-style shape (E=128, K=2048, N=1024, top_k=8, M=4096, rank=16). Reports ms/iter, tokens/s, peak GPU memory, and HBM bandwidth utilisation for three configurations: bf16 baseline, Strategy A (selective dequant), Strategy B (fused MX). On the RTX PRO 6000 Blackwell, the all-active-experts shape used here doesn't exercise selective dequant's memory savings (active = E = 128) — A pays the cost of materialising the full bf16 dequant buffer per step (~9 GB peak vs 1.9 GB for B) while still routing through the bf16 kernel. B halves A's wall time (~12 ms vs 30 ms) by eliminating the buffer, but stays slower than the bf16 baseline (5 ms) which assumes the bf16 weights already exist in memory. Signed-off-by: Wing Lian <wing@axolotl.ai>
Signed-off-by: Wing Lian <wing@axolotl.ai>
Signed-off-by: Wing Lian <wing@axolotl.ai>
The MX-aware autotune pruner for the forward kernel under-accounted SMEM: it computed the packed-tile cost as BLOCK_N * BLOCK_K/2 and the scale-tile cost as BLOCK_N * BLOCK_K/MX_BLOCK_SIZE, but the actual tl.load issues a full [BLOCK_N, BLOCK_K]-shaped uint8 fetch for both buffers (the packed buffer reads each byte twice because K_byte = K // 2 indexes a [BLOCK_K]-wide vector; the scale buffer broadcasts within each MX_BLOCK_SIZE K-block). Bring the forward pruner up to the same conservative full-tile accounting already used by _prune_dX_mx_configs. Without this, on the [E=128, K=2048, N=1024] shape with the typical GPU SMEM caps, two to six high-stage configs that were previously selectable would have overflowed SMEM at launch under correct accounting — a silent OOM-in-the-future risk. Signed-off-by: Wing Lian <wing@axolotl.ai>
The file-level docstring for the MXFP4 kernels described the dX kernel as using a pre-transposed [E, K, N/2] layout produced by a 'mx_pre_transpose_for_dx' helper. That helper doesn't exist; the dX kernel actually reuses the forward [E, N, K/2] layout, iterating the N reduction in outer tiles and decoding nibbles along the K rows of each tile. Rewrite the docstring to describe what the code actually does, including the rationale — reusing the forward buffer avoids the dequant + re-quantize round-trip that a pre-transpose would require and keeps dX numerics free of a second MX rounding error stacked on top of the forward quantization. Signed-off-by: Wing Lian <wing@axolotl.ai>
F4: Hoist 'is_mxfp4_param' import from inside 'HFScatterMoEGatedMLP.forward' to the top of layers.py — it was being re-imported every step on the hot path. F5: Add a thin compatibility shim for torchao MXTensor internals access in mx_weights.py. The MX paths in selective_dequant.py / mx_weights.py used to reach into 'mx_param.qdata', 'mx_param.scale', 'mx_param.kernel_preference' and call 'MXTensor(...)' with positional args directly. That works at the pinned torchao 0.17.0 but is fragile to internal renames in future torchao releases. Funnel through three helpers — '_mx_qdata', '_mx_scale', '_construct_mxtensor_subset' — that use 'getattr' fallbacks for the buffer attributes and pass the constructor's optional args via 'getattr' too. Single point of pain, no API change. F7: Remove the unused 'NO_K_MASK' heuristic + tl.constexpr param from the dX MX kernel '_scatter2scatter_lora_dX_mx'. The dX kernel never references it (its inner loop masks N, not K), so the constexpr just forced extra autotune key entries. F8: Consolidate the duplicate '_torchao_mxtensor_cls()' definitions (one in selective_dequant.py, one in mx_weights.py) into a single definition in mx_weights.py. selective_dequant.py imports it. Signed-off-by: Wing Lian <wing@axolotl.ai>
F3: 'test_strategy_a_backward_fused_variants' previously used
'torch.ones_like(output)' as the grad input and asserted only on dX.
A uniform grad zeros out cross-token differences in the fused-gather
accumulation, masking reordering bugs; restricting the assertion to dX
silently let the dA/dB paths go unchecked across the four
'(use_fused_dX, use_fused_gather)' production variants.
* Drive the backward with 'torch.randn_like(output) * 0.1'.
* Capture and assert dA and dB parity across all four variants
using the same 'row_idx' gather pattern as
'test_strategy_a_backward_matches_bf16'.
* Forward and dX are still asserted bitwise via 'torch.equal'. dA/dB
fall back to atol/rtol = 1e-3 because the fused dA/dB kernel uses
'atomic_add' across N-block programs and the in-flight program
count differs between the full-E baseline and the compact-active
path; combined with FMA reordering, the 'use_fused_dX=True'
variants accumulate ~1 bf16 ULP of unavoidable atomic-order noise.
The new bound is still an order of magnitude below that noise
floor, so it catches real bugs.
F9: The 'test_strategy_b_backward_matches_bf16' dX comparison runs at
'atol=0.5, rtol=2e-2' (small) / 'atol=2.0, rtol=3e-2' (representative)
to allow for accumulated bf16 MMA noise over the N reduction. Those
bounds are appropriate for legitimate per-element drift but would also
admit a uniform multiplicative bug — e.g. an off-by-one on the E8M0
exponent that scales every dX element by 2x.
Add a guard alongside the existing 'torch.allclose': mask out
near-zero baseline elements (relative to 'bf16_dX.abs().max()'), then
require the per-element ratio 'mx_dX / bf16_dX' to have std < 0.5. A
uniform multiplicative bug pushes that std to ~0 while the mean shifts;
a real-bug per-element drift pushes the std up. This crosscuts the
allclose check rather than replacing it.
Signed-off-by: Wing Lian <wing@axolotl.ai>
The previous bench harness did a fresh '.clone()' of x and a
'requires_grad_(True)' on cloned lora A/B tensors every iter inside
the timed window. That accounts for buffer allocation, not kernel
cost, and biases the numbers toward whichever path produced the
smallest activations. Restructure the runners so:
* 'x' is cloned once into a leaf tensor with 'requires_grad_(True)'
inside 'bench()' (outside the timed warmup + timed loop).
* LoRA A/B leaf tensors are constructed once in the runner factory,
not per iter.
* Each iter calls the runner which sets 'x.grad = A.grad = B.grad =
None' (cheap, no GPU sync) so the autograd graph for the timed
iteration is fresh and grads don't accumulate.
Re-run all three configs end-to-end after this change (dense E=128,
sparse E=256 / 10-active, balanced E=256 M-sweep at M ∈ {256, 1024,
4096, 16384}) and refresh the numbers in bench_mxfp4_results.md.
Headers and table structure are unchanged. The qualitative ordering
holds (Strategy A wins at low active/E, Strategy B wins near
active/E ≈ 1, and Strategy A still OOMs across the balanced sweep on
the workstation with vLLM colocated), with per-cell numbers within
single-digit percent of the prior runs.
Signed-off-by: Wing Lian <wing@axolotl.ai>
Signed-off-by: Wing Lian <wing@axolotl.ai>
Signed-off-by: Wing Lian <wing@axolotl.ai>
…arity assertions Wing's "lint and PR review fixes" commit (9007a82) reverted three fixes from the prior lint pass. Restore them: 1. parallel_linear_lora.py: use isinstance(expert_weights, MXWeights) directly so mypy can narrow the union — the `is_mx` boolean alias blocks narrowing and re-introduces 2 union-attr errors. 2. bench_mxfp4.py: assert template is not None before the MXTensor(...) constructor — the chunked converter initializes template to None then sets it inside the loop, which mypy can't prove non-None at the call site (6 None-attr errors). 3. test_mxfp4_expert_weights.py: the F841 on fwd_tol was actually a smell of dropped logic. Both backward tests (test_strategy_a_backward_matches_bf16 and test_strategy_b_backward_matches_bf16) compute the forward outputs out_b/out_a/out_s, run backward, and assert gradients match — but never assert that the forward outputs match. A forward bug producing a constant offset (and therefore zero gradient delta) would slip past the bwd-only checks. Add the missing torch.equal(out_b, out_a) for Strategy A (bitwise contract) and torch.allclose(out_b, out_s, **fwd_tol) for Strategy B (MX tol). Signed-off-by: Wing Lian <wing@axolotl.ai>
3f64bbd to
005c12e
Compare
Earlier pass rejected fp8/nvfp4/mxfp4 at the schema layer, telling users to use QAT/PTQ instead. That was wrong: - NVFP4 has a real weight-only torchao config (NVFP4WeightOnlyConfig in torchao.prototype.mx_formats) — it's a 4-bit quant, perfectly suited to QLoRA. Now auto-promotes adapter lora -> qlora and builds NVFP4WeightOnlyConfig at load. - FP8 (float8_e4m3fn) has Float8WeightOnlyConfig in torchao.quantization — a one-byte-per-weight quant that mirrors INT8's role. Keeps adapter as lora. - MXFP4 is the genuine 'no weight-only flavor' case. The schema now passes it through; the loader raises with a pointer to quantize_moe_experts: true for MoE models (which is where MXFP4 LoRA actually lives, via the ScatterMoE-LoRA path landed in #3663) and to qat/ptq for inference-time MXFP4. CUDA smoke-tested on SmolLM2-135M: - weight_dtype: fp8 -> Float8WeightOnlyConfig, forward+backward OK - weight_dtype: nvfp4 (group_size=16) -> NVFP4WeightOnlyConfig, OK - weight_dtype: mxfp4 -> loader error pointing to quantize_moe_experts Docs and the dtype table updated; schema/loader tests extended. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Earlier pass rejected fp8/nvfp4/mxfp4 at the schema layer, telling users to use QAT/PTQ instead. That was wrong: - NVFP4 has a real weight-only torchao config (NVFP4WeightOnlyConfig in torchao.prototype.mx_formats) — it's a 4-bit quant, perfectly suited to QLoRA. Now auto-promotes adapter lora -> qlora and builds NVFP4WeightOnlyConfig at load. - FP8 (float8_e4m3fn) has Float8WeightOnlyConfig in torchao.quantization — a one-byte-per-weight quant that mirrors INT8's role. Keeps adapter as lora. - MXFP4 is the genuine 'no weight-only flavor' case. The schema now passes it through; the loader raises with a pointer to quantize_moe_experts: true for MoE models (which is where MXFP4 LoRA actually lives, via the ScatterMoE-LoRA path landed in #3663) and to qat/ptq for inference-time MXFP4. CUDA smoke-tested on SmolLM2-135M: - weight_dtype: fp8 -> Float8WeightOnlyConfig, forward+backward OK - weight_dtype: nvfp4 (group_size=16) -> NVFP4WeightOnlyConfig, OK - weight_dtype: mxfp4 -> loader error pointing to quantize_moe_experts Docs and the dtype table updated; schema/loader tests extended.
Summary by CodeRabbit
New Features
Tests