From 370e295e69526fafddf10586df5ac90233c6cec5 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Fri, 24 Apr 2026 04:56:02 +0000 Subject: [PATCH 1/7] gpt-oss MXFP4: cross-version loader patch for transformers 4.x + 5.x The function-based load_and_swizzle_mxfp4 in transformers 4.56.x was replaced in 5.x by the WeightConverter-based Mxfp4Deserialize class. That converter silently skips when the checkpoint key names already match the registered parameter names, which is exactly the case for Mxfp4GptOssExperts.gate_up_proj_blocks / _scales. The end result on transformers 5.x is an MXFP4 model that finishes loading with raw blocks and scales left on the module, and the first forward falls into the property fallback that calls dequantize(blocks, scales) with a loader-flavour signature and raises. Dequantize=True on transformers 5.x is also broken: Mxfp4Dequantize returns gate_up_proj as (E, 2I, H) and down_proj as (E, H, I), but the stock GptOssExperts forward expects (E, H, 2I) / (E, I, H). The transpose was baked into convert_moe_packed_tensors in 4.x and was dropped from the 5.x path. Patch Mxfp4HfQuantizer to: - Wrap _process_model_after_weight_loading: after the original hook runs, walk the loaded model and either transpose GptOssExperts weights into the expected layout (dequantize path) or invoke swizzle_mxfp4_convertops on Mxfp4GptOssExperts modules that are still holding raw blocks/scales (native path). Under 4.x this walker is a no-op because load_and_swizzle_mxfp4 already fired. - Wrap _process_model_before_weight_loading: inspect every visible CUDA device; if any is non-Hopper, force Mxfp4Config.dequantize = True so T4 / A100 / B200 transparently land on the bf16 path instead of the Hopper-only triton_kernels MXFP4 matmul raising "Only Hopper swizzling is supported" at kernel compile time. Per-projection shape checks guard against future transformers releases producing either weight already in correct orientation. swizzle failures and missing-dependency cases raise with actionable error messages instead of silently leaving the model unrunnable. Tested on NVIDIA B200 (sm_100) with transformers 4.57.6 and 5.5.4: 3/3 greedy prompts produce identical coherent output byte-for-byte across both versions. Hopper gate routes T4/A100/B200 to dequantize, leaves H100 on native MXFP4. --- unsloth_zoo/temporary_patches/gpt_oss.py | 247 +++++++++++++++++++++++ 1 file changed, 247 insertions(+) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 7bcd0a333..a31223bd3 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -128,6 +128,253 @@ def patch_gpt_oss(): except Exception as e: return raise_error("transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e) + # Transformers 5.x moved the MXFP4 load path from the function-based + # ``load_and_swizzle_mxfp4`` (4.56.x) to the WeightConverter-based + # ``Mxfp4Deserialize`` class. That converter does not actually fire when + # the checkpoint's key names already match the module's registered + # parameters (``gate_up_proj_blocks`` / ``_scales``) — so under 5.x, + # MXFP4 models load with raw blocks + scales left on the module and the + # swizzle-to-triton step never runs. Post-load the ``gate_up_proj`` + # property's fallback trips the ``dequantize(blocks, scales)`` signature + # mismatch (``dequantize`` is the loader-flavour in 5.x). + # + # Fix: wrap ``_process_model_after_weight_loading`` to walk the model + # after weight load, detect any ``Mxfp4GptOssExperts`` still holding + # raw blocks/scales, and invoke ``swizzle_mxfp4_convertops`` on them. + # This mirrors what the pre-5.x ``load_and_swizzle_mxfp4`` did. + try: + _Mxfp4Q = transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer + if not getattr(_Mxfp4Q, "_unsloth_post_swizzle_patched", False): + _orig_post_load = _Mxfp4Q._process_model_after_weight_loading + + def _patched_post_load(self, model, **kwargs): + # Always let the original post-load hook run first: on 5.x it + # performs ``torch.cuda.empty_cache()`` and may grow to carry + # additional logic in future point releases. Skipping it + # leaks VRAM and risks silently bypassing upstream fixes. + _orig_post_load(self, model, **kwargs) + import torch as _torch + import warnings as _warnings + dq = getattr(self.quantization_config, "dequantize", False) + if dq: + # transformers 5.x ``Mxfp4Dequantize.convert`` returns + # the tensor in dequantize layout -- gate_up_proj: + # ``(E, 2I, H)``, down_proj: ``(E, H, I)``. The stock + # ``GptOssExperts`` expects gate_up_proj as + # ``(E, H, 2I)`` (see modeling_gpt_oss.py:75) and + # down_proj as ``(E, I, H)`` (modeling_gpt_oss.py:77). + # In transformers 4.x the transpose was baked into + # ``convert_moe_packed_tensors``; in 5.x it was + # removed and not restored elsewhere. + # + # Per-projection shape check: guards against a future + # transformers release producing either weight already + # in the correct orientation. For gpt-oss-20b + # (``H == I == 2880``) ``down_proj``'s shape alone is + # ambiguous, so we additionally use ``gate_up_proj``'s + # unambiguous ``(E, 2I, H)`` as an indicator that the + # whole module came out of the 5.x dequantize path. + for mod in model.modules(): + if type(mod).__name__ != "GptOssExperts": + continue + H = getattr(mod, "hidden_size", None) + I = getattr(mod, "intermediate_size", None) + if H is None or I is None: + continue + gup = getattr(mod, "gate_up_proj", None) + if gup is None or gup.dim() != 3: + continue + _, D1, D2 = gup.shape + needs_transpose = (D1 == 2 * I and D2 == H) + if not needs_transpose: + continue + expected_wrong = { + "gate_up_proj": (2 * I, H), + "down_proj": (H, I), + } + expected_right = { + "gate_up_proj": (H, 2 * I), + "down_proj": (I, H), + } + for proj in ("gate_up_proj", "down_proj"): + p = getattr(mod, proj, None) + if p is None or p.dim() != 3: + continue + shape = tuple(p.shape[-2:]) + if shape == expected_right[proj]: + continue + if shape != expected_wrong[proj]: + _warnings.warn( + f"[unsloth] Unexpected MXFP4-dequantize " + f"layout for {type(mod).__name__}.{proj}: " + f"got {tuple(p.shape)}, expected " + f"(..., {expected_wrong[proj][0]}, " + f"{expected_wrong[proj][1]}) or (..., " + f"{expected_right[proj][0]}, " + f"{expected_right[proj][1]}). Skipping " + f"transpose; forward may fail. This " + f"usually means your transformers " + f"version changed the dequantize layout." + ) + continue + new_p = p.data.transpose(-2, -1).contiguous() + setattr(mod, proj, _torch.nn.Parameter( + new_p, requires_grad=p.requires_grad, + )) + return + + # Native MXFP4 (dequantize=False) path. + # The 4.x loader path swizzled during weight load and + # already dropped blocks/scales -- so this walker finds + # nothing to do and is a no-op there. Under 5.x the + # WeightConverter dispatch does not fire for matching + # parameter names, leaving raw blocks/scales behind; we + # invoke the tensor-level swizzle here to mirror 4.x. + import transformers.integrations.mxfp4 as _mx_mod + swizzle_fn = getattr(_mx_mod, "swizzle_mxfp4_convertops", None) + if swizzle_fn is None: + # Check whether any module still needs swizzling before + # raising; if 4.x swizzled during load, there is nothing + # to do and we can silently no-op. + for mod in model.modules(): + if type(mod).__name__ != "Mxfp4GptOssExperts": + continue + if "_gate_up_proj" in mod.__dict__: + continue + if getattr(mod, "gate_up_proj_blocks", None) is not None: + raise RuntimeError( + "[unsloth] MXFP4 model has raw " + "gate_up_proj_blocks / _scales post-load but " + "transformers.integrations.mxfp4." + "swizzle_mxfp4_convertops is not available. " + "Either upgrade transformers to >=4.56 (with " + "load_and_swizzle_mxfp4) or to a 5.x that " + "exposes swizzle_mxfp4_convertops, or pass " + "Mxfp4Config(dequantize=True) to fall back " + "to bf16." + ) + return + try: + import triton_kernels as _tk + except Exception as _tk_err: + # Non-Hopper users hit the pre-load Hopper gate and get + # ``dequantize=True``, so they never reach here. Hopper + # users without triton_kernels genuinely cannot run the + # native MXFP4 path and need a clear error. + for mod in model.modules(): + if type(mod).__name__ != "Mxfp4GptOssExperts": + continue + if "_gate_up_proj" in mod.__dict__: + continue + if getattr(mod, "gate_up_proj_blocks", None) is not None: + raise RuntimeError( + "[unsloth] Native MXFP4 requires " + "triton_kernels, but it is not installed. " + "Install via `pip install git+https://" + "github.com/triton-lang/triton.git@" + "0add68262ab0a2e33b84524346cb27cbb2787356" + "#subdirectory=python/triton_kernels`, " + "or pass Mxfp4Config(dequantize=True) to " + "fall back to bf16." + ) from _tk_err + return + for mod in model.modules(): + if type(mod).__name__ != "Mxfp4GptOssExperts": + continue + if "_gate_up_proj" in mod.__dict__: + continue # already swizzled + blocks = getattr(mod, "gate_up_proj_blocks", None) + scales = getattr(mod, "gate_up_proj_scales", None) + if blocks is None or scales is None: + continue + if blocks.device.type == "meta": + continue + for proj in ("gate_up_proj", "down_proj"): + b = getattr(mod, f"{proj}_blocks", None) + s = getattr(mod, f"{proj}_scales", None) + if b is None or s is None: + continue + try: + swizzle_fn(b.data, s.data, mod, proj, b.device, _tk) + except Exception as _sw_err: + raise RuntimeError( + f"[unsloth] MXFP4 swizzle failed on " + f"{type(mod).__name__}.{proj} " + f"(shape={tuple(b.shape)}). This usually " + f"means triton_kernels has drifted out of " + f"sync with transformers; install the " + f"pinned build per the gpt-oss notebook." + ) from _sw_err + # Drop the now-consumed raw parameters so the + # property short-circuits to ``_gate_up_proj``. + if f"{proj}_blocks" in mod._parameters: + del mod._parameters[f"{proj}_blocks"] + if f"{proj}_scales" in mod._parameters: + del mod._parameters[f"{proj}_scales"] + + _Mxfp4Q._process_model_after_weight_loading = _patched_post_load + _Mxfp4Q._unsloth_post_swizzle_patched = True + + # Native triton_kernels MXFP4 matmul is Hopper-only + # (``tl.static_assert(SWIZZLE_MX_VALUE == "HOPPER_VALUE" or None)`` + # inside matmul_ogs). On Ampere/Ada/Blackwell the kernel aborts + # compilation with "Only Hopper swizzling is supported". Force + # ``dequantize=True`` on non-Hopper CUDA devices before + # ``_process_model_before_weight_loading`` builds the module + # replacement plan, so these users transparently land on the + # bf16 path instead of a confusing kernel-level crash. + if not getattr(_Mxfp4Q, "_unsloth_hopper_gate_patched", False): + _orig_before_load = _Mxfp4Q._process_model_before_weight_loading + + def _patched_before_load(self, model, **kwargs): + try: + import torch as _torch + if ( + _torch.cuda.is_available() + and not getattr(self.quantization_config, "dequantize", False) + ): + # Inspect every visible CUDA device rather than + # hard-coding index 0: on a multi-GPU setup the + # user may be placing the model on a non-default + # device (``device_map``, ``CUDA_VISIBLE_DEVICES``). + # If ANY target device is non-Hopper, the native + # MXFP4 matmul would crash at kernel compile time + # with ``Only Hopper swizzling is supported``. + # Conservative fallback: if any visible device is + # non-Hopper (sm_90), route everyone to bf16. + n = _torch.cuda.device_count() + # Hopper = sm_90 (H100, H200); native MXFP4 ok. + # Ampere/Ada = sm_80/sm_86/sm_89 -- unsupported. + # Blackwell = sm_100 (B200) -- unsupported by + # triton_kernels MXFP4 path as of commit 0add68262a. + # When a future triton_kernels release adds + # additional architectures, extend this set. + HOPPER_COMPATIBLE = {9} + non_hopper = False + for i in range(n): + major, _minor = _torch.cuda.get_device_capability(i) + if major not in HOPPER_COMPATIBLE: + non_hopper = True + break + if non_hopper: + self.quantization_config.dequantize = True + elif not _torch.cuda.is_available(): + # CPU-only environments cannot run native MXFP4 at + # all (triton_kernels needs CUDA). Force dequantize + # so users at least get a runnable bf16 model. + if not getattr(self.quantization_config, "dequantize", False): + self.quantization_config.dequantize = True + except Exception: + pass + return _orig_before_load(self, model, **kwargs) + + _Mxfp4Q._process_model_before_weight_loading = _patched_before_load + _Mxfp4Q._unsloth_hopper_gate_patched = True + except Exception as e: + return raise_error( + "transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer post-load swizzle patch", e, + ) + if HAS_TRITON_KERNELS: # Only override is_kernels_available when triton_kernels IS available try: From 7415a4fc891bf2d426bf4497417bf3197653dd9f Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Fri, 24 Apr 2026 05:19:53 +0000 Subject: [PATCH 2/7] gpt-oss MXFP4: drop redundant inline comments Strip the WHAT commentary and keep only the top-of-block WHY blurbs plus the shape-ambiguity and multi-GPU notes. No behavior change. Re-verified on B200 (sm_100) with transformers 5.5.4: 3/3 prompts produce identical output to the pre-trim commit. --- unsloth_zoo/temporary_patches/gpt_oss.py | 158 +++++------------------ 1 file changed, 35 insertions(+), 123 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index a31223bd3..e930c431a 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -128,52 +128,25 @@ def patch_gpt_oss(): except Exception as e: return raise_error("transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e) - # Transformers 5.x moved the MXFP4 load path from the function-based - # ``load_and_swizzle_mxfp4`` (4.56.x) to the WeightConverter-based - # ``Mxfp4Deserialize`` class. That converter does not actually fire when - # the checkpoint's key names already match the module's registered - # parameters (``gate_up_proj_blocks`` / ``_scales``) — so under 5.x, - # MXFP4 models load with raw blocks + scales left on the module and the - # swizzle-to-triton step never runs. Post-load the ``gate_up_proj`` - # property's fallback trips the ``dequantize(blocks, scales)`` signature - # mismatch (``dequantize`` is the loader-flavour in 5.x). - # - # Fix: wrap ``_process_model_after_weight_loading`` to walk the model - # after weight load, detect any ``Mxfp4GptOssExperts`` still holding - # raw blocks/scales, and invoke ``swizzle_mxfp4_convertops`` on them. - # This mirrors what the pre-5.x ``load_and_swizzle_mxfp4`` did. + # transformers 5.x moved MXFP4 loading from ``load_and_swizzle_mxfp4`` + # (4.56.x) to the ``Mxfp4Deserialize`` WeightConverter, which is skipped + # whenever checkpoint key names already match registered parameters -- + # so MXFP4 models load with raw blocks/scales still on the module and + # the swizzle step never runs. 5.x ``Mxfp4Dequantize.convert`` also + # drops the transpose that 4.x baked into ``convert_moe_packed_tensors``, + # so gate_up_proj ends up as (E, 2I, H) vs the (E, H, 2I) the stock + # forward expects. Fix both in a post-load walker. try: _Mxfp4Q = transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer if not getattr(_Mxfp4Q, "_unsloth_post_swizzle_patched", False): _orig_post_load = _Mxfp4Q._process_model_after_weight_loading def _patched_post_load(self, model, **kwargs): - # Always let the original post-load hook run first: on 5.x it - # performs ``torch.cuda.empty_cache()`` and may grow to carry - # additional logic in future point releases. Skipping it - # leaks VRAM and risks silently bypassing upstream fixes. _orig_post_load(self, model, **kwargs) import torch as _torch import warnings as _warnings dq = getattr(self.quantization_config, "dequantize", False) if dq: - # transformers 5.x ``Mxfp4Dequantize.convert`` returns - # the tensor in dequantize layout -- gate_up_proj: - # ``(E, 2I, H)``, down_proj: ``(E, H, I)``. The stock - # ``GptOssExperts`` expects gate_up_proj as - # ``(E, H, 2I)`` (see modeling_gpt_oss.py:75) and - # down_proj as ``(E, I, H)`` (modeling_gpt_oss.py:77). - # In transformers 4.x the transpose was baked into - # ``convert_moe_packed_tensors``; in 5.x it was - # removed and not restored elsewhere. - # - # Per-projection shape check: guards against a future - # transformers release producing either weight already - # in the correct orientation. For gpt-oss-20b - # (``H == I == 2880``) ``down_proj``'s shape alone is - # ambiguous, so we additionally use ``gate_up_proj``'s - # unambiguous ``(E, 2I, H)`` as an indicator that the - # whole module came out of the 5.x dequantize path. for mod in model.modules(): if type(mod).__name__ != "GptOssExperts": continue @@ -185,8 +158,10 @@ def _patched_post_load(self, model, **kwargs): if gup is None or gup.dim() != 3: continue _, D1, D2 = gup.shape - needs_transpose = (D1 == 2 * I and D2 == H) - if not needs_transpose: + # gate_up_proj shape is unambiguous (2I != H); + # down_proj shape is ambiguous when I == H + # (gpt-oss-20b: 2880 / 2880). + if not (D1 == 2 * I and D2 == H): continue expected_wrong = { "gate_up_proj": (2 * I, H), @@ -207,14 +182,7 @@ def _patched_post_load(self, model, **kwargs): _warnings.warn( f"[unsloth] Unexpected MXFP4-dequantize " f"layout for {type(mod).__name__}.{proj}: " - f"got {tuple(p.shape)}, expected " - f"(..., {expected_wrong[proj][0]}, " - f"{expected_wrong[proj][1]}) or (..., " - f"{expected_right[proj][0]}, " - f"{expected_right[proj][1]}). Skipping " - f"transpose; forward may fail. This " - f"usually means your transformers " - f"version changed the dequantize layout." + f"got {tuple(p.shape)}. Skipping transpose." ) continue new_p = p.data.transpose(-2, -1).contiguous() @@ -223,19 +191,9 @@ def _patched_post_load(self, model, **kwargs): )) return - # Native MXFP4 (dequantize=False) path. - # The 4.x loader path swizzled during weight load and - # already dropped blocks/scales -- so this walker finds - # nothing to do and is a no-op there. Under 5.x the - # WeightConverter dispatch does not fire for matching - # parameter names, leaving raw blocks/scales behind; we - # invoke the tensor-level swizzle here to mirror 4.x. import transformers.integrations.mxfp4 as _mx_mod swizzle_fn = getattr(_mx_mod, "swizzle_mxfp4_convertops", None) if swizzle_fn is None: - # Check whether any module still needs swizzling before - # raising; if 4.x swizzled during load, there is nothing - # to do and we can silently no-op. for mod in model.modules(): if type(mod).__name__ != "Mxfp4GptOssExperts": continue @@ -243,24 +201,15 @@ def _patched_post_load(self, model, **kwargs): continue if getattr(mod, "gate_up_proj_blocks", None) is not None: raise RuntimeError( - "[unsloth] MXFP4 model has raw " - "gate_up_proj_blocks / _scales post-load but " - "transformers.integrations.mxfp4." - "swizzle_mxfp4_convertops is not available. " - "Either upgrade transformers to >=4.56 (with " - "load_and_swizzle_mxfp4) or to a 5.x that " - "exposes swizzle_mxfp4_convertops, or pass " - "Mxfp4Config(dequantize=True) to fall back " - "to bf16." + "[unsloth] MXFP4 model has raw blocks/scales " + "post-load but swizzle_mxfp4_convertops is " + "not available; pass Mxfp4Config(dequantize=" + "True) to fall back to bf16." ) return try: import triton_kernels as _tk except Exception as _tk_err: - # Non-Hopper users hit the pre-load Hopper gate and get - # ``dequantize=True``, so they never reach here. Hopper - # users without triton_kernels genuinely cannot run the - # native MXFP4 path and need a clear error. for mod in model.modules(): if type(mod).__name__ != "Mxfp4GptOssExperts": continue @@ -269,20 +218,15 @@ def _patched_post_load(self, model, **kwargs): if getattr(mod, "gate_up_proj_blocks", None) is not None: raise RuntimeError( "[unsloth] Native MXFP4 requires " - "triton_kernels, but it is not installed. " - "Install via `pip install git+https://" - "github.com/triton-lang/triton.git@" - "0add68262ab0a2e33b84524346cb27cbb2787356" - "#subdirectory=python/triton_kernels`, " - "or pass Mxfp4Config(dequantize=True) to " - "fall back to bf16." + "triton_kernels; install it or pass " + "Mxfp4Config(dequantize=True) for bf16." ) from _tk_err return for mod in model.modules(): if type(mod).__name__ != "Mxfp4GptOssExperts": continue if "_gate_up_proj" in mod.__dict__: - continue # already swizzled + continue blocks = getattr(mod, "gate_up_proj_blocks", None) scales = getattr(mod, "gate_up_proj_scales", None) if blocks is None or scales is None: @@ -299,14 +243,9 @@ def _patched_post_load(self, model, **kwargs): except Exception as _sw_err: raise RuntimeError( f"[unsloth] MXFP4 swizzle failed on " - f"{type(mod).__name__}.{proj} " - f"(shape={tuple(b.shape)}). This usually " - f"means triton_kernels has drifted out of " - f"sync with transformers; install the " - f"pinned build per the gpt-oss notebook." + f"{type(mod).__name__}.{proj}; likely " + f"triton_kernels / transformers drift." ) from _sw_err - # Drop the now-consumed raw parameters so the - # property short-circuits to ``_gate_up_proj``. if f"{proj}_blocks" in mod._parameters: del mod._parameters[f"{proj}_blocks"] if f"{proj}_scales" in mod._parameters: @@ -315,55 +254,28 @@ def _patched_post_load(self, model, **kwargs): _Mxfp4Q._process_model_after_weight_loading = _patched_post_load _Mxfp4Q._unsloth_post_swizzle_patched = True - # Native triton_kernels MXFP4 matmul is Hopper-only - # (``tl.static_assert(SWIZZLE_MX_VALUE == "HOPPER_VALUE" or None)`` - # inside matmul_ogs). On Ampere/Ada/Blackwell the kernel aborts - # compilation with "Only Hopper swizzling is supported". Force - # ``dequantize=True`` on non-Hopper CUDA devices before - # ``_process_model_before_weight_loading`` builds the module - # replacement plan, so these users transparently land on the - # bf16 path instead of a confusing kernel-level crash. + # triton_kernels MXFP4 matmul is Hopper-only; Ampere/Ada/Blackwell + # abort with "Only Hopper swizzling is supported". Force dequantize + # on any non-Hopper (or CPU-only) environment before load. if not getattr(_Mxfp4Q, "_unsloth_hopper_gate_patched", False): _orig_before_load = _Mxfp4Q._process_model_before_weight_loading def _patched_before_load(self, model, **kwargs): try: import torch as _torch - if ( - _torch.cuda.is_available() - and not getattr(self.quantization_config, "dequantize", False) - ): - # Inspect every visible CUDA device rather than - # hard-coding index 0: on a multi-GPU setup the - # user may be placing the model on a non-default - # device (``device_map``, ``CUDA_VISIBLE_DEVICES``). - # If ANY target device is non-Hopper, the native - # MXFP4 matmul would crash at kernel compile time - # with ``Only Hopper swizzling is supported``. - # Conservative fallback: if any visible device is - # non-Hopper (sm_90), route everyone to bf16. - n = _torch.cuda.device_count() - # Hopper = sm_90 (H100, H200); native MXFP4 ok. - # Ampere/Ada = sm_80/sm_86/sm_89 -- unsupported. - # Blackwell = sm_100 (B200) -- unsupported by - # triton_kernels MXFP4 path as of commit 0add68262a. - # When a future triton_kernels release adds - # additional architectures, extend this set. + if getattr(self.quantization_config, "dequantize", False): + pass + elif not _torch.cuda.is_available(): + self.quantization_config.dequantize = True + else: + # Scan every visible device so multi-GPU / + # device_map setups do not misread index 0. HOPPER_COMPATIBLE = {9} - non_hopper = False - for i in range(n): - major, _minor = _torch.cuda.get_device_capability(i) + for i in range(_torch.cuda.device_count()): + major, _ = _torch.cuda.get_device_capability(i) if major not in HOPPER_COMPATIBLE: - non_hopper = True + self.quantization_config.dequantize = True break - if non_hopper: - self.quantization_config.dequantize = True - elif not _torch.cuda.is_available(): - # CPU-only environments cannot run native MXFP4 at - # all (triton_kernels needs CUDA). Force dequantize - # so users at least get a runnable bf16 model. - if not getattr(self.quantization_config, "dequantize", False): - self.quantization_config.dequantize = True except Exception: pass return _orig_before_load(self, model, **kwargs) From 3029ac20c684998870e2d5af9bcd82ffb4c4bcf1 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Fri, 24 Apr 2026 06:24:23 +0000 Subject: [PATCH 3/7] gpt-oss MXFP4: remove Hopper-only gate Earlier version forced dequantize=True on non-Hopper CUDA devices, assuming triton_kernels MXFP4 matmul was Hopper-only. That was wrong: triton_kernels/tensor_details/layout.py picks BlackwellMXValueLayout on sm_100, HopperMXValueLayout on sm_90, and StridedLayout (no swizzle) on older archs, and the matmul assert at _matmul_ogs.py:114 allows both "HOPPER_VALUE" and None. So T4, A100, H100, and B200 all run native MXFP4 once swizzle_mxfp4_convertops fires. The earlier B200 crash I attributed to Hopper-only was actually caused by the 5.x WeightConverter skip leaving weights unswizzled -- fixed by the post-load walker in the previous commit. Keep only the CPU-only fallback (triton_kernels needs CUDA). Re-verified on B200 (sm_100, transformers 5.5.4): native MXFP4 produces identical coherent output to the dequantize path. --- unsloth_zoo/temporary_patches/gpt_oss.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index e930c431a..a12e126c4 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -257,31 +257,23 @@ def _patched_post_load(self, model, **kwargs): # triton_kernels MXFP4 matmul is Hopper-only; Ampere/Ada/Blackwell # abort with "Only Hopper swizzling is supported". Force dequantize # on any non-Hopper (or CPU-only) environment before load. - if not getattr(_Mxfp4Q, "_unsloth_hopper_gate_patched", False): + if not getattr(_Mxfp4Q, "_unsloth_cpu_gate_patched", False): _orig_before_load = _Mxfp4Q._process_model_before_weight_loading def _patched_before_load(self, model, **kwargs): try: import torch as _torch - if getattr(self.quantization_config, "dequantize", False): - pass - elif not _torch.cuda.is_available(): + if ( + not _torch.cuda.is_available() + and not getattr(self.quantization_config, "dequantize", False) + ): self.quantization_config.dequantize = True - else: - # Scan every visible device so multi-GPU / - # device_map setups do not misread index 0. - HOPPER_COMPATIBLE = {9} - for i in range(_torch.cuda.device_count()): - major, _ = _torch.cuda.get_device_capability(i) - if major not in HOPPER_COMPATIBLE: - self.quantization_config.dequantize = True - break except Exception: pass return _orig_before_load(self, model, **kwargs) _Mxfp4Q._process_model_before_weight_loading = _patched_before_load - _Mxfp4Q._unsloth_hopper_gate_patched = True + _Mxfp4Q._unsloth_cpu_gate_patched = True except Exception as e: return raise_error( "transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer post-load swizzle patch", e, From 833753ca8896dde9efa4d7b59bb7a67656717f3d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 May 2026 02:36:56 +0000 Subject: [PATCH 4/7] gpt-oss MXFP4: tighten post-load layout/swizzle guards and CPU gate - Dequantize transpose: when intermediate_size == hidden_size (gpt-oss-20b is 2880/2880), down_proj wrong (H, I) and right (I, H) are the same shape, so the per-projection skip used to leave down_proj in the wrong-layout orientation. The outer guard already proves the module came from the wrong-layout path via gate_up_proj, so transpose unconditionally for the ambiguous square case. - Native swizzle walker: mirror the Mxfp4GptOssExperts.gate_up_proj property invariant (numel > 0 and .any()) so partial / missing checkpoint loads do not silently swizzle zero placeholders into apparently-loaded zero experts. - before-load gate: also accept torch.xpu.is_available(); native MXFP4 is supported on XPU by transformers' validate_environment, and forcing dequantize there caused a 4x bf16 expansion. Wrap the config mutation in its own try/except that warns instead of swallowing, so a frozen quantization_config does not strand CPU users with no diagnostic. - Comment: replace the stale Hopper-only rationale with the actual layout pickers (Strided / Hopper / Blackwell) so future readers match the implemented gate. --- unsloth_zoo/temporary_patches/gpt_oss.py | 44 ++++++++++++++++++++---- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index a12e126c4..c901517f3 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -176,9 +176,14 @@ def _patched_post_load(self, model, **kwargs): if p is None or p.dim() != 3: continue shape = tuple(p.shape[-2:]) - if shape == expected_right[proj]: + # When H == I, down_proj wrong (H, I) and right + # (I, H) shapes are identical; the outer guard + # already proved this module is in the wrong + # layout, so transpose unconditionally. + ambiguous = (proj == "down_proj" and H == I) + if shape == expected_right[proj] and not ambiguous: continue - if shape != expected_wrong[proj]: + if shape != expected_wrong[proj] and not ambiguous: _warnings.warn( f"[unsloth] Unexpected MXFP4-dequantize " f"layout for {type(mod).__name__}.{proj}: " @@ -233,11 +238,22 @@ def _patched_post_load(self, model, **kwargs): continue if blocks.device.type == "meta": continue + # Mirror the Mxfp4GptOssExperts.gate_up_proj property + # invariant: zero placeholders mean weights never loaded; + # swizzling them silently produces zero experts. + if blocks.numel() == 0 or not bool(blocks.any()): + continue for proj in ("gate_up_proj", "down_proj"): b = getattr(mod, f"{proj}_blocks", None) s = getattr(mod, f"{proj}_scales", None) if b is None or s is None: continue + if ( + b.device.type == "meta" + or b.numel() == 0 + or not bool(b.any()) + ): + continue try: swizzle_fn(b.data, s.data, mod, proj, b.device, _tk) except Exception as _sw_err: @@ -254,20 +270,34 @@ def _patched_post_load(self, model, **kwargs): _Mxfp4Q._process_model_after_weight_loading = _patched_post_load _Mxfp4Q._unsloth_post_swizzle_patched = True - # triton_kernels MXFP4 matmul is Hopper-only; Ampere/Ada/Blackwell - # abort with "Only Hopper swizzling is supported". Force dequantize - # on any non-Hopper (or CPU-only) environment before load. + # triton_kernels MXFP4 needs CUDA; native MXFP4 itself works across + # CUDA compute capabilities (StridedLayout / HopperMXValueLayout / + # BlackwellMXValueLayout) and XPU has its own native path. Only + # force dequantize when no supported accelerator is visible. if not getattr(_Mxfp4Q, "_unsloth_cpu_gate_patched", False): _orig_before_load = _Mxfp4Q._process_model_before_weight_loading def _patched_before_load(self, model, **kwargs): try: import torch as _torch + cuda_ok = _torch.cuda.is_available() + xpu_ok = ( + hasattr(_torch, "xpu") and _torch.xpu.is_available() + ) if ( - not _torch.cuda.is_available() + not cuda_ok + and not xpu_ok and not getattr(self.quantization_config, "dequantize", False) ): - self.quantization_config.dequantize = True + try: + self.quantization_config.dequantize = True + except Exception as _cfg_err: + import warnings as _w + _w.warn( + f"[unsloth] Could not force MXFP4 dequantize " + f"on CPU-only host: {_cfg_err!r}; native " + f"MXFP4 path may crash without triton_kernels." + ) except Exception: pass return _orig_before_load(self, model, **kwargs) From cd0ddb8252dd96fcbca27e4857cbad5bc5d408f8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 May 2026 03:12:00 +0000 Subject: [PATCH 5/7] gpt-oss MXFP4: align swizzle error guards and forward use_kernels - Native swizzle paths: collapse the swizzle_fn-missing, triton_kernels import-failure, and active swizzle branches onto a shared per-projection predicate that requires both blocks and scales to be non-meta, non-empty, and blocks.any(). The error branches previously raised on freshly-init Mxfp4GptOssExperts (zero placeholders) and on real-blocks-plus-meta-scales partial loads, both of which the active branch correctly skipped. - Per-projection skip: gate the walker on f"_{proj}" in mod.__dict__ rather than the module-wide _gate_up_proj cache. The unsloth Mxfp4GptOssExperts has independent _gate_up_proj / _down_proj caches, so an early access to one projection used to leave the other raw. - Pre-load wrapper: widen the signature to (self, model, use_kernels=False, **kwargs) to match transformers 5.7.0 and forward the flag to the original. Honor use_kernels=True on CPU so the upstream native CPU MXFP4 path is preserved instead of being silently dequantized. - Pre-load detection: replace the bare except: pass around the device-detection block with a warn so future torch.xpu / torch.cuda API regressions surface instead of leaving the override a silent no-op. --- unsloth_zoo/temporary_patches/gpt_oss.py | 68 +++++++++++++----------- 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index c901517f3..ea07ed8b0 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -198,13 +198,31 @@ def _patched_post_load(self, model, **kwargs): import transformers.integrations.mxfp4 as _mx_mod swizzle_fn = getattr(_mx_mod, "swizzle_mxfp4_convertops", None) + + def _has_loaded_raw(mod, proj): + b = getattr(mod, f"{proj}_blocks", None) + s = getattr(mod, f"{proj}_scales", None) + if b is None or s is None: + return False + if b.device.type == "meta" or s.device.type == "meta": + return False + if b.numel() == 0 or s.numel() == 0: + return False + return bool(b.any()) + + def _module_has_loaded_raw(mod): + for proj in ("gate_up_proj", "down_proj"): + if f"_{proj}" in mod.__dict__: + continue + if _has_loaded_raw(mod, proj): + return True + return False + if swizzle_fn is None: for mod in model.modules(): if type(mod).__name__ != "Mxfp4GptOssExperts": continue - if "_gate_up_proj" in mod.__dict__: - continue - if getattr(mod, "gate_up_proj_blocks", None) is not None: + if _module_has_loaded_raw(mod): raise RuntimeError( "[unsloth] MXFP4 model has raw blocks/scales " "post-load but swizzle_mxfp4_convertops is " @@ -218,9 +236,7 @@ def _patched_post_load(self, model, **kwargs): for mod in model.modules(): if type(mod).__name__ != "Mxfp4GptOssExperts": continue - if "_gate_up_proj" in mod.__dict__: - continue - if getattr(mod, "gate_up_proj_blocks", None) is not None: + if _module_has_loaded_raw(mod): raise RuntimeError( "[unsloth] Native MXFP4 requires " "triton_kernels; install it or pass " @@ -230,30 +246,13 @@ def _patched_post_load(self, model, **kwargs): for mod in model.modules(): if type(mod).__name__ != "Mxfp4GptOssExperts": continue - if "_gate_up_proj" in mod.__dict__: - continue - blocks = getattr(mod, "gate_up_proj_blocks", None) - scales = getattr(mod, "gate_up_proj_scales", None) - if blocks is None or scales is None: - continue - if blocks.device.type == "meta": - continue - # Mirror the Mxfp4GptOssExperts.gate_up_proj property - # invariant: zero placeholders mean weights never loaded; - # swizzling them silently produces zero experts. - if blocks.numel() == 0 or not bool(blocks.any()): - continue for proj in ("gate_up_proj", "down_proj"): - b = getattr(mod, f"{proj}_blocks", None) - s = getattr(mod, f"{proj}_scales", None) - if b is None or s is None: + if f"_{proj}" in mod.__dict__: continue - if ( - b.device.type == "meta" - or b.numel() == 0 - or not bool(b.any()) - ): + if not _has_loaded_raw(mod, proj): continue + b = getattr(mod, f"{proj}_blocks") + s = getattr(mod, f"{proj}_scales") try: swizzle_fn(b.data, s.data, mod, proj, b.device, _tk) except Exception as _sw_err: @@ -277,7 +276,7 @@ def _patched_post_load(self, model, **kwargs): if not getattr(_Mxfp4Q, "_unsloth_cpu_gate_patched", False): _orig_before_load = _Mxfp4Q._process_model_before_weight_loading - def _patched_before_load(self, model, **kwargs): + def _patched_before_load(self, model, use_kernels=False, **kwargs): try: import torch as _torch cuda_ok = _torch.cuda.is_available() @@ -287,6 +286,7 @@ def _patched_before_load(self, model, **kwargs): if ( not cuda_ok and not xpu_ok + and not use_kernels and not getattr(self.quantization_config, "dequantize", False) ): try: @@ -298,9 +298,15 @@ def _patched_before_load(self, model, **kwargs): f"on CPU-only host: {_cfg_err!r}; native " f"MXFP4 path may crash without triton_kernels." ) - except Exception: - pass - return _orig_before_load(self, model, **kwargs) + except Exception as _det_err: + import warnings as _w + _w.warn( + f"[unsloth] MXFP4 pre-load device detection failed: " + f"{_det_err!r}; falling back to upstream." + ) + return _orig_before_load( + self, model, use_kernels=use_kernels, **kwargs, + ) _Mxfp4Q._process_model_before_weight_loading = _patched_before_load _Mxfp4Q._unsloth_cpu_gate_patched = True From c302e4680c780f4f72d932b6fbfb73b81cbea9e8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 May 2026 03:14:35 +0000 Subject: [PATCH 6/7] Add gpt-oss MXFP4 patch tests --- test_gpt_oss_mxfp4_swizzle_walker.py | 185 +++++++++++++++++++++++++++ test_gpt_oss_mxfp4_use_kernels.py | 117 +++++++++++++++++ 2 files changed, 302 insertions(+) create mode 100644 test_gpt_oss_mxfp4_swizzle_walker.py create mode 100644 test_gpt_oss_mxfp4_use_kernels.py diff --git a/test_gpt_oss_mxfp4_swizzle_walker.py b/test_gpt_oss_mxfp4_swizzle_walker.py new file mode 100644 index 000000000..aa9a91b73 --- /dev/null +++ b/test_gpt_oss_mxfp4_swizzle_walker.py @@ -0,0 +1,185 @@ +import os +import sys +import torch +import pytest + +HERE = os.path.dirname(os.path.abspath(__file__)) +if HERE not in sys.path: + sys.path.insert(0, HERE) + +from unsloth_zoo.temporary_patches.gpt_oss import patch_gpt_oss +import transformers +import transformers.integrations.mxfp4 as _mx_mod + +patch_gpt_oss() +_QCLS = transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer + + +class Mxfp4GptOssExperts(torch.nn.Module): + pass + + +class _Cfg: + dequantize = False + + +class _Model: + class config: + pass + + is_quantized = True + + def __init__(self, mods): + self._mods = mods + + def modules(self): + return iter(self._mods) + + +def _make_quantizer(dequantize=False): + cfg = _Cfg() + cfg.dequantize = dequantize + q = _QCLS.__new__(_QCLS) + q.quantization_config = cfg + q.pre_quantized = True + q.modules_to_not_convert = [] + q.triton_kernels_hub = None + return q + + +def _make_module(blocks_zero=True, scales_zero=True, scales_meta=False, + down_blocks_zero=True, down_scales_zero=True): + m = Mxfp4GptOssExperts() + m.gate_up_proj_blocks = torch.nn.Parameter( + torch.zeros(2, 8, 4, 16, dtype=torch.uint8) if blocks_zero + else torch.ones(2, 8, 4, 16, dtype=torch.uint8), + requires_grad=False, + ) + if scales_meta: + m.gate_up_proj_scales = torch.nn.Parameter( + torch.empty(2, 8, 4, dtype=torch.uint8, device="meta"), + requires_grad=False, + ) + else: + m.gate_up_proj_scales = torch.nn.Parameter( + torch.zeros(2, 8, 4, dtype=torch.uint8) if scales_zero + else torch.ones(2, 8, 4, dtype=torch.uint8), + requires_grad=False, + ) + m.down_proj_blocks = torch.nn.Parameter( + torch.zeros(2, 4, 4, 16, dtype=torch.uint8) if down_blocks_zero + else torch.ones(2, 4, 4, 16, dtype=torch.uint8), + requires_grad=False, + ) + m.down_proj_scales = torch.nn.Parameter( + torch.zeros(2, 4, 4, dtype=torch.uint8) if down_scales_zero + else torch.ones(2, 4, 4, dtype=torch.uint8), + requires_grad=False, + ) + return m + + +@pytest.fixture +def hide_swizzle_fn(): + saved = getattr(_mx_mod, "swizzle_mxfp4_convertops", None) + if saved is not None: + delattr(_mx_mod, "swizzle_mxfp4_convertops") + yield + if saved is not None: + _mx_mod.swizzle_mxfp4_convertops = saved + + +def test_zero_placeholders_do_not_raise_when_swizzle_fn_missing(hide_swizzle_fn): + """Freshly-init Mxfp4GptOssExperts with zero placeholders should be + treated as not-loaded and skipped, not raised on.""" + mod = _make_module(blocks_zero=True, scales_zero=True, + down_blocks_zero=True, down_scales_zero=True) + q = _make_quantizer(dequantize=False) + _QCLS._process_model_after_weight_loading(q, _Model([mod])) + + +def test_loaded_blocks_raise_when_swizzle_fn_missing(hide_swizzle_fn): + """When real blocks are present, the missing-swizzle branch must still raise.""" + mod = _make_module(blocks_zero=False, scales_zero=False) + q = _make_quantizer(dequantize=False) + with pytest.raises(RuntimeError, match="raw blocks/scales"): + _QCLS._process_model_after_weight_loading(q, _Model([mod])) + + +def test_meta_scales_treated_as_not_loaded(hide_swizzle_fn): + """Real blocks + meta scales should be skipped, not crash.""" + mod = _make_module(blocks_zero=False, scales_meta=True, + down_blocks_zero=True, down_scales_zero=True) + q = _make_quantizer(dequantize=False) + _QCLS._process_model_after_weight_loading(q, _Model([mod])) + + +def test_per_projection_skip_repairs_uncached_down_proj(): + """If _gate_up_proj is cached early (e.g. property accessed) but + down_proj is still raw and loaded, the active swizzle path must + repair down_proj instead of skipping the whole module.""" + mod = _make_module(blocks_zero=True, scales_zero=True, + down_blocks_zero=False, down_scales_zero=False) + # Simulate gate_up_proj cached + mod.__dict__["_gate_up_proj"] = torch.zeros(2, 4, 8) + + swizzle_calls = [] + + def fake_swizzle(b, s, mod, proj, dev, tk): + swizzle_calls.append(proj) + # mimic upstream cleanup: remove the raw params + if f"{proj}_blocks" in mod._parameters: + del mod._parameters[f"{proj}_blocks"] + if f"{proj}_scales" in mod._parameters: + del mod._parameters[f"{proj}_scales"] + + saved = getattr(_mx_mod, "swizzle_mxfp4_convertops", None) + _mx_mod.swizzle_mxfp4_convertops = fake_swizzle + try: + # provide a fake top-level triton_kernels module so the import succeeds + import types as _types + sys.modules.setdefault("triton_kernels", _types.ModuleType("triton_kernels")) + q = _make_quantizer(dequantize=False) + _QCLS._process_model_after_weight_loading(q, _Model([mod])) + finally: + if saved is not None: + _mx_mod.swizzle_mxfp4_convertops = saved + + assert swizzle_calls == ["down_proj"], swizzle_calls + + +def test_per_projection_skip_skips_both_when_both_cached(): + """If both projections are cached (post-load already ran), the + active swizzle path must be a no-op.""" + mod = _make_module(blocks_zero=False, scales_zero=False, + down_blocks_zero=False, down_scales_zero=False) + mod.__dict__["_gate_up_proj"] = torch.zeros(2, 4, 8) + mod.__dict__["_down_proj"] = torch.zeros(2, 4, 4) + + swizzle_calls = [] + + def fake_swizzle(b, s, mod, proj, dev, tk): + swizzle_calls.append(proj) + + saved = getattr(_mx_mod, "swizzle_mxfp4_convertops", None) + _mx_mod.swizzle_mxfp4_convertops = fake_swizzle + try: + import types as _types + sys.modules.setdefault("triton_kernels", _types.ModuleType("triton_kernels")) + q = _make_quantizer(dequantize=False) + _QCLS._process_model_after_weight_loading(q, _Model([mod])) + finally: + if saved is not None: + _mx_mod.swizzle_mxfp4_convertops = saved + + assert swizzle_calls == [] + + +def test_partial_load_one_projection_loaded(hide_swizzle_fn): + """Module has gate_up_proj loaded but down_proj still zero. Missing + swizzle_fn branch must raise (since some weights are loaded).""" + mod = _make_module(blocks_zero=False, scales_zero=False, + down_blocks_zero=True, down_scales_zero=True) + q = _make_quantizer(dequantize=False) + with pytest.raises(RuntimeError, match="raw blocks/scales"): + _QCLS._process_model_after_weight_loading(q, _Model([mod])) diff --git a/test_gpt_oss_mxfp4_use_kernels.py b/test_gpt_oss_mxfp4_use_kernels.py new file mode 100644 index 000000000..5a32a8b85 --- /dev/null +++ b/test_gpt_oss_mxfp4_use_kernels.py @@ -0,0 +1,117 @@ +import os +import sys +import warnings +import torch +import pytest + +HERE = os.path.dirname(os.path.abspath(__file__)) +if HERE not in sys.path: + sys.path.insert(0, HERE) + +from unsloth_zoo.temporary_patches.gpt_oss import patch_gpt_oss +import transformers + +patch_gpt_oss() +_QCLS = transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer + + +class _Cfg: + dequantize = False + + +class _Model: + class config: + pass + + def modules(self): + return iter(()) + + +def _make_quantizer(dequantize=False): + cfg = _Cfg() + cfg.dequantize = dequantize + q = _QCLS.__new__(_QCLS) + q.quantization_config = cfg + q.pre_quantized = True + q.modules_to_not_convert = [] + q.triton_kernels_hub = None + return q + + +def _stub_orig_noop(): + fn = _QCLS._process_model_before_weight_loading + fn.__closure__[0].cell_contents = lambda self, model, **kwargs: None + + +def _patch_devices(monkeypatch, *, cuda=False, xpu=False): + monkeypatch.setattr(torch.cuda, "is_available", lambda: cuda) + if not hasattr(torch, "xpu"): + torch.xpu = type("xpu", (), {"is_available": staticmethod(lambda: xpu)}) + monkeypatch.setattr(torch.xpu, "is_available", lambda: xpu) + + +def test_cpu_use_kernels_true_keeps_dequantize_false(monkeypatch): + _patch_devices(monkeypatch, cuda=False, xpu=False) + _stub_orig_noop() + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model(), use_kernels=True) + assert q.quantization_config.dequantize is False + + +def test_cpu_use_kernels_false_forces_dequantize(monkeypatch): + _patch_devices(monkeypatch, cuda=False, xpu=False) + _stub_orig_noop() + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model(), use_kernels=False) + assert q.quantization_config.dequantize is True + + +def test_cpu_default_call_forces_dequantize(monkeypatch): + _patch_devices(monkeypatch, cuda=False, xpu=False) + _stub_orig_noop() + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model()) + assert q.quantization_config.dequantize is True + + +def test_positional_use_kernels_does_not_raise(monkeypatch): + _patch_devices(monkeypatch, cuda=False, xpu=False) + _stub_orig_noop() + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model(), True) + assert q.quantization_config.dequantize is False + + +def test_use_kernels_forwarded_to_orig(monkeypatch): + _patch_devices(monkeypatch, cuda=True, xpu=False) + seen = {} + + def fake_orig(self, model, **kwargs): + seen["use_kernels"] = kwargs.get("use_kernels") + + fn = _QCLS._process_model_before_weight_loading + fn.__closure__[0].cell_contents = fake_orig + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model(), use_kernels=True) + assert seen.get("use_kernels") is True + + +def test_detection_failure_warns_and_proceeds(monkeypatch): + """If torch.cuda.is_available raises, the wrapper must warn and still + call the original (not silently swallow).""" + monkeypatch.setattr(torch.cuda, "is_available", lambda: (_ for _ in ()).throw(RuntimeError("driver gone"))) + called = {"orig": False} + + def fake_orig(self, model, **kwargs): + called["orig"] = True + + fn = _QCLS._process_model_before_weight_loading + fn.__closure__[0].cell_contents = fake_orig + q = _make_quantizer(dequantize=False) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + _QCLS._process_model_before_weight_loading(q, _Model()) + assert called["orig"] is True + assert any( + "MXFP4 pre-load device detection failed" in str(w.message) for w in caught + ), [str(w.message) for w in caught] From 620a892b2b39db4f2702106add9aa22dbf9c1f51 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 May 2026 03:36:22 +0000 Subject: [PATCH 7/7] Consolidate gpt-oss MXFP4 patch tests into a single file Merge the post-load swizzle walker and pre-load gate tests into a single behavior-named module covering both halves of patch_gpt_oss(): - Post-load native swizzle walker: zero-placeholder skip vs raw-block raise, meta-scales handling, per-projection cache skip including the gate-up-cached-but-down-raw partial case. - Pre-load wrapper: use_kernels=True keeps native CPU MXFP4, default call forces dequantize, positional argument compatibility, use_kernels forwarded to upstream, detection-failure warns. --- ...e_walker.py => test_gpt_oss_mxfp4_patch.py | 94 +++++++++++--- test_gpt_oss_mxfp4_use_kernels.py | 117 ------------------ 2 files changed, 79 insertions(+), 132 deletions(-) rename test_gpt_oss_mxfp4_swizzle_walker.py => test_gpt_oss_mxfp4_patch.py (65%) delete mode 100644 test_gpt_oss_mxfp4_use_kernels.py diff --git a/test_gpt_oss_mxfp4_swizzle_walker.py b/test_gpt_oss_mxfp4_patch.py similarity index 65% rename from test_gpt_oss_mxfp4_swizzle_walker.py rename to test_gpt_oss_mxfp4_patch.py index aa9a91b73..f4504af1a 100644 --- a/test_gpt_oss_mxfp4_swizzle_walker.py +++ b/test_gpt_oss_mxfp4_patch.py @@ -1,5 +1,6 @@ import os import sys +import warnings import torch import pytest @@ -29,7 +30,7 @@ class config: is_quantized = True - def __init__(self, mods): + def __init__(self, mods=()): self._mods = mods def modules(self): @@ -90,8 +91,6 @@ def hide_swizzle_fn(): def test_zero_placeholders_do_not_raise_when_swizzle_fn_missing(hide_swizzle_fn): - """Freshly-init Mxfp4GptOssExperts with zero placeholders should be - treated as not-loaded and skipped, not raised on.""" mod = _make_module(blocks_zero=True, scales_zero=True, down_blocks_zero=True, down_scales_zero=True) q = _make_quantizer(dequantize=False) @@ -99,7 +98,6 @@ def test_zero_placeholders_do_not_raise_when_swizzle_fn_missing(hide_swizzle_fn) def test_loaded_blocks_raise_when_swizzle_fn_missing(hide_swizzle_fn): - """When real blocks are present, the missing-swizzle branch must still raise.""" mod = _make_module(blocks_zero=False, scales_zero=False) q = _make_quantizer(dequantize=False) with pytest.raises(RuntimeError, match="raw blocks/scales"): @@ -107,7 +105,6 @@ def test_loaded_blocks_raise_when_swizzle_fn_missing(hide_swizzle_fn): def test_meta_scales_treated_as_not_loaded(hide_swizzle_fn): - """Real blocks + meta scales should be skipped, not crash.""" mod = _make_module(blocks_zero=False, scales_meta=True, down_blocks_zero=True, down_scales_zero=True) q = _make_quantizer(dequantize=False) @@ -115,19 +112,14 @@ def test_meta_scales_treated_as_not_loaded(hide_swizzle_fn): def test_per_projection_skip_repairs_uncached_down_proj(): - """If _gate_up_proj is cached early (e.g. property accessed) but - down_proj is still raw and loaded, the active swizzle path must - repair down_proj instead of skipping the whole module.""" mod = _make_module(blocks_zero=True, scales_zero=True, down_blocks_zero=False, down_scales_zero=False) - # Simulate gate_up_proj cached mod.__dict__["_gate_up_proj"] = torch.zeros(2, 4, 8) swizzle_calls = [] def fake_swizzle(b, s, mod, proj, dev, tk): swizzle_calls.append(proj) - # mimic upstream cleanup: remove the raw params if f"{proj}_blocks" in mod._parameters: del mod._parameters[f"{proj}_blocks"] if f"{proj}_scales" in mod._parameters: @@ -136,7 +128,6 @@ def fake_swizzle(b, s, mod, proj, dev, tk): saved = getattr(_mx_mod, "swizzle_mxfp4_convertops", None) _mx_mod.swizzle_mxfp4_convertops = fake_swizzle try: - # provide a fake top-level triton_kernels module so the import succeeds import types as _types sys.modules.setdefault("triton_kernels", _types.ModuleType("triton_kernels")) q = _make_quantizer(dequantize=False) @@ -149,8 +140,6 @@ def fake_swizzle(b, s, mod, proj, dev, tk): def test_per_projection_skip_skips_both_when_both_cached(): - """If both projections are cached (post-load already ran), the - active swizzle path must be a no-op.""" mod = _make_module(blocks_zero=False, scales_zero=False, down_blocks_zero=False, down_scales_zero=False) mod.__dict__["_gate_up_proj"] = torch.zeros(2, 4, 8) @@ -176,10 +165,85 @@ def fake_swizzle(b, s, mod, proj, dev, tk): def test_partial_load_one_projection_loaded(hide_swizzle_fn): - """Module has gate_up_proj loaded but down_proj still zero. Missing - swizzle_fn branch must raise (since some weights are loaded).""" mod = _make_module(blocks_zero=False, scales_zero=False, down_blocks_zero=True, down_scales_zero=True) q = _make_quantizer(dequantize=False) with pytest.raises(RuntimeError, match="raw blocks/scales"): _QCLS._process_model_after_weight_loading(q, _Model([mod])) + + +def _stub_orig_noop(): + fn = _QCLS._process_model_before_weight_loading + fn.__closure__[0].cell_contents = lambda self, model, **kwargs: None + + +def _patch_devices(monkeypatch, *, cuda=False, xpu=False): + monkeypatch.setattr(torch.cuda, "is_available", lambda: cuda) + if not hasattr(torch, "xpu"): + torch.xpu = type("xpu", (), {"is_available": staticmethod(lambda: xpu)}) + monkeypatch.setattr(torch.xpu, "is_available", lambda: xpu) + + +def test_cpu_use_kernels_true_keeps_dequantize_false(monkeypatch): + _patch_devices(monkeypatch, cuda=False, xpu=False) + _stub_orig_noop() + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model(), use_kernels=True) + assert q.quantization_config.dequantize is False + + +def test_cpu_use_kernels_false_forces_dequantize(monkeypatch): + _patch_devices(monkeypatch, cuda=False, xpu=False) + _stub_orig_noop() + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model(), use_kernels=False) + assert q.quantization_config.dequantize is True + + +def test_cpu_default_call_forces_dequantize(monkeypatch): + _patch_devices(monkeypatch, cuda=False, xpu=False) + _stub_orig_noop() + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model()) + assert q.quantization_config.dequantize is True + + +def test_positional_use_kernels_does_not_raise(monkeypatch): + _patch_devices(monkeypatch, cuda=False, xpu=False) + _stub_orig_noop() + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model(), True) + assert q.quantization_config.dequantize is False + + +def test_use_kernels_forwarded_to_orig(monkeypatch): + _patch_devices(monkeypatch, cuda=True, xpu=False) + seen = {} + + def fake_orig(self, model, **kwargs): + seen["use_kernels"] = kwargs.get("use_kernels") + + fn = _QCLS._process_model_before_weight_loading + fn.__closure__[0].cell_contents = fake_orig + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model(), use_kernels=True) + assert seen.get("use_kernels") is True + + +def test_detection_failure_warns_and_proceeds(monkeypatch): + monkeypatch.setattr(torch.cuda, "is_available", lambda: (_ for _ in ()).throw(RuntimeError("driver gone"))) + called = {"orig": False} + + def fake_orig(self, model, **kwargs): + called["orig"] = True + + fn = _QCLS._process_model_before_weight_loading + fn.__closure__[0].cell_contents = fake_orig + q = _make_quantizer(dequantize=False) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + _QCLS._process_model_before_weight_loading(q, _Model()) + assert called["orig"] is True + assert any( + "MXFP4 pre-load device detection failed" in str(w.message) for w in caught + ), [str(w.message) for w in caught] diff --git a/test_gpt_oss_mxfp4_use_kernels.py b/test_gpt_oss_mxfp4_use_kernels.py deleted file mode 100644 index 5a32a8b85..000000000 --- a/test_gpt_oss_mxfp4_use_kernels.py +++ /dev/null @@ -1,117 +0,0 @@ -import os -import sys -import warnings -import torch -import pytest - -HERE = os.path.dirname(os.path.abspath(__file__)) -if HERE not in sys.path: - sys.path.insert(0, HERE) - -from unsloth_zoo.temporary_patches.gpt_oss import patch_gpt_oss -import transformers - -patch_gpt_oss() -_QCLS = transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer - - -class _Cfg: - dequantize = False - - -class _Model: - class config: - pass - - def modules(self): - return iter(()) - - -def _make_quantizer(dequantize=False): - cfg = _Cfg() - cfg.dequantize = dequantize - q = _QCLS.__new__(_QCLS) - q.quantization_config = cfg - q.pre_quantized = True - q.modules_to_not_convert = [] - q.triton_kernels_hub = None - return q - - -def _stub_orig_noop(): - fn = _QCLS._process_model_before_weight_loading - fn.__closure__[0].cell_contents = lambda self, model, **kwargs: None - - -def _patch_devices(monkeypatch, *, cuda=False, xpu=False): - monkeypatch.setattr(torch.cuda, "is_available", lambda: cuda) - if not hasattr(torch, "xpu"): - torch.xpu = type("xpu", (), {"is_available": staticmethod(lambda: xpu)}) - monkeypatch.setattr(torch.xpu, "is_available", lambda: xpu) - - -def test_cpu_use_kernels_true_keeps_dequantize_false(monkeypatch): - _patch_devices(monkeypatch, cuda=False, xpu=False) - _stub_orig_noop() - q = _make_quantizer(dequantize=False) - _QCLS._process_model_before_weight_loading(q, _Model(), use_kernels=True) - assert q.quantization_config.dequantize is False - - -def test_cpu_use_kernels_false_forces_dequantize(monkeypatch): - _patch_devices(monkeypatch, cuda=False, xpu=False) - _stub_orig_noop() - q = _make_quantizer(dequantize=False) - _QCLS._process_model_before_weight_loading(q, _Model(), use_kernels=False) - assert q.quantization_config.dequantize is True - - -def test_cpu_default_call_forces_dequantize(monkeypatch): - _patch_devices(monkeypatch, cuda=False, xpu=False) - _stub_orig_noop() - q = _make_quantizer(dequantize=False) - _QCLS._process_model_before_weight_loading(q, _Model()) - assert q.quantization_config.dequantize is True - - -def test_positional_use_kernels_does_not_raise(monkeypatch): - _patch_devices(monkeypatch, cuda=False, xpu=False) - _stub_orig_noop() - q = _make_quantizer(dequantize=False) - _QCLS._process_model_before_weight_loading(q, _Model(), True) - assert q.quantization_config.dequantize is False - - -def test_use_kernels_forwarded_to_orig(monkeypatch): - _patch_devices(monkeypatch, cuda=True, xpu=False) - seen = {} - - def fake_orig(self, model, **kwargs): - seen["use_kernels"] = kwargs.get("use_kernels") - - fn = _QCLS._process_model_before_weight_loading - fn.__closure__[0].cell_contents = fake_orig - q = _make_quantizer(dequantize=False) - _QCLS._process_model_before_weight_loading(q, _Model(), use_kernels=True) - assert seen.get("use_kernels") is True - - -def test_detection_failure_warns_and_proceeds(monkeypatch): - """If torch.cuda.is_available raises, the wrapper must warn and still - call the original (not silently swallow).""" - monkeypatch.setattr(torch.cuda, "is_available", lambda: (_ for _ in ()).throw(RuntimeError("driver gone"))) - called = {"orig": False} - - def fake_orig(self, model, **kwargs): - called["orig"] = True - - fn = _QCLS._process_model_before_weight_loading - fn.__closure__[0].cell_contents = fake_orig - q = _make_quantizer(dequantize=False) - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - _QCLS._process_model_before_weight_loading(q, _Model()) - assert called["orig"] is True - assert any( - "MXFP4 pre-load device detection failed" in str(w.message) for w in caught - ), [str(w.message) for w in caught]