diff --git a/test_gpt_oss_mxfp4_patch.py b/test_gpt_oss_mxfp4_patch.py new file mode 100644 index 000000000..f4504af1a --- /dev/null +++ b/test_gpt_oss_mxfp4_patch.py @@ -0,0 +1,249 @@ +import os +import sys +import warnings +import torch +import pytest + +HERE = os.path.dirname(os.path.abspath(__file__)) +if HERE not in sys.path: + sys.path.insert(0, HERE) + +from unsloth_zoo.temporary_patches.gpt_oss import patch_gpt_oss +import transformers +import transformers.integrations.mxfp4 as _mx_mod + +patch_gpt_oss() +_QCLS = transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer + + +class Mxfp4GptOssExperts(torch.nn.Module): + pass + + +class _Cfg: + dequantize = False + + +class _Model: + class config: + pass + + is_quantized = True + + def __init__(self, mods=()): + self._mods = mods + + def modules(self): + return iter(self._mods) + + +def _make_quantizer(dequantize=False): + cfg = _Cfg() + cfg.dequantize = dequantize + q = _QCLS.__new__(_QCLS) + q.quantization_config = cfg + q.pre_quantized = True + q.modules_to_not_convert = [] + q.triton_kernels_hub = None + return q + + +def _make_module(blocks_zero=True, scales_zero=True, scales_meta=False, + down_blocks_zero=True, down_scales_zero=True): + m = Mxfp4GptOssExperts() + m.gate_up_proj_blocks = torch.nn.Parameter( + torch.zeros(2, 8, 4, 16, dtype=torch.uint8) if blocks_zero + else torch.ones(2, 8, 4, 16, dtype=torch.uint8), + requires_grad=False, + ) + if scales_meta: + m.gate_up_proj_scales = torch.nn.Parameter( + torch.empty(2, 8, 4, dtype=torch.uint8, device="meta"), + requires_grad=False, + ) + else: + m.gate_up_proj_scales = torch.nn.Parameter( + torch.zeros(2, 8, 4, dtype=torch.uint8) if scales_zero + else torch.ones(2, 8, 4, dtype=torch.uint8), + requires_grad=False, + ) + m.down_proj_blocks = torch.nn.Parameter( + torch.zeros(2, 4, 4, 16, dtype=torch.uint8) if down_blocks_zero + else torch.ones(2, 4, 4, 16, dtype=torch.uint8), + requires_grad=False, + ) + m.down_proj_scales = torch.nn.Parameter( + torch.zeros(2, 4, 4, dtype=torch.uint8) if down_scales_zero + else torch.ones(2, 4, 4, dtype=torch.uint8), + requires_grad=False, + ) + return m + + +@pytest.fixture +def hide_swizzle_fn(): + saved = getattr(_mx_mod, "swizzle_mxfp4_convertops", None) + if saved is not None: + delattr(_mx_mod, "swizzle_mxfp4_convertops") + yield + if saved is not None: + _mx_mod.swizzle_mxfp4_convertops = saved + + +def test_zero_placeholders_do_not_raise_when_swizzle_fn_missing(hide_swizzle_fn): + mod = _make_module(blocks_zero=True, scales_zero=True, + down_blocks_zero=True, down_scales_zero=True) + q = _make_quantizer(dequantize=False) + _QCLS._process_model_after_weight_loading(q, _Model([mod])) + + +def test_loaded_blocks_raise_when_swizzle_fn_missing(hide_swizzle_fn): + mod = _make_module(blocks_zero=False, scales_zero=False) + q = _make_quantizer(dequantize=False) + with pytest.raises(RuntimeError, match="raw blocks/scales"): + _QCLS._process_model_after_weight_loading(q, _Model([mod])) + + +def test_meta_scales_treated_as_not_loaded(hide_swizzle_fn): + mod = _make_module(blocks_zero=False, scales_meta=True, + down_blocks_zero=True, down_scales_zero=True) + q = _make_quantizer(dequantize=False) + _QCLS._process_model_after_weight_loading(q, _Model([mod])) + + +def test_per_projection_skip_repairs_uncached_down_proj(): + mod = _make_module(blocks_zero=True, scales_zero=True, + down_blocks_zero=False, down_scales_zero=False) + mod.__dict__["_gate_up_proj"] = torch.zeros(2, 4, 8) + + swizzle_calls = [] + + def fake_swizzle(b, s, mod, proj, dev, tk): + swizzle_calls.append(proj) + if f"{proj}_blocks" in mod._parameters: + del mod._parameters[f"{proj}_blocks"] + if f"{proj}_scales" in mod._parameters: + del mod._parameters[f"{proj}_scales"] + + saved = getattr(_mx_mod, "swizzle_mxfp4_convertops", None) + _mx_mod.swizzle_mxfp4_convertops = fake_swizzle + try: + import types as _types + sys.modules.setdefault("triton_kernels", _types.ModuleType("triton_kernels")) + q = _make_quantizer(dequantize=False) + _QCLS._process_model_after_weight_loading(q, _Model([mod])) + finally: + if saved is not None: + _mx_mod.swizzle_mxfp4_convertops = saved + + assert swizzle_calls == ["down_proj"], swizzle_calls + + +def test_per_projection_skip_skips_both_when_both_cached(): + mod = _make_module(blocks_zero=False, scales_zero=False, + down_blocks_zero=False, down_scales_zero=False) + mod.__dict__["_gate_up_proj"] = torch.zeros(2, 4, 8) + mod.__dict__["_down_proj"] = torch.zeros(2, 4, 4) + + swizzle_calls = [] + + def fake_swizzle(b, s, mod, proj, dev, tk): + swizzle_calls.append(proj) + + saved = getattr(_mx_mod, "swizzle_mxfp4_convertops", None) + _mx_mod.swizzle_mxfp4_convertops = fake_swizzle + try: + import types as _types + sys.modules.setdefault("triton_kernels", _types.ModuleType("triton_kernels")) + q = _make_quantizer(dequantize=False) + _QCLS._process_model_after_weight_loading(q, _Model([mod])) + finally: + if saved is not None: + _mx_mod.swizzle_mxfp4_convertops = saved + + assert swizzle_calls == [] + + +def test_partial_load_one_projection_loaded(hide_swizzle_fn): + mod = _make_module(blocks_zero=False, scales_zero=False, + down_blocks_zero=True, down_scales_zero=True) + q = _make_quantizer(dequantize=False) + with pytest.raises(RuntimeError, match="raw blocks/scales"): + _QCLS._process_model_after_weight_loading(q, _Model([mod])) + + +def _stub_orig_noop(): + fn = _QCLS._process_model_before_weight_loading + fn.__closure__[0].cell_contents = lambda self, model, **kwargs: None + + +def _patch_devices(monkeypatch, *, cuda=False, xpu=False): + monkeypatch.setattr(torch.cuda, "is_available", lambda: cuda) + if not hasattr(torch, "xpu"): + torch.xpu = type("xpu", (), {"is_available": staticmethod(lambda: xpu)}) + monkeypatch.setattr(torch.xpu, "is_available", lambda: xpu) + + +def test_cpu_use_kernels_true_keeps_dequantize_false(monkeypatch): + _patch_devices(monkeypatch, cuda=False, xpu=False) + _stub_orig_noop() + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model(), use_kernels=True) + assert q.quantization_config.dequantize is False + + +def test_cpu_use_kernels_false_forces_dequantize(monkeypatch): + _patch_devices(monkeypatch, cuda=False, xpu=False) + _stub_orig_noop() + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model(), use_kernels=False) + assert q.quantization_config.dequantize is True + + +def test_cpu_default_call_forces_dequantize(monkeypatch): + _patch_devices(monkeypatch, cuda=False, xpu=False) + _stub_orig_noop() + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model()) + assert q.quantization_config.dequantize is True + + +def test_positional_use_kernels_does_not_raise(monkeypatch): + _patch_devices(monkeypatch, cuda=False, xpu=False) + _stub_orig_noop() + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model(), True) + assert q.quantization_config.dequantize is False + + +def test_use_kernels_forwarded_to_orig(monkeypatch): + _patch_devices(monkeypatch, cuda=True, xpu=False) + seen = {} + + def fake_orig(self, model, **kwargs): + seen["use_kernels"] = kwargs.get("use_kernels") + + fn = _QCLS._process_model_before_weight_loading + fn.__closure__[0].cell_contents = fake_orig + q = _make_quantizer(dequantize=False) + _QCLS._process_model_before_weight_loading(q, _Model(), use_kernels=True) + assert seen.get("use_kernels") is True + + +def test_detection_failure_warns_and_proceeds(monkeypatch): + monkeypatch.setattr(torch.cuda, "is_available", lambda: (_ for _ in ()).throw(RuntimeError("driver gone"))) + called = {"orig": False} + + def fake_orig(self, model, **kwargs): + called["orig"] = True + + fn = _QCLS._process_model_before_weight_loading + fn.__closure__[0].cell_contents = fake_orig + q = _make_quantizer(dequantize=False) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + _QCLS._process_model_before_weight_loading(q, _Model()) + assert called["orig"] is True + assert any( + "MXFP4 pre-load device detection failed" in str(w.message) for w in caught + ), [str(w.message) for w in caught] diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 7bcd0a333..ea07ed8b0 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -128,6 +128,193 @@ def patch_gpt_oss(): except Exception as e: return raise_error("transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e) + # transformers 5.x moved MXFP4 loading from ``load_and_swizzle_mxfp4`` + # (4.56.x) to the ``Mxfp4Deserialize`` WeightConverter, which is skipped + # whenever checkpoint key names already match registered parameters -- + # so MXFP4 models load with raw blocks/scales still on the module and + # the swizzle step never runs. 5.x ``Mxfp4Dequantize.convert`` also + # drops the transpose that 4.x baked into ``convert_moe_packed_tensors``, + # so gate_up_proj ends up as (E, 2I, H) vs the (E, H, 2I) the stock + # forward expects. Fix both in a post-load walker. + try: + _Mxfp4Q = transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer + if not getattr(_Mxfp4Q, "_unsloth_post_swizzle_patched", False): + _orig_post_load = _Mxfp4Q._process_model_after_weight_loading + + def _patched_post_load(self, model, **kwargs): + _orig_post_load(self, model, **kwargs) + import torch as _torch + import warnings as _warnings + dq = getattr(self.quantization_config, "dequantize", False) + if dq: + for mod in model.modules(): + if type(mod).__name__ != "GptOssExperts": + continue + H = getattr(mod, "hidden_size", None) + I = getattr(mod, "intermediate_size", None) + if H is None or I is None: + continue + gup = getattr(mod, "gate_up_proj", None) + if gup is None or gup.dim() != 3: + continue + _, D1, D2 = gup.shape + # gate_up_proj shape is unambiguous (2I != H); + # down_proj shape is ambiguous when I == H + # (gpt-oss-20b: 2880 / 2880). + if not (D1 == 2 * I and D2 == H): + continue + expected_wrong = { + "gate_up_proj": (2 * I, H), + "down_proj": (H, I), + } + expected_right = { + "gate_up_proj": (H, 2 * I), + "down_proj": (I, H), + } + for proj in ("gate_up_proj", "down_proj"): + p = getattr(mod, proj, None) + if p is None or p.dim() != 3: + continue + shape = tuple(p.shape[-2:]) + # When H == I, down_proj wrong (H, I) and right + # (I, H) shapes are identical; the outer guard + # already proved this module is in the wrong + # layout, so transpose unconditionally. + ambiguous = (proj == "down_proj" and H == I) + if shape == expected_right[proj] and not ambiguous: + continue + if shape != expected_wrong[proj] and not ambiguous: + _warnings.warn( + f"[unsloth] Unexpected MXFP4-dequantize " + f"layout for {type(mod).__name__}.{proj}: " + f"got {tuple(p.shape)}. Skipping transpose." + ) + continue + new_p = p.data.transpose(-2, -1).contiguous() + setattr(mod, proj, _torch.nn.Parameter( + new_p, requires_grad=p.requires_grad, + )) + return + + import transformers.integrations.mxfp4 as _mx_mod + swizzle_fn = getattr(_mx_mod, "swizzle_mxfp4_convertops", None) + + def _has_loaded_raw(mod, proj): + b = getattr(mod, f"{proj}_blocks", None) + s = getattr(mod, f"{proj}_scales", None) + if b is None or s is None: + return False + if b.device.type == "meta" or s.device.type == "meta": + return False + if b.numel() == 0 or s.numel() == 0: + return False + return bool(b.any()) + + def _module_has_loaded_raw(mod): + for proj in ("gate_up_proj", "down_proj"): + if f"_{proj}" in mod.__dict__: + continue + if _has_loaded_raw(mod, proj): + return True + return False + + if swizzle_fn is None: + for mod in model.modules(): + if type(mod).__name__ != "Mxfp4GptOssExperts": + continue + if _module_has_loaded_raw(mod): + raise RuntimeError( + "[unsloth] MXFP4 model has raw blocks/scales " + "post-load but swizzle_mxfp4_convertops is " + "not available; pass Mxfp4Config(dequantize=" + "True) to fall back to bf16." + ) + return + try: + import triton_kernels as _tk + except Exception as _tk_err: + for mod in model.modules(): + if type(mod).__name__ != "Mxfp4GptOssExperts": + continue + if _module_has_loaded_raw(mod): + raise RuntimeError( + "[unsloth] Native MXFP4 requires " + "triton_kernels; install it or pass " + "Mxfp4Config(dequantize=True) for bf16." + ) from _tk_err + return + for mod in model.modules(): + if type(mod).__name__ != "Mxfp4GptOssExperts": + continue + for proj in ("gate_up_proj", "down_proj"): + if f"_{proj}" in mod.__dict__: + continue + if not _has_loaded_raw(mod, proj): + continue + b = getattr(mod, f"{proj}_blocks") + s = getattr(mod, f"{proj}_scales") + try: + swizzle_fn(b.data, s.data, mod, proj, b.device, _tk) + except Exception as _sw_err: + raise RuntimeError( + f"[unsloth] MXFP4 swizzle failed on " + f"{type(mod).__name__}.{proj}; likely " + f"triton_kernels / transformers drift." + ) from _sw_err + if f"{proj}_blocks" in mod._parameters: + del mod._parameters[f"{proj}_blocks"] + if f"{proj}_scales" in mod._parameters: + del mod._parameters[f"{proj}_scales"] + + _Mxfp4Q._process_model_after_weight_loading = _patched_post_load + _Mxfp4Q._unsloth_post_swizzle_patched = True + + # triton_kernels MXFP4 needs CUDA; native MXFP4 itself works across + # CUDA compute capabilities (StridedLayout / HopperMXValueLayout / + # BlackwellMXValueLayout) and XPU has its own native path. Only + # force dequantize when no supported accelerator is visible. + if not getattr(_Mxfp4Q, "_unsloth_cpu_gate_patched", False): + _orig_before_load = _Mxfp4Q._process_model_before_weight_loading + + def _patched_before_load(self, model, use_kernels=False, **kwargs): + try: + import torch as _torch + cuda_ok = _torch.cuda.is_available() + xpu_ok = ( + hasattr(_torch, "xpu") and _torch.xpu.is_available() + ) + if ( + not cuda_ok + and not xpu_ok + and not use_kernels + and not getattr(self.quantization_config, "dequantize", False) + ): + try: + self.quantization_config.dequantize = True + except Exception as _cfg_err: + import warnings as _w + _w.warn( + f"[unsloth] Could not force MXFP4 dequantize " + f"on CPU-only host: {_cfg_err!r}; native " + f"MXFP4 path may crash without triton_kernels." + ) + except Exception as _det_err: + import warnings as _w + _w.warn( + f"[unsloth] MXFP4 pre-load device detection failed: " + f"{_det_err!r}; falling back to upstream." + ) + return _orig_before_load( + self, model, use_kernels=use_kernels, **kwargs, + ) + + _Mxfp4Q._process_model_before_weight_loading = _patched_before_load + _Mxfp4Q._unsloth_cpu_gate_patched = True + except Exception as e: + return raise_error( + "transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer post-load swizzle patch", e, + ) + if HAS_TRITON_KERNELS: # Only override is_kernels_available when triton_kernels IS available try: