Register Gemma-4 MoE LoRA extractor to fix grouped_mm contraction crash#624
Conversation
unsloth/gemma-4-26B-A4B-it crashes on the first training step with
RuntimeError: contraction dimension of mat_a and mat_b must match
raised from torch._grouped_mm in moe_utils.native_moe_grouped_mm. Cause:
the gemma4_moe patch never registered Gemma4TextExperts._unsloth_lora_extractor_fn,
so _extract_lora_from_wrapper falls through to the default canonical
permutation branch, which does not handle PEFT 0.19+ swapped 3D LoRA
layouts. The other six MoE families (qwen3_moe, qwen3_5_moe,
qwen3_next_moe, qwen3_vl_moe, glm4_moe, deepseek_v3_moe) all register
one already.
Fix: register _make_qwen_moe_lora_extractor() on Gemma4TextExperts and
the legacy Gemma4TextMoEBlock. Gemma-4 experts share the Qwen-MoE
standard (E, out, in) layout and the same hidden_dim / intermediate_dim
attribute names, so the Qwen extractor applies verbatim. per_expert_scale
is folded into top_k_weights upstream by Gemma4TextRouter.forward, so
the extractor does not need any Gemma-4-specific scale handling.
Idempotent and gated on transformers exposing the Gemma-4 module, so it
no-ops on transformers 4.x and on Gemma-4-free installations.
Validated by running fwd plus bwd on unsloth/gemma-4-26B-A4B-it with
LoRA over experts.gate_up_proj and experts.down_proj.
There was a problem hiding this comment.
Code Review
This pull request introduces a LoRA extractor registration for Gemma-4 MoE models by reusing the Qwen-MoE standard-layout extractor. This change addresses a potential RuntimeError regarding mismatched contraction dimensions in torch._grouped_mm when using PEFT 0.19+ layouts. The registration is integrated into both current and legacy patching paths for Gemma-4 MoE. I have no feedback to provide.
Review notesScopeSingle file, +50/-1: What I verified
TestsI added a small synthetic regression test that does not depend on
Suite result on PR head, I am happy to open a small follow-up PR with this test file if useful. What was not run locallyFull real-model fwd+bwd on SuggestionTwo small things to consider for a follow-up, neither blocking:
LGTM otherwise. The fix is the minimal correct change, parallels the existing six MoE families exactly, and the no-op-on-missing-gemma4 path is preserved. |
Re-ran on transformers latestBumped Layout claims now verified against the real upstream class
top_k_weights = top_k_weights * self.per_expert_scale[top_k_index]so per-expert scale is folded into routing weights upstream and the LoRA extractor does not need any Gemma-4-specific scale handling. Patch state on the real classAfter
Per-expert delta parity on the real classDrove the registered extractor against the real
Tests(31 existing + 11 new from What is still pending end-to-endDid not run a full |
Compatibility sweep across transformers 5.5.0..5.8.0Ran the full test suite plus a real-class smoke test (instantiate Environment: PR head
Every version exposes the current layout ( No flaky behavior, no skipped versions, no warnings beyond an unrelated swig deprecation. The PR is compatible with the entire transformers 5.5.x..5.8.x window. |
Cross-family MoE sweep on transformers 5.8.0To make sure PR #624's Gemma-4 registration didn't introduce a regression for any sibling MoE family, I instantiated every unsloth-zoo MoE patch target from a tiny synthetic config (no checkpoint download), called the registered LoRA extractor against hand-built PEFT 0.18 raw and PEFT 0.19 swapped 3D LoRA factors for both
PR #624's target (gemma4_moe) and every Qwen family that the extractor reuse pattern came from pass cleanly. DeepSeek-V3 also passes. Note on glm4_moe_lite (independent of PR #624)
For synthetic configs outside that regime the heuristic mis-orients the LoRA factors:
This is independent of #624 and predates it; PR #624 specifically reuses the Qwen extractor ( SummaryPR #624 is a clean, additive registration plumbing change that does not regress any sibling MoE family. The extractor it reuses ( |
Dynamic regression-prevention test + cross-version coverageDo we need to fix anything?No, not for this PR. The cross-family sweep showed every other MoE family ( One dynamic test that prevents the PR #624 regression classAdded
That is exactly the PR #624 failure mode: forward was patched to use the grouped-MoE backend but no extractor was registered, so the default permutation in Discovery is fully dynamic and version-agnostic:
Plus an opportunistic per-expert parity check: for each discovered class, tries to instantiate via the sibling CPU-only (uses Negative-control checkTo confirm the test would catch the original PR #624 regression, I ran: import unsloth_zoo.temporary_patches # apply patches
del transformers.models.gemma4.modeling_gemma4.Gemma4TextExperts._unsloth_lora_extractor_fn
# run the testResult: Cross-version coverage on the new testRan the new test on every transformers 5.X.X release on PyPI (5.0.0 through 5.8.0):
The discovery count grows as MoE families ship upstream (Gemma-4 lands in 5.5); the test stays green because PR #624 closed the only registration gap. Total test count on PR head with Happy to fold both new test files into a small follow-up PR if useful. |
PR #624 fixed Gemma-4 MoE LoRA training by registering the missing _unsloth_lora_extractor_fn on Gemma4TextExperts and Gemma4TextMoEBlock. Add three test files that lock in the contract so the same regression class cannot return for any current or future MoE family. tests/test_gemma4_moe_lora_registration.py Focused unit tests for the new _register_gemma4_lora_extractor helper: registration attaches the Qwen extractor and the gemma4_moe model-type tag, repeated calls preserve callable identity, None target returns False, factory failures leave class state untouched, and per-expert delta parity holds for gate_up_proj and down_proj on both PEFT 0.18 raw and PEFT 0.19 swapped LoRA layouts. tests/test_moe_lora_extractor_coverage.py Single dynamic test that walks transformers.models.* via pkgutil and inspect, applies every TEMPORARY_PATCHES entry, and asserts that every class whose forward unsloth_zoo replaced (detected via the _original_<modeling_module>_<ClassName>_forward attribute that patch_function writes) AND whose __init__ source declares gate_up_proj and down_proj as nn.Parameters has _unsloth_lora_extractor_fn registered. Plus opportunistic per-expert parity on each discovered class via a sibling-Config-driven tiny synthetic instantiation. Discovery is transformers-version-agnostic; the test grew from finding 5 classes on transformers 5.0.0 to 7 on 5.8.0 with no source changes. tests/conftest.py Lets the suite run on GPU-free CI runners. Pre-loads the real unsloth_zoo.device_type with torch.cuda.is_available temporarily mocked so the @functools.cache on get_device_type captures "cuda"; is_available is then restored to its real value so transitive deps (torchao, dynamo) still skip CUDA init. On runners with a real accelerator the conftest is a no-op, so full-fidelity GPU testing is preserved. Also stubs torch.cuda.memory.mem_get_info so gpt_oss.py's top-level GPU memory query does not raise. Verified: full suite is 31 -> 43 passed, on transformers 5.0.0 through 5.8.0, with both CUDA_VISIBLE_DEVICES='' and a real GPU.
|
Pushed the three test files as
Suite now 43 passed (31 prior + 12 new). Verified across |
| return cls(cfg_arg) | ||
| except Exception: | ||
| continue | ||
| try: |
|
|
||
| import pytest | ||
| import torch | ||
| import torch.nn as nn |
| # Apply every TEMPORARY_PATCHES entry. Importing the package side-effect- | ||
| # populates the list; we run each entry once and ignore individual failures | ||
| # (a missing transformers submodule is the standard no-op signal). | ||
| import unsloth_zoo.temporary_patches # noqa: F401 side effect: register patches |
| try: | ||
| if hasattr(torch, "cuda") and torch.cuda.is_available(): | ||
| return True | ||
| except Exception: |
| try: | ||
| if hasattr(torch, "xpu") and torch.xpu.is_available(): | ||
| return True | ||
| except Exception: |
| try: | ||
| if hasattr(torch, "accelerator") and torch.accelerator.is_available(): | ||
| return True | ||
| except Exception: |
| try: | ||
| import torch.cuda.memory as _cuda_memory # type: ignore | ||
| _cuda_memory.mem_get_info = lambda *a, **k: (0, 80 * 1024 ** 3) | ||
| except Exception: |
New step "MoE per-family coverage + GRPO patches + grouped_gemm AST" that hardens the matrix against the recurring MoE bug class behind unslothai/unsloth-zoo#624 / #612 / #607 / #601 and unslothai/unsloth #4934 / #3598. Five clusters of pytest cases inside one shim: 1. Per-MoE-family side-effect contract (8 parametrized cases): For each `patch_*_moe` in unsloth_zoo.temporary_patches.{qwen3_moe, qwen3_5_moe, qwen3_next_moe, qwen3_vl_moe, gemma4_moe, glm4_moe, deepseek_v3_moe, gpt_oss}, look up the transformers target classes, skip when none import on this matrix cell, run the patch fn, and assert at least one importable target now carries an unsloth "patched" marker. Accepts five marker conventions used across the codebase (_unsloth_already_patched, _unsloth_lora_patched, _unsloth_lora_extractor_fn, _original_<modeling_tail>_<cls>_forward, plain _original_forward). Surfaces silent early-returns (PR #612) that escape the registration-coverage test. gpt_oss specifically reads UNSLOTH_MODEL_NAME and only runs on transformers >= 5; the shim sets the env var via monkeypatch and skips on the 4.57.6 cell with a documented reason. 2. PR #4934 (TRL 1.0 GRPO disable_gradient_checkpointing): rebinding contract. After patch_trl_disable_gradient_checkpointing(), the no-op decorated function MUST be the symbol on trl.models.utils AND every trl.* module that imported it by reference. Skips on TRL < 1.0 (no symbol present). 3. PR #3598 (gradient_accumulation): patch_gradient_accumulation_fix on a vanilla transformers.Trainer must run cleanly without raising AND be idempotent. Catches future double-scale or import-injection regressions in the source rewriter. 4. unsloth/kernels/moe/grouped_gemm AST smoke: walks every .py under the directory (12 files) and asserts ast.parse succeeds. Triton kernels are GPU-only at runtime, but a syntax error in source surfaces as ImportError on every install. Also sanity-checks the directory layout (interface.py, kernels/forward.py, kernels/backward.py, reference/moe_block.py, reference/moe_ops.py must exist). Local verification on host TRL 0.25.1 + transformers 4.57.6: 4 pass (qwen3_moe, qwen3_vl_moe, GRPO disable-GC, grad-accum, grouped_gemm AST), 7 skip legitimately (qwen3_5/qwen3_next/gemma4/glm4/deepseek/ gpt_oss absent or version-gated). Wall-time ~10s on host; budget ~30-60s per matrix cell.
…nslothai#630) The previous canary (any class with `gate_up_proj` + `down_proj` 3D parameters) was too aggressive. transformers ships such classes in modules that unsloth_zoo does NOT patch (e.g. Llama4TextExperts has the 3D pattern but `unsloth_zoo/temporary_patches/` has no llama4.py), and on CPU-only runners some targeting patches early-exit before hitting their patch_function call. Both produced false-positive "discovery zero with surface present" failures on transformers 4.57.6 on ubuntu-latest CPU runners. Use the `_unsloth_already_patched=True` boolean each patch fn sets explicitly at the end of its successful path as the canary. That flag's presence on a class means the patch fn ran to its self- declaration line. If at least one class carries it AND discovery returned zero, the test-helper marker convention has drifted (real regression). If no class carries it, either transformers is too old or every targeting patch fn raised before its self-declaration -- this test cannot reach its target either way, so skip. This narrows the assertion to the regression the test was written to catch (PR unslothai#624: marker drift / forgotten extractor on a class the patch fn claims to own) without firing on irrelevant 3D-pattern classes or runtime-only patch failures.
Summary
unsloth/gemma-4-26B-A4B-it(Gemma-4 MoE, 128 experts) crashes on the first training step when LoRA is configured on the expert weights:raised from
torch._grouped_mminmoe_utils.native_moe_grouped_mm.Cause:
gemma4_moe.pypatchesGemma4TextExperts.forwardto use the grouped-GEMM backend but never registersGemma4TextExperts._unsloth_lora_extractor_fn.moe_utils._extract_lora_from_wrapperfalls through to the default canonical permutation branch, which does not handle PEFT 0.19+ swapped 3D LoRA layouts. The shapes it produces fail the contraction-dim check insidetorch._grouped_mm.The six existing MoE families (
qwen3_moe,qwen3_5_moe,qwen3_next_moe,qwen3_vl_moe,glm4_moe,deepseek_v3_moe) all register an extractor already;gemma4_moewas the only one missing it.Fix
Register
_make_qwen_moe_lora_extractor()onGemma4TextExperts(current Transformers 5.x layout) and on the legacyGemma4TextMoEBlock(Transformers 4.x layout). Gemma-4 experts share the Qwen-MoE standard(E, out, in)layout and the samehidden_dim/intermediate_dimattribute names, so the Qwen extractor applies verbatim:per_expert_scaleis folded intotop_k_weightsupstream byGemma4TextRouter.forward, so the extractor needs no Gemma-4-specific handling.The change is one new helper
_register_gemma4_lora_extractorplus three call sites inside_patch_gemma4_moe_currentand_patch_gemma4_moe_legacy. Single-file diff.Compatibility
from transformers.models.gemma4.modeling_gemma4 import ..., so transformers 4.57.6 is unaffected (no Gemma-4 module)._did_swap_in_out_features, so TRL 0.22.2 / 0.27.1 / 1.0.0 all keep working._unsloth_lora_extractor_registeredflag.Test plan
unsloth/gemma-4-26B-A4B-itwith LoRA overexperts.gate_up_projandexperts.down_proj. Before this PR: contraction-dim crash on first step. After this PR: finite loss, clean grad.gemma4_moe.py).