gpt-oss MXFP4: cross-version loader patch for transformers 4.x + 5.x#611
gpt-oss MXFP4: cross-version loader patch for transformers 4.x + 5.x#611danielhanchen wants to merge 8 commits into
Conversation
The function-based load_and_swizzle_mxfp4 in transformers 4.56.x was replaced in 5.x by the WeightConverter-based Mxfp4Deserialize class. That converter silently skips when the checkpoint key names already match the registered parameter names, which is exactly the case for Mxfp4GptOssExperts.gate_up_proj_blocks / _scales. The end result on transformers 5.x is an MXFP4 model that finishes loading with raw blocks and scales left on the module, and the first forward falls into the property fallback that calls dequantize(blocks, scales) with a loader-flavour signature and raises. Dequantize=True on transformers 5.x is also broken: Mxfp4Dequantize returns gate_up_proj as (E, 2I, H) and down_proj as (E, H, I), but the stock GptOssExperts forward expects (E, H, 2I) / (E, I, H). The transpose was baked into convert_moe_packed_tensors in 4.x and was dropped from the 5.x path. Patch Mxfp4HfQuantizer to: - Wrap _process_model_after_weight_loading: after the original hook runs, walk the loaded model and either transpose GptOssExperts weights into the expected layout (dequantize path) or invoke swizzle_mxfp4_convertops on Mxfp4GptOssExperts modules that are still holding raw blocks/scales (native path). Under 4.x this walker is a no-op because load_and_swizzle_mxfp4 already fired. - Wrap _process_model_before_weight_loading: inspect every visible CUDA device; if any is non-Hopper, force Mxfp4Config.dequantize = True so T4 / A100 / B200 transparently land on the bf16 path instead of the Hopper-only triton_kernels MXFP4 matmul raising "Only Hopper swizzling is supported" at kernel compile time. Per-projection shape checks guard against future transformers releases producing either weight already in correct orientation. swizzle failures and missing-dependency cases raise with actionable error messages instead of silently leaving the model unrunnable. Tested on NVIDIA B200 (sm_100) with transformers 4.57.6 and 5.5.4: 3/3 greedy prompts produce identical coherent output byte-for-byte across both versions. Hopper gate routes T4/A100/B200 to dequantize, leaves H100 on native MXFP4.
There was a problem hiding this comment.
Code Review
This pull request introduces patches for Mxfp4HfQuantizer to handle changes in Transformers 5.x, specifically addressing issues where MXFP4 models load with raw blocks and scales instead of being swizzled for Triton. It also adds a hardware gate to force dequantization on non-Hopper devices to prevent kernel crashes. Review feedback identified a critical bug where the transpose logic for down_proj would be skipped when hidden_size equals intermediate_size due to ambiguous shapes, and suggested adding type checks for tensors to avoid potential AttributeError when encountering ParameterModule instances.
| p = getattr(mod, proj, None) | ||
| if p is None or p.dim() != 3: | ||
| continue | ||
| shape = tuple(p.shape[-2:]) | ||
| if shape == expected_right[proj]: | ||
| continue | ||
| if shape != expected_wrong[proj]: | ||
| _warnings.warn( | ||
| f"[unsloth] Unexpected MXFP4-dequantize " | ||
| f"layout for {type(mod).__name__}.{proj}: " | ||
| f"got {tuple(p.shape)}, expected " | ||
| f"(..., {expected_wrong[proj][0]}, " | ||
| f"{expected_wrong[proj][1]}) or (..., " | ||
| f"{expected_right[proj][0]}, " | ||
| f"{expected_right[proj][1]}). Skipping " | ||
| f"transpose; forward may fail. This " | ||
| f"usually means your transformers " | ||
| f"version changed the dequantize layout." | ||
| ) | ||
| continue |
There was a problem hiding this comment.
There are two issues here:
- Similar to the check for
gupabove,pshould be verified as a tensor to avoid anAttributeErrorif it's aParameterModule. - Bug: When
hidden_size == intermediate_size(which is true forgpt-oss-20bwhere both are 2880),expected_rightandexpected_wrongfordown_projare identical(2880, 2880). The current logic at line 204 willcontinueand skip the transpose fordown_proj, leaving the weight in the incorrect orientation even thoughneeds_transposewas detected asTrueviagate_up_proj.
Refactoring the logic to check for expected_wrong first ensures the transpose is applied correctly even when shapes are ambiguous.
for proj in ("gate_up_proj", "down_proj"):
p = getattr(mod, proj, None)
if not isinstance(p, _torch.Tensor) or p.dim() != 3:
continue
shape = tuple(p.shape[-2:])
if shape == expected_wrong[proj]:
new_p = p.data.transpose(-2, -1).contiguous()
setattr(mod, proj, _torch.nn.Parameter(
new_p, requires_grad=p.requires_grad,
))
elif shape != expected_right[proj]:
_warnings.warn(
f"[unsloth] Unexpected MXFP4-dequantize "
f"layout for {type(mod).__name__}.{proj}: "
f"got {tuple(p.shape)}, expected "
f"(..., {expected_wrong[proj][0]}, "
f"{expected_wrong[proj][1]}) or (..., "
f"{expected_right[proj][0]}, "
f"{expected_right[proj][1]}). Skipping "
f"transpose; forward may fail. This "
f"usually means your transformers "
f"version changed the dequantize layout."
)| gup = getattr(mod, "gate_up_proj", None) | ||
| if gup is None or gup.dim() != 3: | ||
| continue |
There was a problem hiding this comment.
The code should verify that gup is actually a tensor before calling .dim(). If mod is an instance of the unsloth patched GptOssExperts and the transformers loader has not replaced the attribute, gate_up_proj will be a ParameterModule (which is an nn.Module and does not have a .dim() method), leading to an AttributeError.
| gup = getattr(mod, "gate_up_proj", None) | |
| if gup is None or gup.dim() != 3: | |
| continue | |
| gup = getattr(mod, "gate_up_proj", None) | |
| if not isinstance(gup, _torch.Tensor) or gup.dim() != 3: | |
| continue |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 370e295e69
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if shape == expected_right[proj]: | ||
| continue |
There was a problem hiding this comment.
Transpose down_proj when gate_up indicates 5.x dequant layout
When dequantize=True and gate_up_proj is detected in the (E, 2I, H) layout (needs_transpose=True), down_proj still skips transposition if its last two dims equal expected_right. For gpt-oss-20b (H == I), expected_wrong["down_proj"] and expected_right["down_proj"] are both (2880, 2880), so this branch always continues and leaves down_proj in the wrong orientation. That silently changes expert outputs (wrong matrix orientation with identical shape) instead of restoring 4.x-equivalent behavior.
Useful? React with 👍 / 👎.
Strip the WHAT commentary and keep only the top-of-block WHY blurbs plus the shape-ambiguity and multi-GPU notes. No behavior change. Re-verified on B200 (sm_100) with transformers 5.5.4: 3/3 prompts produce identical output to the pre-trim commit.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7415a4fc89
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| for i in range(_torch.cuda.device_count()): | ||
| major, _ = _torch.cuda.get_device_capability(i) | ||
| if major not in HOPPER_COMPATIBLE: | ||
| self.quantization_config.dequantize = True |
There was a problem hiding this comment.
Gate dequantization by model devices, not all visible GPUs
The Hopper check currently iterates every CUDA-visible device and flips quantization_config.dequantize = True as soon as it finds any non-Hopper GPU. In mixed systems (for example, H100 + A100) where device_map places GPT-OSS only on Hopper, this still forces bf16 dequantization and disables MXFP4 unnecessarily, which can cause avoidable memory blow-ups or load failures. The check should be scoped to the devices actually selected for the model load instead of all visible adapters.
Useful? React with 👍 / 👎.
Earlier version forced dequantize=True on non-Hopper CUDA devices, assuming triton_kernels MXFP4 matmul was Hopper-only. That was wrong: triton_kernels/tensor_details/layout.py picks BlackwellMXValueLayout on sm_100, HopperMXValueLayout on sm_90, and StridedLayout (no swizzle) on older archs, and the matmul assert at _matmul_ogs.py:114 allows both "HOPPER_VALUE" and None. So T4, A100, H100, and B200 all run native MXFP4 once swizzle_mxfp4_convertops fires. The earlier B200 crash I attributed to Hopper-only was actually caused by the 5.x WeightConverter skip leaving weights unswizzled -- fixed by the post-load walker in the previous commit. Keep only the CPU-only fallback (triton_kernels needs CUDA). Re-verified on B200 (sm_100, transformers 5.5.4): native MXFP4 produces identical coherent output to the dequantize path.
|
Correction pushed: dropped the Hopper-only gate entirely. triton_kernels actually supports MXFP4 on every target:
The B200 "Only Hopper swizzling is supported" crash I originally attributed to a Hopper-only kernel was actually the 5.x WeightConverter skip leaving weights unswizzled. Once Kept only the CPU-only fallback. Re-verified on B200 with transformers 5.5.4: native MXFP4 produces the same byte-identical coherent output as the dequantize path. Updated GPU coverage table:
|
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3029ac20c6
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| not _torch.cuda.is_available() | ||
| and not getattr(self.quantization_config, "dequantize", False) | ||
| ): | ||
| self.quantization_config.dequantize = True |
There was a problem hiding this comment.
Detect non-Hopper GPUs before keeping MXFP4 quantized
The new _patched_before_load guard only flips dequantize when CUDA is unavailable, so any CUDA host with a non-Hopper GPU (for example A100/B200) still keeps dequantize=False when users request native MXFP4. In that environment the model proceeds down the quantized path and later hits the Hopper-only kernel failure (Only Hopper swizzling is supported), so this check does not enforce the fallback described in the surrounding comments.
Useful? React with 👍 / 👎.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces patches for Mxfp4HfQuantizer to address MXFP4 loading and layout issues in transformers 5.x, specifically for GptOssExperts. The changes include a post-load walker to handle tensor transpositions and a pre-load patch to manage dequantization settings. Review feedback highlights a bug in the layout check when intermediate_size equals hidden_size and suggests extending the hardware check to force dequantization on all non-Hopper GPUs to avoid compatibility issues.
| if shape == expected_right[proj]: | ||
| continue | ||
| if shape != expected_wrong[proj]: |
There was a problem hiding this comment.
When intermediate_size (I) equals hidden_size (H), which is the case for the gpt-oss-20b model (2880), the down_proj shape (H, I) is identical to the expected layout (I, H). The current check shape == expected_right[proj] will trigger in this scenario, causing the necessary transpose for down_proj to be skipped. Since the module's orientation has already been confirmed as incorrect via the gate_up_proj check on line 164, the transpose should be applied to down_proj even when the shape is ambiguous.
| if shape == expected_right[proj]: | |
| continue | |
| if shape != expected_wrong[proj]: | |
| if shape == expected_right[proj] and (proj == "gate_up_proj" or I != H): | |
| continue | |
| if shape != expected_wrong[proj] and shape != expected_right[proj]: |
| if ( | ||
| not _torch.cuda.is_available() | ||
| and not getattr(self.quantization_config, "dequantize", False) | ||
| ): | ||
| self.quantization_config.dequantize = True |
There was a problem hiding this comment.
The implementation of _patched_before_load does not include the Hopper-specific hardware check described in the pull request summary. Currently, it only forces dequantization when CUDA is completely unavailable. To prevent crashes on non-Hopper GPUs (like T4, A100, or B200) where native MXFP4 swizzling is unsupported, the logic should inspect the compute capability of all visible CUDA devices and force dequantization if any non-Hopper device is found.
| if ( | |
| not _torch.cuda.is_available() | |
| and not getattr(self.quantization_config, "dequantize", False) | |
| ): | |
| self.quantization_config.dequantize = True | |
| force_dq = not _torch.cuda.is_available() | |
| if not force_dq: | |
| for i in range(_torch.cuda.device_count()): | |
| if _torch.cuda.get_device_capability(i) != (9, 0): | |
| force_dq = True | |
| break | |
| if force_dq and not getattr(self.quantization_config, "dequantize", False): | |
| self.quantization_config.dequantize = True |
|
Possible duplicate of a trusted maintainer's PR. This PR looks like it solves the same underlying problem as unslothai/unsloth-zoo#471 by @Datta0 (trusted maintainer).
Canonical PR summary: This PR fixes GPT-OSS MoE support in Unsloth Zoo by adding grouped_mm-based LoRA inference/training paths, MXFP4 dequantization, GRPO hidden-state returns, and compatibility patches for compiler imports, expert/router patching, and attention forward flows. The auto-review is still running against this PR — reviewers will factor in the canonical above. If this PR is genuinely different, call out the delta in the review discussion so the maintainer can decide which to merge. |
- Dequantize transpose: when intermediate_size == hidden_size (gpt-oss-20b is 2880/2880), down_proj wrong (H, I) and right (I, H) are the same shape, so the per-projection skip used to leave down_proj in the wrong-layout orientation. The outer guard already proves the module came from the wrong-layout path via gate_up_proj, so transpose unconditionally for the ambiguous square case. - Native swizzle walker: mirror the Mxfp4GptOssExperts.gate_up_proj property invariant (numel > 0 and .any()) so partial / missing checkpoint loads do not silently swizzle zero placeholders into apparently-loaded zero experts. - before-load gate: also accept torch.xpu.is_available(); native MXFP4 is supported on XPU by transformers' validate_environment, and forcing dequantize there caused a 4x bf16 expansion. Wrap the config mutation in its own try/except that warns instead of swallowing, so a frozen quantization_config does not strand CPU users with no diagnostic. - Comment: replace the stale Hopper-only rationale with the actual layout pickers (Strided / Hopper / Blackwell) so future readers match the implemented gate.
- Native swizzle paths: collapse the swizzle_fn-missing, triton_kernels
import-failure, and active swizzle branches onto a shared per-projection
predicate that requires both blocks and scales to be non-meta, non-empty,
and blocks.any(). The error branches previously raised on freshly-init
Mxfp4GptOssExperts (zero placeholders) and on real-blocks-plus-meta-scales
partial loads, both of which the active branch correctly skipped.
- Per-projection skip: gate the walker on f"_{proj}" in mod.__dict__
rather than the module-wide _gate_up_proj cache. The unsloth
Mxfp4GptOssExperts has independent _gate_up_proj / _down_proj caches,
so an early access to one projection used to leave the other raw.
- Pre-load wrapper: widen the signature to (self, model, use_kernels=False,
**kwargs) to match transformers 5.7.0 and forward the flag to the
original. Honor use_kernels=True on CPU so the upstream native CPU
MXFP4 path is preserved instead of being silently dequantized.
- Pre-load detection: replace the bare except: pass around the
device-detection block with a warn so future torch.xpu / torch.cuda
API regressions surface instead of leaving the override a silent
no-op.
Merge the post-load swizzle walker and pre-load gate tests into a single behavior-named module covering both halves of patch_gpt_oss(): - Post-load native swizzle walker: zero-placeholder skip vs raw-block raise, meta-scales handling, per-projection cache skip including the gate-up-cached-but-down-raw partial case. - Pre-load wrapper: use_kernels=True keeps native CPU MXFP4, default call forces dequantize, positional argument compatibility, use_kernels forwarded to upstream, detection-failure warns.
|
Auto-review verdict: Approved PR repairs gpt-oss MXFP4 loading on transformers 5.x by adding post-load layout/swizzle walkers and a CPU-only dequantize gate to Mxfp4HfQuantizer, restoring correct gate_up_proj/down_proj orientation (including the gpt-oss-20b H==I square case) and avoiding crashes on hosts without CUDA/XPU. Reason: All accepted findings fixed across two iterations; remaining iter-3 findings were all rejected (re-raises of prior rejections, hypothetical edge cases, or contradicted by upstream). Tests pass. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 620a892b2b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| fn = _QCLS._process_model_before_weight_loading | ||
| fn.__closure__[0].cell_contents = lambda self, model, **kwargs: None |
There was a problem hiding this comment.
Restore original pre-load hook after closure stubbing
The helper _stub_orig_noop() overwrites fn.__closure__[0].cell_contents in-place and never restores it, so once any test calls it, subsequent tests in the same process execute a permanently replaced _process_model_before_weight_loading implementation. This can silently invalidate later assertions (inside this file or other test modules) because they no longer exercise the real upstream hook behavior.
Useful? React with 👍 / 👎.
Summary
unsloth/gpt-oss-20bon transformers 5.x: the 4.56.x function-basedload_and_swizzle_mxfp4was replaced by theWeightConverter-basedMxfp4Deserialize, which silently skips when the checkpoint key names already match the registered parameter names. MXFP4 models end up with rawgate_up_proj_blocks/_scaleson the module and the first forward raises inside the property fallback.Mxfp4Dequantizereturnsgate_up_projin(E, 2I, H)anddown_projin(E, H, I), but stockGptOssExperts.forwardexpects(E, H, 2I)/(E, I, H). The transpose was baked intoconvert_moe_packed_tensorsin 4.x and was dropped from the 5.x path._process_model_before_weight_loading: if any visible CUDA device is non-Hopper, forceMxfp4Config.dequantize = True. Without this, T4 / A100 / B200 users crash insidetriton_kernels.matmul_ogswithOnly Hopper swizzling is supported.Approach
Monkey-patch
Mxfp4HfQuantizervia two wrapping hooks, guarded by idempotence flags set on the class:_process_model_after_weight_loadingalways calls the original first (preservestorch.cuda.empty_cache()and any future upstream logic), then walks the model to either transposeGptOssExpertsweights into the stock layout, or callswizzle_mxfp4_convertopsonMxfp4GptOssExpertsmodules that are still holding raw blocks/scales. Under 4.x the walker is a no-op becauseload_and_swizzle_mxfp4already fired during weight load._process_model_before_weight_loadinginspects every visible CUDA device and forcesdequantize=Truewhen any is non-Hopper, and also when CUDA is unavailable.Per-projection shape checks guard against future transformers releases producing either weight already in the correct orientation. Failure modes (
triton_kernelsmissing,swizzle_mxfp4_convertopsmissing, swizzle error) raise with actionable error messages instead of leaving the model silently unrunnable.GPU coverage
Verification
Tested on NVIDIA B200 (sm_100) with both transformers versions. Output matches byte-for-byte across both versions. Greedy decode, 32 new tokens, chat template.
Test plan
patch_gpt_oss()twice in the same process is a no-opHAS_TRITON_KERNELSgate)