Skip to content

Fused ScatterMoE-LoRA for MXFP4 weights#3663

Merged
winglian merged 15 commits into
mainfrom
feat/scattermoe-lora-mxfp4
May 28, 2026
Merged

Fused ScatterMoE-LoRA for MXFP4 weights#3663
winglian merged 15 commits into
mainfrom
feat/scattermoe-lora-mxfp4

Conversation

@winglian

@winglian winglian commented May 18, 2026

Copy link
Copy Markdown
Collaborator

Summary by CodeRabbit

  • New Features

    • Added support for MXFP4 quantized expert weights in ScatterMoE LoRA, enabling memory-efficient training with compressed model parameters.
  • Tests

    • Added integration tests validating MXFP4 expert weight functionality across forward and backward passes.
    • Added performance benchmarks comparing MXFP4 strategies against baseline configurations.

Review Change Stack

@coderabbitai

coderabbitai Bot commented May 18, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 3806f3b6-91f9-4388-add1-ef9a177af3dc

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds MXFP4 support to ScatterMoE LoRA with fused forward and backward Triton kernels that dequantize expert weights on-the-fly, a new MXWeights container type for expert quantization metadata, selective expert dequantization for both strategies, integration into ScatterMoELoRA's forward and backward passes, and comprehensive unit, integration, and performance tests.

Changes

MXFP4 Fused Kernels for ScatterMoE LoRA

Layer / File(s) Summary
MXFP4 Type Foundation and FP4 Codebook
src/axolotl/integrations/kernels/libs/scattermoe_lora/mx_weights.py
MXWeights dataclass holds packed uint8 qdata and E8M0 scales; MXLayout enum defines layout modes; fp4_codebook provides per-device cached FP4 lookup tensors; helpers extract qdata/scales from torchao MXTensor and construct expert subsets by slicing.
Selective MXFP4 Dequantization
src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.py
is_mxfp4_param detects torchao MXFP4 parameters; _selective_dequant_mxfp4 extracts active-expert qdata/scale subsets and dequantizes via torchao; selective_expert_weights dispatch routes MXFP4 to the new dequant path.
MXFP4 Forward Fused Kernel
src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py (lines 2244–2748)
_compute_expert_block_lora_mxfp4 decodes fp4 nibbles from packed buffers, looks up fp32 codebook values, applies E8M0 scaling via exp2(scale_byte - 127), and fuses base and LoRA GEMMs within the same K-loop. Public dispatcher scatter2scatter_lora_mx validates layout, loads the codebook, and launches the kernel with autotuned config and boundary masking.
MXFP4 Backward dX Fused Kernel
src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py (lines 2750–3195)
_compute_expert_block_lora_dX_mxfp4 dequantizes W on-the-fly during N-reduction using the same forward MXFP4 buffers (no transpose), accumulates dY @ W^T + dY @ B in a single pass, and applies the LoRA-A epilogue. Public dispatcher scatter2scatter_lora_dX_mx builds the grid and launches with layout validation.
ScatterMoELoRA MX Integration
src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py
Forward accepts expert_weights: Union[torch.Tensor, MXWeights], detects MX mode, and dispatches to scatter2scatter_lora_mx or tensor path. Backward branches on ctx.is_mx to use scatter2scatter_lora_dX_mx for input gradients; E and N dimensions derive from MXWeights fields when MX, else tensor shape; d_weights explicitly None for MX weights (immutable containers).
HF Layer Selective Dequant Detection
src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py
Imports is_mxfp4_param and detects MXFP4 on experts.gate_up_proj, enabling use_selective when MXFP4 or BnB quantization is present; backward selective dequantization triggers for detected quantized experts.
MXFP4 Expert Weight Tests
tests/integrations/kernels/scattermoe_lora/test_mxfp4_expert_weights.py
Strategy A (selective dequant) forward/backward verified against bf16 baseline with bitwise dX equality and active-expert-slice LoRA-grad checks; Strategy B (fused MX kernel) forward/backward validated within shape-dependent tolerances including a ratio-std drift guard for uniform-scaling bugs; all fused flag combinations tested.
MXFP4 Integration Test
tests/integrations/kernels/scattermoe_lora/test_mxfp4_integration.py
Validates against a PyTorch per-expert reference on a small DeepSeek-V4-style MoE block; both strategies (selective dequant and fused MX) produce outputs matching the reference within 5e-3 relative tolerance.
Benchmark Suite and Results
tests/integrations/kernels/scattermoe_lora/bench_mxfp4.py, tests/integrations/kernels/scattermoe_lora/bench_mxfp4_results.md
Benchmarks bf16 baseline, Strategy A, and Strategy B across dense/sparse/balanced routing modes; includes GPU peak bandwidth estimation, per-M sweeps, tokens/s throughput, HBM utilization metrics, and notes on OOM behavior and expert-saturation effects.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes


Possibly related PRs

  • axolotl-ai-cloud/axolotl#3410: Extends the existing ScatterMoE LoRA kernel/layer stack from the main PR by adding MXFP4-specific fused LoRA forward/backward paths and wiring them through the integration layer via MXWeights and selective MX dequantization.

Suggested reviewers

  • NanoCode012
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.94% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Fused ScatterMoE-LoRA for MXFP4 weights' accurately and concisely summarizes the main change: adding fused kernel support for MXFP4-quantized expert weights in ScatterMoE-LoRA.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/scattermoe-lora-mxfp4

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/axolotl/integrations/kernels/libs/scattermoe_lora/mx_weights.py (1)

212-212: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Resolve pre-commit formatting drift before merge.

CI already reports end-of-file-fixer and ruff-format modifying this file; please commit those formatting changes so lint passes.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/mx_weights.py` at line
212, Run the project's pre-commit hooks / formatters on the affected file to
resolve formatting drift: apply end-of-file-fixer (ensure a single trailing
newline) and run ruff/ruff-format (or the project's formatter) on mx_weights.py,
then stage and commit the updated file so CI no longer modifies it;
alternatively run the project's pre-commit install and `pre-commit run
--all-files`, review the changes, and commit them.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/mx_weights.py`:
- Around line 91-104: MXWeights.__post_init__ currently silently reinterprets
non-uint8 scales and doesn't validate shapes, which can corrupt buffers; change
it to require scales.dtype == torch.uint8 (raise ValueError instead of view())
and add explicit shape/rank checks tying packed and scales to K, N, and
block_size: validate packed.dtype == uint8, packed.ndim and scales.ndim, confirm
packed.size(1)/packed.size(2) or packed layout matches expected K and N derived
from attributes (and that N is divisible by MX_BLOCK_SIZE), and ensure
scales.shape matches (num_experts, N // MX_BLOCK_SIZE) or the correct per-block
layout used by your Triton kernel; only after these validations set num_experts
from packed.size(0) if None, otherwise raise descriptive errors on mismatch.
- Around line 151-162: Update the misleading ImportError strings that reference
the wrong torchao minimum version: locate the MXFP4 import guard where MXTensor
= _torchao_mxtensor_cls() and the subsequent raise ImportError call that
currently says "MXFP4 path requires torchao (install `torchao>=0.7`)" (and the
similar message later around the other check at ~line 186) and change the
messages to require "torchao>=0.17.0" so they match the documented and pinned
dependency and the constructor signature used.

In
`@src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py`:
- Around line 71-79: The mypy error comes from using is_mx =
isinstance(expert_weights, MXWeights) which doesn't narrow the union type for
the else branch; replace the boolean temp with a direct isinstance check or
explicitly narrow the type in the else branch—e.g., change "is_mx =
isinstance(expert_weights, MXWeights); if is_mx: ..." to "if
isinstance(expert_weights, MXWeights): ..." or, after the current check, add
"from typing import cast; expert_weights = cast(torch.Tensor, expert_weights)"
before using tensor-specific attributes so expert_weights (and symbols
MXWeights, MXLayout, and the else branch that accesses expert_weights.dtype) are
properly narrowed for mypy.

In `@tests/integrations/kernels/scattermoe_lora/bench_mxfp4.py`:
- Around line 122-143: The variable `template` is Optional and is dereferenced
when constructing the returned MXTensor, which trips mypy; add an assertion like
`assert template is not None` (or equivalent narrowing) immediately before the
MXTensor(...) return to guarantee to the type checker that `template` is
non-None; ensure the assertion sits right before the MXTensor construction so
references to `template.elem_dtype`, `template.block_size`,
`template.orig_dtype`, `template.kernel_preference`,
`template.act_quant_kwargs`, and `template.is_swizzled_scales` are accepted by
mypy.

In `@tests/integrations/kernels/scattermoe_lora/test_mxfp4_expert_weights.py`:
- Line 441: The local variable fwd_tol is assigned from _tol_for_shape(K) but
never used, causing a lint F841; remove the unused assignment or explicitly
discard it (e.g. call _tol_for_shape(K) and assign to _ ) in the
test_mxfp4_expert_weights.py test where fwd_tol is created so that the linter no
longer flags an unused variable; locate the statement involving fwd_tol and
either delete that line or replace the variable name with an underscore.

---

Outside diff comments:
In `@src/axolotl/integrations/kernels/libs/scattermoe_lora/mx_weights.py`:
- Line 212: Run the project's pre-commit hooks / formatters on the affected file
to resolve formatting drift: apply end-of-file-fixer (ensure a single trailing
newline) and run ruff/ruff-format (or the project's formatter) on mx_weights.py,
then stage and commit the updated file so CI no longer modifies it;
alternatively run the project's pre-commit install and `pre-commit run
--all-files`, review the changes, and commit them.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6e0becca-338b-4c18-8fd4-fe47a239927f

📥 Commits

Reviewing files that changed from the base of the PR and between d7cb1c9 and 5f7e9bb.

📒 Files selected for processing (11)
  • src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py
  • src/axolotl/integrations/kernels/libs/scattermoe_lora/layers.py
  • src/axolotl/integrations/kernels/libs/scattermoe_lora/mx_weights.py
  • src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py
  • src/axolotl/integrations/kernels/libs/scattermoe_lora/selective_dequant.py
  • tests/integrations/kernels/__init__.py
  • tests/integrations/kernels/scattermoe_lora/__init__.py
  • tests/integrations/kernels/scattermoe_lora/bench_mxfp4.py
  • tests/integrations/kernels/scattermoe_lora/bench_mxfp4_results.md
  • tests/integrations/kernels/scattermoe_lora/test_mxfp4_expert_weights.py
  • tests/integrations/kernels/scattermoe_lora/test_mxfp4_integration.py

Comment thread src/axolotl/integrations/kernels/libs/scattermoe_lora/mx_weights.py
Comment thread src/axolotl/integrations/kernels/libs/scattermoe_lora/mx_weights.py Outdated
Comment thread src/axolotl/integrations/kernels/libs/scattermoe_lora/parallel_linear_lora.py Outdated
Comment thread tests/integrations/kernels/scattermoe_lora/bench_mxfp4.py
@github-actions

github-actions Bot commented May 18, 2026

Copy link
Copy Markdown
Contributor

📖 Documentation Preview: https://6a17b69b033abf02a98fa668--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 005c12e

@codecov

codecov Bot commented May 18, 2026

Copy link
Copy Markdown

winglian added 15 commits May 28, 2026 03:16
Add an MXFP4 branch to `selective_expert_weights()` that detects a
torchao `MXTensor` parameter (elem_dtype=float4_e2m1fn_x2) and
dequantizes only the active experts via index-then-construct of a
compact sub-MXTensor. The K-axis OCP block layout (last storage dim)
matches `experts.gate_up_proj` natural shape `[E, N, K]`, so the
caller's existing `.transpose(2, 1)` post-step keeps producing the
kernel's `[E, K, N]` weight tile unchanged.

`HFScatterMoEGatedMLP.forward` now also routes through the selective
path whenever the experts hold MXFP4 weights — full-tensor MX dequant
of 256-expert models is prohibitive and the kernel needs bf16 input.

Tests (CUDA-only) compare against a bf16 baseline produced by the
same MXTensor's full dequant; outputs are bitwise identical for both
forward and backward (dX, dA, dB) across small [E=8,K=128,N=256] and
representative [E=32,K=2048,N=1024] shapes, and across all four
combinations of \`use_fused_dX\` / \`use_fused_gather\`.

Signed-off-by: Wing Lian <wing@axolotl.ai>
Add MX-aware forward and dX kernels that consume an ``MXWeights``
container (packed uint8 + E8M0 scales) directly, so the base-weight
tile is dequantized inside the K-loop instead of through a materialized
bf16 buffer. The K-loop loads two FP4 values per uint8 byte, looks them
up in a 16-entry codebook tensor (``±{0, 0.5, 1, 1.5, 2, 3, 4, 6}``),
multiplies by ``2^(scale_byte - 127)``, and casts to bf16 for the
matmul. ``BLOCK_K`` is constrained to a multiple of the OCP block size
(32) so each tile aligns with whole scale blocks; an MX-aware autotune
pruner accounts for the extra packed/scale SMEM.

The dX kernel reuses the *forward* MX layout (block axis = K, the dX
output axis) — for each (K_tile, N_tile) sub-tile, nibbles decode
along the K rows (the byte is shared by two adjacent K rows) and
scales broadcast within their MX block. This avoids the
dequant + re-quantize "pre-transpose" the spec suggested and the
extra MX-rounding error that round-trip would have introduced.

``ScatterMoELoRA.forward`` now accepts either a dense tensor or an
``MXWeights``; the MX branch always selects the fused-dX and
fused-gather backward kernels (the non-fused dX path would have to
materialize a bf16 weight tile, defeating the win).

Unit tests cover forward, dX, dA, dB parity for small
[E=8, K=128, N=256] and representative [E=32, K=2048, N=1024] shapes;
tolerances are calibrated to bf16 MMA noise (atomic-add ordering and
FMA reordering between the full-E baseline and compact-active MX path).
Integration test exercises a tiny synthetic DeepSeek-V4-style MoE
block (E=8, hidden=512, intermediate=256, top_k=2) end-to-end through
both Strategy A and Strategy B with LoRA disabled.

Signed-off-by: Wing Lian <wing@axolotl.ai>
Add ``bench_mxfp4.py`` and committed results for the representative
DeepSeek-V4-style shape (E=128, K=2048, N=1024, top_k=8, M=4096,
rank=16). Reports ms/iter, tokens/s, peak GPU memory, and HBM
bandwidth utilisation for three configurations: bf16 baseline,
Strategy A (selective dequant), Strategy B (fused MX).

On the RTX PRO 6000 Blackwell, the all-active-experts shape used
here doesn't exercise selective dequant's memory savings (active = E
= 128) — A pays the cost of materialising the full bf16 dequant
buffer per step (~9 GB peak vs 1.9 GB for B) while still routing
through the bf16 kernel. B halves A's wall time (~12 ms vs 30 ms) by
eliminating the buffer, but stays slower than the bf16 baseline (5
ms) which assumes the bf16 weights already exist in memory.

Signed-off-by: Wing Lian <wing@axolotl.ai>
Signed-off-by: Wing Lian <wing@axolotl.ai>
The MX-aware autotune pruner for the forward kernel under-accounted
SMEM: it computed the packed-tile cost as BLOCK_N * BLOCK_K/2 and the
scale-tile cost as BLOCK_N * BLOCK_K/MX_BLOCK_SIZE, but the actual
tl.load issues a full [BLOCK_N, BLOCK_K]-shaped uint8 fetch for both
buffers (the packed buffer reads each byte twice because K_byte =
K // 2 indexes a [BLOCK_K]-wide vector; the scale buffer broadcasts
within each MX_BLOCK_SIZE K-block). Bring the forward pruner up to the
same conservative full-tile accounting already used by
_prune_dX_mx_configs. Without this, on the [E=128, K=2048, N=1024]
shape with the typical GPU SMEM caps, two to six high-stage configs
that were previously selectable would have overflowed SMEM at launch
under correct accounting — a silent OOM-in-the-future risk.

Signed-off-by: Wing Lian <wing@axolotl.ai>
The file-level docstring for the MXFP4 kernels described the dX kernel
as using a pre-transposed [E, K, N/2] layout produced by a
'mx_pre_transpose_for_dx' helper. That helper doesn't exist; the dX
kernel actually reuses the forward [E, N, K/2] layout, iterating the N
reduction in outer tiles and decoding nibbles along the K rows of each
tile. Rewrite the docstring to describe what the code actually does,
including the rationale — reusing the forward buffer avoids the
dequant + re-quantize round-trip that a pre-transpose would require
and keeps dX numerics free of a second MX rounding error stacked on
top of the forward quantization.

Signed-off-by: Wing Lian <wing@axolotl.ai>
F4: Hoist 'is_mxfp4_param' import from inside 'HFScatterMoEGatedMLP.forward'
to the top of layers.py — it was being re-imported every step on the hot
path.

F5: Add a thin compatibility shim for torchao MXTensor internals access in
mx_weights.py. The MX paths in selective_dequant.py / mx_weights.py used
to reach into 'mx_param.qdata', 'mx_param.scale',
'mx_param.kernel_preference' and call 'MXTensor(...)' with positional
args directly. That works at the pinned torchao 0.17.0 but is fragile to
internal renames in future torchao releases. Funnel through three
helpers — '_mx_qdata', '_mx_scale', '_construct_mxtensor_subset' — that
use 'getattr' fallbacks for the buffer attributes and pass the
constructor's optional args via 'getattr' too. Single point of pain,
no API change.

F7: Remove the unused 'NO_K_MASK' heuristic + tl.constexpr param from
the dX MX kernel '_scatter2scatter_lora_dX_mx'. The dX kernel never
references it (its inner loop masks N, not K), so the constexpr just
forced extra autotune key entries.

F8: Consolidate the duplicate '_torchao_mxtensor_cls()' definitions
(one in selective_dequant.py, one in mx_weights.py) into a single
definition in mx_weights.py. selective_dequant.py imports it.

Signed-off-by: Wing Lian <wing@axolotl.ai>
F3: 'test_strategy_a_backward_fused_variants' previously used
'torch.ones_like(output)' as the grad input and asserted only on dX.
A uniform grad zeros out cross-token differences in the fused-gather
accumulation, masking reordering bugs; restricting the assertion to dX
silently let the dA/dB paths go unchecked across the four
'(use_fused_dX, use_fused_gather)' production variants.

  * Drive the backward with 'torch.randn_like(output) * 0.1'.
  * Capture and assert dA and dB parity across all four variants
    using the same 'row_idx' gather pattern as
    'test_strategy_a_backward_matches_bf16'.
  * Forward and dX are still asserted bitwise via 'torch.equal'. dA/dB
    fall back to atol/rtol = 1e-3 because the fused dA/dB kernel uses
    'atomic_add' across N-block programs and the in-flight program
    count differs between the full-E baseline and the compact-active
    path; combined with FMA reordering, the 'use_fused_dX=True'
    variants accumulate ~1 bf16 ULP of unavoidable atomic-order noise.
    The new bound is still an order of magnitude below that noise
    floor, so it catches real bugs.

F9: The 'test_strategy_b_backward_matches_bf16' dX comparison runs at
'atol=0.5, rtol=2e-2' (small) / 'atol=2.0, rtol=3e-2' (representative)
to allow for accumulated bf16 MMA noise over the N reduction. Those
bounds are appropriate for legitimate per-element drift but would also
admit a uniform multiplicative bug — e.g. an off-by-one on the E8M0
exponent that scales every dX element by 2x.

Add a guard alongside the existing 'torch.allclose': mask out
near-zero baseline elements (relative to 'bf16_dX.abs().max()'), then
require the per-element ratio 'mx_dX / bf16_dX' to have std < 0.5. A
uniform multiplicative bug pushes that std to ~0 while the mean shifts;
a real-bug per-element drift pushes the std up. This crosscuts the
allclose check rather than replacing it.

Signed-off-by: Wing Lian <wing@axolotl.ai>
The previous bench harness did a fresh '.clone()' of x and a
'requires_grad_(True)' on cloned lora A/B tensors every iter inside
the timed window. That accounts for buffer allocation, not kernel
cost, and biases the numbers toward whichever path produced the
smallest activations. Restructure the runners so:

  * 'x' is cloned once into a leaf tensor with 'requires_grad_(True)'
    inside 'bench()' (outside the timed warmup + timed loop).
  * LoRA A/B leaf tensors are constructed once in the runner factory,
    not per iter.
  * Each iter calls the runner which sets 'x.grad = A.grad = B.grad =
    None' (cheap, no GPU sync) so the autograd graph for the timed
    iteration is fresh and grads don't accumulate.

Re-run all three configs end-to-end after this change (dense E=128,
sparse E=256 / 10-active, balanced E=256 M-sweep at M ∈ {256, 1024,
4096, 16384}) and refresh the numbers in bench_mxfp4_results.md.
Headers and table structure are unchanged. The qualitative ordering
holds (Strategy A wins at low active/E, Strategy B wins near
active/E ≈ 1, and Strategy A still OOMs across the balanced sweep on
the workstation with vLLM colocated), with per-cell numbers within
single-digit percent of the prior runs.

Signed-off-by: Wing Lian <wing@axolotl.ai>
…arity assertions

Wing's "lint and PR review fixes" commit (9007a82) reverted three fixes
from the prior lint pass. Restore them:

1. parallel_linear_lora.py: use isinstance(expert_weights, MXWeights)
   directly so mypy can narrow the union — the `is_mx` boolean alias
   blocks narrowing and re-introduces 2 union-attr errors.

2. bench_mxfp4.py: assert template is not None before the MXTensor(...)
   constructor — the chunked converter initializes template to None
   then sets it inside the loop, which mypy can't prove non-None at
   the call site (6 None-attr errors).

3. test_mxfp4_expert_weights.py: the F841 on fwd_tol was actually a
   smell of dropped logic. Both backward tests
   (test_strategy_a_backward_matches_bf16 and
   test_strategy_b_backward_matches_bf16) compute the forward outputs
   out_b/out_a/out_s, run backward, and assert gradients match — but
   never assert that the forward outputs match. A forward bug
   producing a constant offset (and therefore zero gradient delta)
   would slip past the bwd-only checks. Add the missing
   torch.equal(out_b, out_a) for Strategy A (bitwise contract) and
   torch.allclose(out_b, out_s, **fwd_tol) for Strategy B (MX tol).

Signed-off-by: Wing Lian <wing@axolotl.ai>
@winglian winglian force-pushed the feat/scattermoe-lora-mxfp4 branch from 3f64bbd to 005c12e Compare May 28, 2026 03:22
@winglian winglian merged commit 5c1a266 into main May 28, 2026
15 of 16 checks passed
@winglian winglian deleted the feat/scattermoe-lora-mxfp4 branch May 28, 2026 14:32
winglian added a commit that referenced this pull request May 29, 2026
Earlier pass rejected fp8/nvfp4/mxfp4 at the schema layer, telling
users to use QAT/PTQ instead. That was wrong:

- NVFP4 has a real weight-only torchao config (NVFP4WeightOnlyConfig
  in torchao.prototype.mx_formats) — it's a 4-bit quant, perfectly
  suited to QLoRA. Now auto-promotes adapter lora -> qlora and
  builds NVFP4WeightOnlyConfig at load.
- FP8 (float8_e4m3fn) has Float8WeightOnlyConfig in torchao.quantization
  — a one-byte-per-weight quant that mirrors INT8's role. Keeps
  adapter as lora.
- MXFP4 is the genuine 'no weight-only flavor' case. The schema now
  passes it through; the loader raises with a pointer to
  quantize_moe_experts: true for MoE models (which is where MXFP4
  LoRA actually lives, via the ScatterMoE-LoRA path landed in #3663)
  and to qat/ptq for inference-time MXFP4.

CUDA smoke-tested on SmolLM2-135M:
- weight_dtype: fp8  -> Float8WeightOnlyConfig, forward+backward OK
- weight_dtype: nvfp4 (group_size=16) -> NVFP4WeightOnlyConfig, OK
- weight_dtype: mxfp4 -> loader error pointing to quantize_moe_experts

Docs and the dtype table updated; schema/loader tests extended.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
winglian added a commit that referenced this pull request May 29, 2026
Earlier pass rejected fp8/nvfp4/mxfp4 at the schema layer, telling
users to use QAT/PTQ instead. That was wrong:

- NVFP4 has a real weight-only torchao config (NVFP4WeightOnlyConfig
  in torchao.prototype.mx_formats) — it's a 4-bit quant, perfectly
  suited to QLoRA. Now auto-promotes adapter lora -> qlora and
  builds NVFP4WeightOnlyConfig at load.
- FP8 (float8_e4m3fn) has Float8WeightOnlyConfig in torchao.quantization
  — a one-byte-per-weight quant that mirrors INT8's role. Keeps
  adapter as lora.
- MXFP4 is the genuine 'no weight-only flavor' case. The schema now
  passes it through; the loader raises with a pointer to
  quantize_moe_experts: true for MoE models (which is where MXFP4
  LoRA actually lives, via the ScatterMoE-LoRA path landed in #3663)
  and to qat/ptq for inference-time MXFP4.

CUDA smoke-tested on SmolLM2-135M:
- weight_dtype: fp8  -> Float8WeightOnlyConfig, forward+backward OK
- weight_dtype: nvfp4 (group_size=16) -> NVFP4WeightOnlyConfig, OK
- weight_dtype: mxfp4 -> loader error pointing to quantize_moe_experts

Docs and the dtype table updated; schema/loader tests extended.
@ved1beta ved1beta mentioned this pull request Jun 8, 2026
8 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

scheduled_release This PR is slated for the upcoming release

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant