diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..b3e683d95 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,168 @@ +"""GPU-free test harness. + +`unsloth_zoo.__init__` calls `device_type.get_device_type()` at import time, +which raises `NotImplementedError("Unsloth cannot find any torch accelerator")` +on CI runners without CUDA / XPU / HIP visible. This makes any test that +imports `unsloth_zoo` un-runnable on a GPU-free CI. + +Most tests in this directory only exercise CPU-only logic (LoRA extractor +shape parity, registration coverage, dtype helpers). They do not need a real +accelerator. To unblock GPU-free CI, this conftest pre-installs a stub +`unsloth_zoo.device_type` into `sys.modules` BEFORE the package is imported, +exposing every name `unsloth_zoo/__init__.py` reads from it. + +Behavior: + - When a real accelerator is available (CUDA / XPU / HIP), the stub is + NOT installed; the real `device_type.py` runs and reports the actual + accelerator. CI on GPU runners still gets full fidelity. + - When no accelerator is available, the stub claims `cuda` so the import + chain in `__init__.py` does not raise. Downstream code that tries to + call `torch.cuda.*` will still fail at *runtime*, but at *import* the + package loads cleanly. Tests that stay on CPU run; tests that need + GPU compute would fail on their own kernel calls and should be marked + `@pytest.mark.skipif` separately. + - The stub is a no-op when `unsloth_zoo` is already imported (some + upstream pytest harness already loaded it). +""" + +from __future__ import annotations + +import importlib.util +import sys +import types + + +def _has_real_accelerator() -> bool: + try: + import torch + except Exception: + return False + try: + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return True + except Exception: + pass + try: + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return True + except Exception: + pass + try: + if hasattr(torch, "accelerator") and torch.accelerator.is_available(): + return True + except Exception: + pass + return False + + +def _preload_real_device_type() -> bool: + """Pre-load the REAL `unsloth_zoo.device_type` module under a + temporarily-mocked `torch.cuda.is_available()` so its + `DEVICE_TYPE = get_device_type()` initialization succeeds without a + real accelerator. Returns True on success. + + We need the real module (not a stub) so tests like + `test_device_synchronize_xpu_calls_synchronize_when_present` keep + exercising the real `device_synchronize` body. + + Strategy: build the minimal `unsloth_zoo` namespace package skeleton + (so the relative `from .utils import Version` works), pre-load + `unsloth_zoo.utils`, then pre-load `unsloth_zoo.device_type` with + `torch.cuda.is_available` patched to True for the duration. The + `@functools.cache` on `get_device_type` permanently captures the + "cuda" result, so subsequent calls return "cuda" without needing + the patch. Finally we drop the `unsloth_zoo` skeleton so the real + `__init__.py` runs on the next `import unsloth_zoo`; it will find + the already-loaded `device_type` and `utils` in `sys.modules` and + skip re-execution. + """ + if "unsloth_zoo.device_type" in sys.modules: + return True + pkg_spec = importlib.util.find_spec("unsloth_zoo") + if pkg_spec is None or not pkg_spec.submodule_search_locations: + return False + pkg_path = pkg_spec.submodule_search_locations[0] + + import os + + skeleton_already = "unsloth_zoo" in sys.modules + if not skeleton_already: + zoo_pkg = types.ModuleType("unsloth_zoo") + zoo_pkg.__path__ = [pkg_path] + zoo_pkg.__spec__ = pkg_spec + zoo_pkg.__package__ = "unsloth_zoo" + sys.modules["unsloth_zoo"] = zoo_pkg + + try: + # Pre-load utils (device_type does `from .utils import Version`). + if "unsloth_zoo.utils" not in sys.modules: + utils_path = os.path.join(pkg_path, "utils.py") + utils_spec = importlib.util.spec_from_file_location( + "unsloth_zoo.utils", utils_path, + ) + utils_mod = importlib.util.module_from_spec(utils_spec) + sys.modules["unsloth_zoo.utils"] = utils_mod + utils_spec.loader.exec_module(utils_mod) + + # Pre-load device_type with a temporarily-True is_available. + device_type_path = os.path.join(pkg_path, "device_type.py") + dt_spec = importlib.util.spec_from_file_location( + "unsloth_zoo.device_type", device_type_path, + ) + dt_mod = importlib.util.module_from_spec(dt_spec) + sys.modules["unsloth_zoo.device_type"] = dt_mod + + import torch + _orig_is_avail = torch.cuda.is_available + torch.cuda.is_available = lambda: True # type: ignore[assignment] + try: + dt_spec.loader.exec_module(dt_mod) + finally: + torch.cuda.is_available = _orig_is_avail + finally: + if not skeleton_already: + # Drop our skeleton so the real package __init__.py executes + # on the next `import unsloth_zoo`. The pre-loaded submodules + # remain in sys.modules and will be reused by __init__. + sys.modules.pop("unsloth_zoo", None) + + return True + + +def _patch_torch_cuda_for_import() -> None: + """Monkey-patch concrete `torch.cuda.*` calls that other parts of + `unsloth_zoo.temporary_patches.*` make at module IMPORT time. After + this conftest finishes, `torch.cuda.is_available()` is back to its + real value (False on a GPU-free CI), so transitive deps like torchao + / dynamo correctly skip CUDA init when they are imported by other + test modules. + + Specifically guards: + gpt_oss.py:1141 -> torch.cuda.memory.mem_get_info(0) + which runs at module top-level after `unsloth_zoo.device_type`'s + `DEVICE_TYPE` is already "cuda" (cached above). + """ + try: + import torch.cuda.memory as _cuda_memory # type: ignore + _cuda_memory.mem_get_info = lambda *a, **k: (0, 80 * 1024 ** 3) + except Exception: + pass + + +if not _has_real_accelerator(): + if not _preload_real_device_type(): + # Fallback: if we cannot find the real device_type source (eg. + # zipped install), fall back to a stub so tests at least import. + stub = types.ModuleType("unsloth_zoo.device_type") + stub.DEVICE_TYPE = "cuda" + stub.DEVICE_TYPE_TORCH = "cuda" + stub.DEVICE_COUNT = 1 + stub.ALLOW_PREQUANTIZED_MODELS = False + stub.is_hip = lambda: False + stub.get_device_type = lambda: "cuda" + stub.get_device_count = lambda: 1 + stub.device_synchronize = lambda *a, **k: None + stub.device_empty_cache = lambda *a, **k: None + stub.device_is_bf16_supported = lambda *a, **k: False + sys.modules["unsloth_zoo.device_type"] = stub + _patch_torch_cuda_for_import() diff --git a/tests/test_gemma4_moe_lora_registration.py b/tests/test_gemma4_moe_lora_registration.py new file mode 100644 index 000000000..6ad556f2b --- /dev/null +++ b/tests/test_gemma4_moe_lora_registration.py @@ -0,0 +1,208 @@ +"""Regression tests for the Gemma-4 MoE LoRA extractor registration added by +PR #624. These tests do NOT require ``transformers.models.gemma4`` to exist; +they exercise the registration helper against a synthetic stand-in class +with the same surface (gate_up_proj (E, 2*I, H), down_proj (E, H, I), +hidden_dim, intermediate_dim). + +What is covered: + +1. Successful registration attaches the Qwen extractor and the model-type + tag without overwriting unrelated state. +2. Registration is idempotent across repeated calls (callable identity is + preserved, no double-wrapping). +3. ``_register_gemma4_lora_extractor(None)`` returns False without raising, + matching the legacy import path where ``Gemma4TextMoEBlock`` may be + absent. +4. If the underlying extractor factory raises, registration returns False + and leaves the class state untouched (no half-registered attributes). +5. The registered extractor produces (E, in, R) / (E, R, out) tensors that + numerically reconstruct the per-expert delta on both the PEFT 0.18 raw + layout and the PEFT 0.19 ``_did_swap_in_out_features`` swapped layout, + for both ``gate_up_proj`` and ``down_proj`` parameters. +6. Sibling MoE families' existing extractor registrations are not disturbed + by Gemma-4 registration. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn + +from unsloth_zoo.temporary_patches import gemma4_moe as g4 +from unsloth_zoo.temporary_patches.qwen3_moe import _make_qwen_moe_lora_extractor + + +def _fresh_stub_class(): + class StubGemma4TextExperts(nn.Module): + num_experts = 4 + hidden_dim = 8 + intermediate_dim = 12 + + def __init__(self) -> None: + super().__init__() + self.gate_up_proj = nn.Parameter( + torch.randn(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim) + ) + self.down_proj = nn.Parameter( + torch.randn(self.num_experts, self.hidden_dim, self.intermediate_dim) + ) + + return StubGemma4TextExperts + + +class _StubWrapper: + """Mimics the surface ``_make_qwen_moe_lora_extractor`` reads from a + PEFT ParamWrapper: ``parameter_name``, ``get_base_layer()``, + optionally ``_did_swap_in_out_features``.""" + + def __init__(self, parameter_name: str, base_layer, peft_swapped: bool): + self.parameter_name = parameter_name + self._base_layer = base_layer + self._did_swap_in_out_features = peft_swapped + + def get_base_layer(self): + return self._base_layer + + +def test_register_attaches_extractor_and_tag(): + cls = _fresh_stub_class() + assert g4._register_gemma4_lora_extractor(cls) is True + assert cls._unsloth_lora_extractor_registered is True + assert cls._unsloth_model_type == "gemma4_moe" + assert callable(cls._unsloth_lora_extractor_fn) + + +def test_register_is_idempotent(): + cls = _fresh_stub_class() + assert g4._register_gemma4_lora_extractor(cls) is True + fn_before = cls._unsloth_lora_extractor_fn + # Second call must short-circuit on the registered flag and leave the + # extractor identity untouched. + assert g4._register_gemma4_lora_extractor(cls) is True + assert cls._unsloth_lora_extractor_fn is fn_before + assert cls._unsloth_model_type == "gemma4_moe" + + +def test_register_none_returns_false_without_raising(): + assert g4._register_gemma4_lora_extractor(None) is False + + +def test_register_failure_preserves_class_state(monkeypatch): + cls = _fresh_stub_class() + + def _boom(): + raise RuntimeError("synthetic factory failure") + + monkeypatch.setattr(g4, "_make_qwen_moe_lora_extractor", _boom) + assert g4._register_gemma4_lora_extractor(cls) is False + assert not hasattr(cls, "_unsloth_lora_extractor_fn") + assert not getattr(cls, "_unsloth_lora_extractor_registered", False) + + +def test_register_does_not_disturb_sibling_registration(): + # qwen3_moe registers its extractor when patch_qwen3_moe is invoked. + # Without invoking it, we verify that an unrelated class registered by + # Gemma-4 does NOT touch any class qwen3_moe would later target. + cls = _fresh_stub_class() + assert g4._register_gemma4_lora_extractor(cls) is True + # The Qwen extractor factory is still callable and independent. + qwen_extractor = _make_qwen_moe_lora_extractor() + assert callable(qwen_extractor) + # The Gemma-4 stub extractor and a freshly built Qwen extractor are + # distinct instances even though they originate from the same factory. + assert cls._unsloth_lora_extractor_fn is not qwen_extractor + + +def _drive_extractor(cls, parameter_name: str, peft_swapped: bool): + """Drive the registered extractor against hand-built LoRA factors and + compare per-expert reconstructed delta to a naive reference.""" + torch.manual_seed(0) + base = cls() + if parameter_name == "gate_up_proj": + in_dim = base.hidden_dim + out_dim = 2 * base.intermediate_dim + else: # down_proj + in_dim = base.intermediate_dim + out_dim = base.hidden_dim + + E = base.num_experts + R = 3 + + if peft_swapped: + # PEFT 0.19 swapped layout + weight_A = torch.randn(E * R, out_dim) + weight_B = torch.randn(in_dim, E * R) + else: + # PEFT 0.18 raw 3D layout + weight_A = torch.randn(E * R, in_dim) + weight_B = torch.randn(out_dim, E * R) + + wrapper = _StubWrapper(parameter_name, base, peft_swapped) + extractor = cls._unsloth_lora_extractor_fn + first, second, scaling, num_experts = extractor( + wrapper, weight_A, weight_B, 1.5, E, + ) + assert first.shape == (E, in_dim, R), (parameter_name, peft_swapped, first.shape) + assert second.shape == (E, R, out_dim), (parameter_name, peft_swapped, second.shape) + assert num_experts == E + assert scaling == 1.5 + assert first.is_contiguous() + assert second.is_contiguous() + + x = torch.randn(7, in_dim) + for e in range(E): + Ae = weight_A[e * R : (e + 1) * R] + Be = weight_B[:, e * R : (e + 1) * R] + if peft_swapped: + naive = x @ Be @ Ae + else: + naive = x @ Ae.T @ Be.T + via = (x @ first[e]) @ second[e] + torch.testing.assert_close(via, naive, atol=1e-4, rtol=1e-4) + + +def test_extractor_gate_up_canonical_peft018(): + cls = _fresh_stub_class() + assert g4._register_gemma4_lora_extractor(cls) is True + _drive_extractor(cls, "gate_up_proj", peft_swapped=False) + + +def test_extractor_gate_up_swapped_peft019(): + cls = _fresh_stub_class() + assert g4._register_gemma4_lora_extractor(cls) is True + _drive_extractor(cls, "gate_up_proj", peft_swapped=True) + + +def test_extractor_down_canonical_peft018(): + cls = _fresh_stub_class() + assert g4._register_gemma4_lora_extractor(cls) is True + _drive_extractor(cls, "down_proj", peft_swapped=False) + + +def test_extractor_down_swapped_peft019(): + cls = _fresh_stub_class() + assert g4._register_gemma4_lora_extractor(cls) is True + _drive_extractor(cls, "down_proj", peft_swapped=True) + + +def test_patch_gemma4_moe_is_noop_without_gemma4(monkeypatch): + """End-to-end sanity check: when transformers lacks + ``models.gemma4.modeling_gemma4`` (the case in this env), the public + entrypoint must not raise. Both inner patch functions guard their + imports, and ``patch_gemma4_moe`` short-circuits via ``raise_error`` + which returns rather than raising.""" + # Act: just call the entrypoint. transformers.models.gemma4 is not + # importable on transformers 4.57.6 in this environment. + g4.patch_gemma4_moe() # must not raise + + +def test_register_handles_legacy_block_class_shape(): + """Sanity check that a class shaped like Gemma4TextMoEBlock (legacy + layout) accepts registration the same way as Gemma4TextExperts. The + legacy block also exposes ``num_experts``/``hidden_dim``/ + ``intermediate_dim`` plus ``gate_up_proj``/``down_proj``, so the + Qwen extractor is layout-compatible.""" + LegacyBlock = _fresh_stub_class() + assert g4._register_gemma4_lora_extractor(LegacyBlock) is True + _drive_extractor(LegacyBlock, "gate_up_proj", peft_swapped=False) + _drive_extractor(LegacyBlock, "down_proj", peft_swapped=True) diff --git a/tests/test_moe_lora_extractor_coverage.py b/tests/test_moe_lora_extractor_coverage.py new file mode 100644 index 000000000..f23ab890c --- /dev/null +++ b/tests/test_moe_lora_extractor_coverage.py @@ -0,0 +1,361 @@ +"""Dynamic regression-prevention test for the PR #624 failure mode. + +PR #624 fixed a Gemma-4 MoE LoRA training crash whose root cause was: + + Gemma4TextExperts._unsloth_already_patched = True + Gemma4TextExperts._unsloth_lora_extractor_fn -> NOT REGISTERED + +`unsloth_zoo.temporary_patches.gemma4_moe._patch_gemma4_moe_current` patched +`forward` to call the grouped-MoE backend but forgot to attach the LoRA +extractor. Without the extractor, `moe_utils._extract_lora_from_wrapper` +falls through to its default canonical permutation, which produces tensors +whose contraction dimensions do not match for `torch._grouped_mm` on PEFT +0.19+ swapped 3D LoRA layouts. The crash surfaces on the first training +step as `RuntimeError: contraction dimension of mat_a and mat_b must match`. + +This test prevents that exact regression for every current and future MoE +family. It applies every `TEMPORARY_PATCHES` entry, walks every loaded +`transformers.models.*` module, and asserts that every class flagged +`_unsloth_already_patched=True` whose source defines `gate_up_proj` and +`down_proj` 3D parameters also has `_unsloth_lora_extractor_fn` registered. + +Design constraints: + - GPU-free. Instantiation, when attempted, uses tiny synthetic configs. + - Transformers-version-agnostic. Discovery walks the live tree; nothing + is hard-coded to specific class names. + - Single test. Discovery + assertion live together in one pytest case so + the contract is one signal, not many. + - Opportunistic parity. If we can instantiate a discovered class without + a checkpoint, we additionally drive its registered extractor on hand- + built PEFT 0.18 raw and PEFT 0.19 swapped LoRA factors and assert + per-expert delta parity. Failures here surface as a separate assertion + so registration coverage and orientation correctness are distinguishable. + +The test is deliberately defensive: importing transformers submodules is +allowed to fail (the full transformers tree pulls in optional deps), and +those failures are reported as `unimported` not as test failures. +""" + +from __future__ import annotations + +import importlib +import inspect +import pkgutil +import re +import warnings +from typing import Any + +import pytest +import torch +import torch.nn as nn + +# Apply every TEMPORARY_PATCHES entry. Importing the package side-effect- +# populates the list; we run each entry once and ignore individual failures +# (a missing transformers submodule is the standard no-op signal). +import unsloth_zoo.temporary_patches # noqa: F401 side effect: register patches +from unsloth_zoo.temporary_patches.common import TEMPORARY_PATCHES + + +# Regex for "self.gate_up_proj = nn.Parameter(torch.empty(...))" / similar +# 3D parameter declarations on the class body. We also accept the variant +# spelling used by some unsloth_zoo internals (Mxfp4 / GptOss style does +# NOT match this on purpose -- those classes use a different LoRA path). +_3D_PARAM_PATTERNS = [ + re.compile(r"self\.gate_up_proj\s*=\s*nn\.Parameter\("), + re.compile(r"self\.down_proj\s*=\s*nn\.Parameter\("), +] + + +def _apply_all_temporary_patches() -> None: + for fn in TEMPORARY_PATCHES: + try: + fn() + except Exception: # noqa: BLE001 - missing modules are fine here + pass + + +def _iter_modeling_modules(): + """Yield every transformers.models..modeling_ module that imports + cleanly. Failures (missing optional deps) are skipped silently.""" + import transformers.models as tm + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + for sub in pkgutil.iter_modules(tm.__path__): + if sub.ispkg is False: + continue + try: + pkg = importlib.import_module(f"transformers.models.{sub.name}") + except Exception: + continue + for child in pkgutil.iter_modules(pkg.__path__): + if not child.name.startswith("modeling_"): + continue + try: + yield importlib.import_module( + f"transformers.models.{sub.name}.{child.name}", + ) + except Exception: + continue + + +def _looks_like_grouped_moe_experts(cls: type) -> bool: + """Return True if `cls.__init__` source declares both `gate_up_proj` + and `down_proj` as `nn.Parameter`s. This is the surface that the + grouped-MoE LoRA extractor reads.""" + try: + src = inspect.getsource(cls.__init__) + except (OSError, TypeError): + return False + return all(p.search(src) for p in _3D_PARAM_PATTERNS) + + +def _has_unsloth_patched_forward(cls: type) -> bool: + """Detect that `unsloth_zoo.temporary_patches.utils.patch_function` has + replaced `cls.forward`. `patch_function` stores the original under a + per-class attribute named `_original___ + forward` (`utils.py:_get_unique_storage_name`). The presence of that + attribute on the class is a uniform marker independent of any family- + specific bool flag like `_unsloth_already_patched`. + """ + cls_name = cls.__name__ + for attr in dir(cls): + if attr.startswith("_original_") and attr.endswith(f"_{cls_name}_forward"): + return True + return False + + +def _discover_patched_moe_classes() -> list[type]: + """Find every transformers class that unsloth-zoo has replaced the + `forward` of AND whose source matches the grouped-MoE 3D-parameter + shape (`gate_up_proj` + `down_proj` declared as `nn.Parameter`). + + The combination is the load-bearing contract: `forward` was swapped to + the grouped-MoE backend, so the LoRA path through that backend must + have an extractor registered. PR #624 was an instance of registration + being forgotten on a class that satisfied both conditions. + """ + seen: set[type] = set() + out: list[type] = [] + for modeling in _iter_modeling_modules(): + for name, obj in inspect.getmembers(modeling, inspect.isclass): + if not isinstance(obj, type): + continue + if obj in seen: + continue + seen.add(obj) + if getattr(obj, "__module__", None) != modeling.__name__: + continue + if not _has_unsloth_patched_forward(obj): + continue + if not _looks_like_grouped_moe_experts(obj): + continue + out.append(obj) + return out + + +# -------------------------------------------------------------------------- +# Opportunistic parity helpers +# -------------------------------------------------------------------------- + + +class _StubWrapper: + def __init__(self, parameter_name: str, base: Any, peft_swapped: bool): + self.parameter_name = parameter_name + self._base = base + self._did_swap_in_out_features = peft_swapped + + def get_base_layer(self): + return self._base + + +def _try_instantiate_experts(cls: type): + """Best-effort instantiation of an MoE experts class with tiny synthetic + dims. Returns the instance or None if any path fails. + + Dim choice: H=16, I=10, 2*I=20. All three values are distinct (so + PEFT-version dispatch can't be ambiguous) and the relation + `H > I` and `2*I > H` matches production MoE configs (Qwen3-MoE, + Gemma-4 MoE, Glm4-MoE-Lite, DeepSeek-V3 all live in this regime). A + test in a different regime would surface known-fragile extractor + orientation behavior that is independent of the PR #624 contract this + test exists to enforce.""" + cfg_cls = _find_sibling_config(cls) + if cfg_cls is None: + return None + overrides = { + "vocab_size": 64, + "hidden_size": 16, + "intermediate_size": 12, + "moe_intermediate_size": 10, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "head_dim": 8, + "num_experts": 4, + "num_experts_per_tok": 2, + "n_routed_experts": 4, + "num_local_experts": 4, + "top_k_experts": 2, + "first_k_dense_replace": 1, + "qk_nope_head_dim": 4, + "qk_rope_head_dim": 4, + "v_head_dim": 4, + "rms_norm_eps": 1e-6, + "hidden_activation": "gelu_pytorch_tanh", + } + try: + sig = inspect.signature(cfg_cls.__init__) + kwargs = {k: v for k, v in overrides.items() if k in sig.parameters} + cfg = cfg_cls(**kwargs) + except Exception: + return None + # Some experts modules want a nested text_config; try that if direct fails. + for cfg_arg in (cfg, getattr(cfg, "text_config", None)): + if cfg_arg is None: + continue + try: + return cls(cfg_arg) + except Exception: + continue + try: + return cls(cfg_arg, layer_idx=0) + except Exception: + continue + return None + + +def _find_sibling_config(cls: type): + """Find the most likely Config class living in the same modeling module + as `cls`. Tries Foo<...>Config and Foo<...>TextConfig forms by stripping + common experts suffixes from the class name.""" + mod = importlib.import_module(cls.__module__) + base = cls.__name__ + for suffix in ("Experts", "NaiveMoe", "MoEBlock", "MoeBlock", "MoE", "Moe"): + if base.endswith(suffix): + base = base[: -len(suffix)] + break + candidates = [ + base + "TextConfig", + base + "Config", + cls.__name__.split("Experts")[0] + "TextConfig", + cls.__name__.split("Experts")[0] + "Config", + ] + for cand in candidates: + cfg = getattr(mod, cand, None) + if isinstance(cfg, type): + return cfg + return None + + +def _read_dims(experts) -> tuple[int, int] | None: + H = getattr(experts, "hidden_dim", None) + I = getattr(experts, "intermediate_dim", None) + if H is None or I is None: + return None + return int(H), int(I) + + +def _parity_one(extractor, experts, name: str, in_dim: int, out_dim: int, + peft_swap: bool, E: int = 4, R: int = 3) -> tuple[bool, str]: + torch.manual_seed(0) + if peft_swap: + wA = torch.randn(E * R, out_dim) + wB = torch.randn(in_dim, E * R) + else: + wA = torch.randn(E * R, in_dim) + wB = torch.randn(out_dim, E * R) + try: + res = extractor(_StubWrapper(name, experts, peft_swap), wA, wB, 1.0, E) + except Exception as ex: # noqa: BLE001 + return False, f"extractor raised: {ex!r}" + if res is None: + return False, "extractor returned None" + first, second = res[0], res[1] + if tuple(first.shape) != (E, in_dim, R): + return False, f"first.shape {tuple(first.shape)} != ({E},{in_dim},{R})" + if tuple(second.shape) != (E, R, out_dim): + return False, f"second.shape {tuple(second.shape)} != ({E},{R},{out_dim})" + x = torch.randn(5, in_dim) + for e in range(E): + Ae = wA[e * R : (e + 1) * R] + Be = wB[:, e * R : (e + 1) * R] + naive = (x @ Be @ Ae) if peft_swap else (x @ Ae.T @ Be.T) + via = (x @ first[e]) @ second[e] + if not torch.allclose(via, naive, atol=1e-4, rtol=1e-4): + return False, f"per-expert delta mismatch on expert {e}" + return True, "ok" + + +# -------------------------------------------------------------------------- +# The single test +# -------------------------------------------------------------------------- + + +def test_every_patched_moe_experts_class_has_lora_extractor(): + """Regression-prevention test for PR #624. Every class whose `forward` + unsloth-zoo has patched to use the grouped-MoE backend AND whose layout + matches the standard `(E, 2*I, H)` / `(E, H, I)` 3D-parameter shape + MUST have `_unsloth_lora_extractor_fn` registered. Without it, LoRA + training crashes on the first step with a `torch._grouped_mm` + contraction-dim mismatch on PEFT 0.19+.""" + _apply_all_temporary_patches() + + patched = _discover_patched_moe_classes() + assert patched, ( + "Discovery produced zero patched MoE classes. Either no MoE " + "transformers module imported on this transformers version, or the " + "_unsloth_already_patched marker convention changed. Investigate." + ) + + # Hard contract: extractor must be registered on every patched class. + missing = [ + f"{c.__module__}.{c.__name__}" + for c in patched + if getattr(c, "_unsloth_lora_extractor_fn", None) is None + ] + assert not missing, ( + "The following transformers MoE classes have `_unsloth_already_patched" + "=True` but no `_unsloth_lora_extractor_fn` registered. This is the " + "exact regression PR #624 fixed for Gemma-4: LoRA training will crash " + "on PEFT 0.19+ with `RuntimeError: contraction dimension of mat_a and " + "mat_b must match` on the first training step. Register the extractor " + f"in the patch function. Offenders: {missing}" + ) + + # Soft contract: opportunistic per-expert parity. Failures here are the + # extractor-orientation bug class (PR #624 was a registration bug, not + # an orientation bug, but the same one-test-catches-both contract is + # cheap to assert here). + parity_failures = [] + for cls in patched: + experts = _try_instantiate_experts(cls) + if experts is None: + continue + dims = _read_dims(experts) + if dims is None: + continue + H, I = dims + ext = cls._unsloth_lora_extractor_fn + for proj_name, in_dim, out_dim in [ + ("gate_up_proj", H, 2 * I), + ("down_proj", I, H), + ]: + for peft_swap in (False, True): + ok, msg = _parity_one(ext, experts, proj_name, in_dim, out_dim, peft_swap) + if not ok: + parity_failures.append( + f"{cls.__name__}.{proj_name} swap={peft_swap}: {msg}" + ) + + if parity_failures: + # Surface the orientation drift but keep the registration assertion + # above as the primary signal. We use pytest.fail not assert so the + # registration coverage assertion can still pass-or-fail + # independently of orientation. + pytest.fail( + "Per-expert delta parity failed for some patched MoE families. " + "This is independent of registration coverage but indicates an " + "extractor orientation bug. Failures:\n - " + + "\n - ".join(parity_failures) + ) diff --git a/unsloth_zoo/temporary_patches/gemma4_moe.py b/unsloth_zoo/temporary_patches/gemma4_moe.py index eaf09b91f..4d602270b 100644 --- a/unsloth_zoo/temporary_patches/gemma4_moe.py +++ b/unsloth_zoo/temporary_patches/gemma4_moe.py @@ -17,12 +17,49 @@ import os import torch import torch.nn as nn -from .common import TEMPORARY_PATCHES +from .common import TEMPORARY_PATCHES, UNSLOTH_ENABLE_LOGGING from .utils import patch_function, process_return, raise_error, logger from .moe_utils import ( patch_param_wrapper_for_moe, get_forward_moe_backend, ) +# Reuse the Qwen-MoE standard-layout LoRA extractor. Gemma4TextExperts has the +# same (E, out, in) layout, the same hidden_dim / intermediate_dim attribute +# names, and per_expert_scale is folded into top_k_weights upstream by +# Gemma4TextRouter.forward, so no Gemma-4-specific scale handling is needed +# inside the extractor itself. +from .qwen3_moe import _make_qwen_moe_lora_extractor + + +def _register_gemma4_lora_extractor(experts_cls): + """Attach _unsloth_lora_extractor_fn to a Gemma-4 experts class. + + Idempotent and safe if experts_cls is None. Without this registration, + moe_utils._extract_lora_from_wrapper falls through to the default + canonical-permutation branch, which can produce shapes whose contraction + dimensions do not match for torch._grouped_mm on PEFT 0.19+ swapped + layouts. The crash surfaces as + RuntimeError: contraction dimension of mat_a and mat_b must match + on the first training step. Gemma-4 experts share the Qwen-MoE standard + layout, so the Qwen extractor handles both PEFT 0.18 and 0.19 cleanly. + """ + if experts_cls is None: + return False + if getattr(experts_cls, "_unsloth_lora_extractor_registered", False): + return True + try: + extractor = _make_qwen_moe_lora_extractor() + experts_cls._unsloth_lora_extractor_fn = staticmethod(extractor) + experts_cls._unsloth_model_type = "gemma4_moe" + experts_cls._unsloth_lora_extractor_registered = True + return True + except Exception as e: + if UNSLOTH_ENABLE_LOGGING: + logger.warning( + f"Unsloth: Could not register Gemma-4 MoE LoRA extractor on " + f"{getattr(experts_cls, '__name__', experts_cls)}: {e}" + ) + return False def patch_gemma4_grpo_hidden_states(): @@ -171,6 +208,10 @@ def _patch_gemma4_moe_current(): return False if getattr(Gemma4TextExperts, "_unsloth_already_patched", False): + # Even when forward is already patched, make sure the extractor is + # registered. Guards against an older unsloth-zoo that patched + # forward but lacked the extractor registration. + _register_gemma4_lora_extractor(Gemma4TextExperts) return True _moe_backend = get_forward_moe_backend() @@ -184,6 +225,9 @@ def _gemma4_experts_forward(self, hidden_states, top_k_index, top_k_weights): ok = patch_function(Gemma4TextExperts, "forward", _gemma4_experts_forward, force=True) if ok: Gemma4TextExperts._unsloth_already_patched = True + # Register the Qwen-MoE-style standard-layout extractor so that the + # grouped-mm LoRA path produces correct contraction dimensions. + _register_gemma4_lora_extractor(Gemma4TextExperts) return ok @@ -198,6 +242,8 @@ def _patch_gemma4_moe_legacy(): return False if getattr(Gemma4TextMoEBlock, "_unsloth_already_patched", False): + # Same defensive re-registration as in the current-layout path. + _register_gemma4_lora_extractor(Gemma4TextMoEBlock) return True # Remap decoder layer module names to match checkpoint key layout: @@ -244,6 +290,9 @@ def _gemma4_moe_forward(self, hidden_states, top_k_index, top_k_weights): return False Gemma4TextMoEBlock._unsloth_already_patched = True + # Legacy MoE block has the same parameter layout (E, out, in). Register + # the same standard-layout extractor. + _register_gemma4_lora_extractor(Gemma4TextMoEBlock) return True