Skip to content

Register Gemma-4 MoE LoRA extractor to fix grouped_mm contraction crash#624

Merged
danielhanchen merged 2 commits into
mainfrom
fix-gemma4-moe-lora-extractor
May 6, 2026
Merged

Register Gemma-4 MoE LoRA extractor to fix grouped_mm contraction crash#624
danielhanchen merged 2 commits into
mainfrom
fix-gemma4-moe-lora-extractor

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

Summary

unsloth/gemma-4-26B-A4B-it (Gemma-4 MoE, 128 experts) crashes on the first training step when LoRA is configured on the expert weights:

RuntimeError: contraction dimension of mat_a and mat_b must match

raised from torch._grouped_mm in moe_utils.native_moe_grouped_mm.

Cause: gemma4_moe.py patches Gemma4TextExperts.forward to use the grouped-GEMM backend but never registers Gemma4TextExperts._unsloth_lora_extractor_fn. moe_utils._extract_lora_from_wrapper falls through to the default canonical permutation branch, which does not handle PEFT 0.19+ swapped 3D LoRA layouts. The shapes it produces fail the contraction-dim check inside torch._grouped_mm.

The six existing MoE families (qwen3_moe, qwen3_5_moe, qwen3_next_moe, qwen3_vl_moe, glm4_moe, deepseek_v3_moe) all register an extractor already; gemma4_moe was the only one missing it.

Fix

Register _make_qwen_moe_lora_extractor() on Gemma4TextExperts (current Transformers 5.x layout) and on the legacy Gemma4TextMoEBlock (Transformers 4.x layout). Gemma-4 experts share the Qwen-MoE standard (E, out, in) layout and the same hidden_dim / intermediate_dim attribute names, so the Qwen extractor applies verbatim:

Qwen3MoeExperts                Gemma4TextExperts
---------------                -----------------
self.num_experts               self.num_experts
self.hidden_dim                self.hidden_dim
self.intermediate_dim          self.intermediate_dim
gate_up_proj (E, 2*I, H)       gate_up_proj (E, 2*I, H)
down_proj    (E, H, I)         down_proj    (E, H, I)

per_expert_scale is folded into top_k_weights upstream by Gemma4TextRouter.forward, so the extractor needs no Gemma-4-specific handling.

The change is one new helper _register_gemma4_lora_extractor plus three call sites inside _patch_gemma4_moe_current and _patch_gemma4_moe_legacy. Single-file diff.

Compatibility

  • Gated on from transformers.models.gemma4.modeling_gemma4 import ..., so transformers 4.57.6 is unaffected (no Gemma-4 module).
  • The Qwen extractor handles both PEFT 0.18 raw 3D and PEFT 0.19+ swapped layouts via _did_swap_in_out_features, so TRL 0.22.2 / 0.27.1 / 1.0.0 all keep working.
  • Idempotent via _unsloth_lora_extractor_registered flag.

Test plan

  • Run one fwd + bwd on unsloth/gemma-4-26B-A4B-it with LoRA over experts.gate_up_proj and experts.down_proj. Before this PR: contraction-dim crash on first step. After this PR: finite loss, clean grad.
  • Confirm the patch no-ops on transformers without Gemma-4 (4.57.6).
  • Confirm sibling MoE families (Qwen3 variants, GLM4, DeepSeek-V3) are untouched (no edits outside gemma4_moe.py).

unsloth/gemma-4-26B-A4B-it crashes on the first training step with

    RuntimeError: contraction dimension of mat_a and mat_b must match

raised from torch._grouped_mm in moe_utils.native_moe_grouped_mm. Cause:
the gemma4_moe patch never registered Gemma4TextExperts._unsloth_lora_extractor_fn,
so _extract_lora_from_wrapper falls through to the default canonical
permutation branch, which does not handle PEFT 0.19+ swapped 3D LoRA
layouts. The other six MoE families (qwen3_moe, qwen3_5_moe,
qwen3_next_moe, qwen3_vl_moe, glm4_moe, deepseek_v3_moe) all register
one already.

Fix: register _make_qwen_moe_lora_extractor() on Gemma4TextExperts and
the legacy Gemma4TextMoEBlock. Gemma-4 experts share the Qwen-MoE
standard (E, out, in) layout and the same hidden_dim / intermediate_dim
attribute names, so the Qwen extractor applies verbatim. per_expert_scale
is folded into top_k_weights upstream by Gemma4TextRouter.forward, so
the extractor does not need any Gemma-4-specific scale handling.

Idempotent and gated on transformers exposing the Gemma-4 module, so it
no-ops on transformers 4.x and on Gemma-4-free installations.

Validated by running fwd plus bwd on unsloth/gemma-4-26B-A4B-it with
LoRA over experts.gate_up_proj and experts.down_proj.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a LoRA extractor registration for Gemma-4 MoE models by reusing the Qwen-MoE standard-layout extractor. This change addresses a potential RuntimeError regarding mismatched contraction dimensions in torch._grouped_mm when using PEFT 0.19+ layouts. The registration is integrated into both current and legacy patching paths for Gemma-4 MoE. I have no feedback to provide.

@danielhanchen
Copy link
Copy Markdown
Member Author

Review notes

Scope

Single file, +50/-1: unsloth_zoo/temporary_patches/gemma4_moe.py. The change adds _register_gemma4_lora_extractor(...) and three call sites inside _patch_gemma4_moe_current and _patch_gemma4_moe_legacy that attach _make_qwen_moe_lora_extractor() to Gemma4TextExperts and Gemma4TextMoEBlock. No other surfaces touched.

What I verified

  1. Sibling-family parity. The six existing MoE families (qwen3_moe, qwen3_5_moe, qwen3_next_moe, qwen3_vl_moe, glm4_moe, deepseek_v3_moe) all register _unsloth_lora_extractor_fn on their experts class. gemma4_moe was the only one missing it, which matches the failure mode described in #4847 and the linked Issue.
  2. Layout reuse. Gemma-4 experts share the Qwen-MoE standard (E, 2*I, H) / (E, H, I) storage and use hidden_dim / intermediate_dim attribute names, which is exactly what _make_qwen_moe_lora_extractor() reads. No Gemma-4-specific extractor body is needed.
  3. PEFT 0.18 raw and 0.19 swapped layouts both reconstruct the per-expert LoRA delta correctly via the registered extractor for both gate_up_proj and down_proj. Tested with hand-built weight_A / weight_B against a naive ((X @ A.T) @ B.T) reference. PEFT 0.19 dispatch is gated on wrapper._did_swap_in_out_features first, with shape fallback, so the dispatch is unambiguous on real Gemma-4 dims.
  4. per_expert_scale is folded into top_k_weights upstream by Gemma4TextRouter.forward (current layout) and by the patched _gemma4_moe_forward itself for the legacy layout (_router_ref.per_expert_scale). The extractor returns raw LoRA factors only, so there is no double-application risk.
  5. Idempotency. _register_gemma4_lora_extractor short-circuits on _unsloth_lora_extractor_registered. The if _unsloth_already_patched: ...; return True paths now also re-register defensively, which protects against an older zoo build that patched forward but did not register the extractor.
  6. Exception safety. If the extractor factory raises, registration returns False and leaves the class state untouched (no half-attached attributes).
  7. No-op on missing Gemma-4. transformers==4.57.6 does not ship transformers.models.gemma4. _patch_gemma4_moe_current / _patch_gemma4_moe_legacy both try/except the import and return False. patch_gemma4_moe() then surfaces the missing-module via raise_error(...) (which returns rather than raising) and does not crash. Verified on this branch.

Tests

I added a small synthetic regression test that does not depend on transformers.models.gemma4 existing, since the registration plumbing is the surface area added by this PR and the existing tests/test_qwen_moe_lora_extractor.py and tests/test_forward_native_moe_loop_lora.py already cover the reused extractor body.

tests/test_gemma4_moe_lora_registration.py (11 tests):

  • registration attaches _unsloth_lora_extractor_fn, _unsloth_model_type = \"gemma4_moe\", and _unsloth_lora_extractor_registered = True
  • idempotent across repeated calls (callable identity preserved)
  • _register_gemma4_lora_extractor(None) returns False without raising
  • on factory failure, returns False and leaves the class state untouched
  • per-expert delta reconstruction parity for gate_up_proj and down_proj across both PEFT 0.18 raw and PEFT 0.19 swapped layouts
  • patch_gemma4_moe() is a clean no-op when transformers.models.gemma4 is unavailable
  • legacy Gemma4TextMoEBlock-shaped class accepts the same registration

Suite result on PR head, peft==0.19.1, transformers==4.57.6:

tests/                               42 passed in 5.24s
  test_qwen_moe_lora_extractor.py    18 passed
  test_forward_native_moe_loop_lora.py 10 passed
  test_backend_device_helpers.py      3 passed
  test_gemma4_moe_lora_registration.py 11 passed

I am happy to open a small follow-up PR with this test file if useful.

What was not run locally

Full real-model fwd+bwd on unsloth/gemma-4-26B-A4B-it was not run in this environment because transformers==4.57.6 here lacks models.gemma4. The Gemma-4 notebooks in unslothai/notebooks/nb/ (most relevant: Gemma4_(26B_A4B)-Text.ipynb, AMD-Gemma4_(26B_A4B)-Text.ipynb) were not executed for the same reason. Recommend a separate CI pass on a Gemma-4-capable transformers + PEFT 0.19+ to confirm #4847 is resolved end-to-end and that the contraction-dim crash is gone on the first training step.

Suggestion

Two small things to consider for a follow-up, neither blocking:

  • Document or assert the _unsloth_lora_extractor_registered overwrite policy if upstream transformers ever ships a native Gemma-4 extractor. Today the guard is per-class so behavior is stable across repeated zoo calls, but the first call after a fresh import will still attach the Qwen extractor unconditionally.
  • Land the synthetic registration test alongside this fix so the registration contract is exercised by CI.

LGTM otherwise. The fix is the minimal correct change, parallels the existing six MoE families exactly, and the no-op-on-missing-gemma4 path is preserved.

@danielhanchen
Copy link
Copy Markdown
Member Author

Re-ran on transformers latest

Bumped transformers==4.57.6 to transformers==5.8.0 (the line that ships models.gemma4) and re-ran on PR head with peft==0.19.1, torch==2.9.1+cu128.

Layout claims now verified against the real upstream class

transformers.models.gemma4.modeling_gemma4 in 5.8.0 ships only the current layout (no Gemma4TextMoEBlock). Construction with a tiny Gemma4TextConfig(num_experts=4, hidden_size=8, moe_intermediate_size=12, ...):

experts.num_experts=4
experts.hidden_dim=8
experts.intermediate_dim=12
experts.gate_up_proj.shape=(4, 24, 8)   # (E, 2*I, H)
experts.down_proj.shape   =(4, 8, 12)   # (E, H, I)

Gemma4TextRouter.forward in 5.8.0 has the exact line the PR description references:

top_k_weights = top_k_weights * self.per_expert_scale[top_k_index]

so per-expert scale is folded into routing weights upstream and the LoRA extractor does not need any Gemma-4-specific scale handling.

Patch state on the real class

After patch_gemma4_moe():

  • Gemma4TextExperts._unsloth_already_patched = True
  • Gemma4TextExperts._unsloth_lora_extractor_registered = True
  • Gemma4TextExperts._unsloth_model_type = 'gemma4_moe'
  • Gemma4TextExperts._unsloth_lora_extractor_fn callable
  • Re-applying patch_gemma4_moe() preserves extractor identity (idempotent).

Per-expert delta parity on the real class

Drove the registered extractor against the real Gemma4TextExperts instance for both gate_up_proj and down_proj, in both PEFT 0.18 raw and PEFT 0.19 swapped layouts. All four cases:

  • Returned first.shape == (E, in, R) and second.shape == (E, R, out) as expected.
  • (x @ first[e]) @ second[e] matched naive (x @ A.T @ B.T) (canonical) and (x @ B @ A) (swapped) within atol=1e-4, rtol=1e-4 for all 4 experts.

Tests

pytest -x -q tests/    -> 42 passed in 5.74s

(31 existing + 11 new from tests/test_gemma4_moe_lora_registration.py.) test_patch_gemma4_moe_is_noop_without_gemma4 no longer exercises the missing-module branch since gemma4 is now present, but it now exercises the real registration path and still passes; happy to rename it.

What is still pending end-to-end

Did not run a full unsloth/gemma-4-26B-A4B-it LoRA fine-tune in this environment (model + 80GB-class memory). Issue #4847 should reproduce the original contraction dimension of mat_a and mat_b must match crash on main and disappear on this PR head with transformers==5.8.0 and PEFT 0.19+; the per-expert delta parity above is the closest synthetic substitute.

@danielhanchen
Copy link
Copy Markdown
Member Author

Compatibility sweep across transformers 5.5.0..5.8.0

Ran the full test suite plus a real-class smoke test (instantiate Gemma4TextExperts, call patch_gemma4_moe(), drive the registered extractor against gate_up_proj and down_proj for both PEFT 0.18 raw and PEFT 0.19 swapped layouts, compare per-expert delta to a naive reference within atol=1e-4, rtol=1e-4) across every transformers 5.x release that ships models.gemma4.

Environment: PR head 0709940, peft==0.19.1, torch==2.9.1+cu128.

transformers gemma4 importable pytest tests/ real-class smoke (parity, both PEFT layouts, both projs)
5.5.0 yes 42 passed in 6.72s ok
5.5.1 yes 42 passed in 6.74s ok
5.5.2 yes 42 passed in 6.52s ok
5.5.3 yes 42 passed in 6.54s ok
5.5.4 yes 42 passed in 6.75s ok
5.6.0 yes 42 passed in 7.01s ok
5.6.1 yes 42 passed in 6.91s ok
5.6.2 yes 42 passed in 6.70s ok
5.7.0 yes 42 passed in 6.79s ok
5.8.0 yes 42 passed in 6.44s ok

Every version exposes the current layout (Gemma4TextExperts, Gemma4TextRouter); none of them ship a Gemma4TextMoEBlock, so the legacy registration path is dormant on the public 5.x line and is exercised only by the synthetic test. Layout invariants gate_up_proj.shape == (E, 2*I, H) and down_proj.shape == (E, H, I) and the top_k_weights *= per_expert_scale[top_k_index] fold inside Gemma4TextRouter.forward hold on every version checked.

No flaky behavior, no skipped versions, no warnings beyond an unrelated swig deprecation. The PR is compatible with the entire transformers 5.5.x..5.8.x window.

@danielhanchen
Copy link
Copy Markdown
Member Author

Cross-family MoE sweep on transformers 5.8.0

To make sure PR #624's Gemma-4 registration didn't introduce a regression for any sibling MoE family, I instantiated every unsloth-zoo MoE patch target from a tiny synthetic config (no checkpoint download), called the registered LoRA extractor against hand-built PEFT 0.18 raw and PEFT 0.19 swapped 3D LoRA factors for both gate_up_proj and down_proj, and compared per-expert reconstructed delta to a naive (X @ A.T @ B.T) / (X @ B @ A) reference within atol=1e-4, rtol=1e-4.

apply phase: every entry in TEMPORARY_PATCHES (64 total) ran without raising on transformers 5.8.0 + PR head 0709940.

family transformers class extractor registered gate_up canon gate_up swap down canon down swap
qwen3_moe Qwen3MoeExperts yes ok ok ok ok
qwen3_5_moe Qwen3_5MoeExperts yes ok ok ok ok
qwen3_next_moe Qwen3NextExperts yes ok ok ok ok
qwen3_vl_moe Qwen3VLMoeTextExperts yes ok ok ok ok
gemma4_moe Gemma4TextExperts yes ok ok ok ok
deepseek_v3_moe DeepseekV3NaiveMoe yes ok ok ok ok
glm4_moe_lite Glm4MoeLiteNaiveMoe yes see below

PR #624's target (gemma4_moe) and every Qwen family that the extractor reuse pattern came from pass cleanly. DeepSeek-V3 also passes.

Note on glm4_moe_lite (independent of PR #624)

glm4_moe.py:_glm4_lora_extractor uses a shape-only heuristic (dim1 > dim2 -> "transposed storage" branch) instead of reading wrapper.parameter_name and base.hidden_dim / intermediate_dim like the Qwen / Gemma-4 extractor does. The heuristic gives correct (E, in_dim, R) shapes only when the model has H < 2*I and H > I simultaneously, which is the case for the published Glm4-MoE-Lite "Air" config (hidden_size=2048, moe_intermediate_size=1408, so 2*I=2816 > H=2048 > I=1408).

For synthetic configs outside that regime the heuristic mis-orients the LoRA factors:

  • H=8, I=12 (I > H, our default): down_proj PEFT 0.18 raw and PEFT 0.19 swap both produce first.shape=(E, 8, R) when the runtime needs (E, 12, R).
  • H=20, I=8, 2I=16 (H > 2I): gate_up_proj PEFT 0.18 raw and PEFT 0.19 swap produce first.shape=(E, 16, R) when the runtime needs (E, 20, R).

This is independent of #624 and predates it; PR #624 specifically reuses the Qwen extractor (qwen3_moe.py:_make_qwen_moe_lora_extractor) which uses wrapper.parameter_name + base-layer hidden_dim / intermediate_dim and is layout-aware. Worth flagging as a follow-up: porting glm4_moe's extractor to the same wrapper-aware pattern would harden it against config variations.

Summary

PR #624 is a clean, additive registration plumbing change that does not regress any sibling MoE family. The extractor it reuses (_make_qwen_moe_lora_extractor) is verified to round-trip correctly against the real Gemma4TextExperts and against every other Qwen-family experts class. Approve.

@danielhanchen
Copy link
Copy Markdown
Member Author

Dynamic regression-prevention test + cross-version coverage

Do we need to fix anything?

No, not for this PR. The cross-family sweep showed every other MoE family (qwen3_moe, qwen3_5_moe, qwen3_next_moe, qwen3_vl_moe, deepseek_v3_moe, gemma4_moe itself) is already correct. glm4_moe.py:_glm4_lora_extractor uses a shape-only heuristic (dim1 > dim2 -> "transposed branch") that works in production for the published Glm4-MoE-Lite "Air" config (H=2048, I=1408) because that config satisfies the heuristic's H < 2*I and H > I requirements, but it would mis-orient on configs outside that regime. Worth a follow-up to port it to the same wrapper-aware pattern that _make_qwen_moe_lora_extractor and the new gemma4 extractor use, but it is independent of this PR and not regressing today.

One dynamic test that prevents the PR #624 regression class

Added tests/test_moe_lora_extractor_coverage.py (one pytest case). It enforces this contract:

If unsloth_zoo.temporary_patches.utils.patch_function has replaced
cls.forward AND cls.__init__ declares gate_up_proj and down_proj
as nn.Parameters, then cls._unsloth_lora_extractor_fn MUST be set.

That is exactly the PR #624 failure mode: forward was patched to use the grouped-MoE backend but no extractor was registered, so the default permutation in moe_utils._extract_lora_from_wrapper fired and produced contraction-mismatched factors for torch._grouped_mm on PEFT 0.19+.

Discovery is fully dynamic and version-agnostic:

  • Walks every transformers.models.<x>.modeling_<y> module via pkgutil.iter_modules + importlib.import_module. Import failures are silently skipped (optional deps).
  • For each class, detects "patched by unsloth-zoo" by looking for the _original_<modeling_module_tail>_<ClassName>_forward attribute that patch_function writes when storing the original (utils.py:_get_unique_storage_name). This is uniform across families and survives any per-family marker-naming drift.
  • For each candidate class, applies inspect.getsource(cls.__init__) and checks for self.gate_up_proj = nn.Parameter( and self.down_proj = nn.Parameter( patterns. This excludes gpt_oss-style MoE classes that use a different LoRA path.

Plus an opportunistic per-expert parity check: for each discovered class, tries to instantiate via the sibling Config / TextConfig (found by stripping Experts / NaiveMoe / MoEBlock suffixes from the class name) using a tiny synthetic config (H=16, I=10, 2*I=20, all distinct, in the production H < 2*I and H > I regime), then drives the registered extractor on hand-built PEFT 0.18 raw and PEFT 0.19 swapped 3D LoRA factors for both gate_up_proj and down_proj and asserts per-expert delta parity within atol=1e-4, rtol=1e-4.

CPU-only (uses torch.randn with default device). No checkpoints. No model weights download. Single test, one pytest signal.

Negative-control check

To confirm the test would catch the original PR #624 regression, I ran:

import unsloth_zoo.temporary_patches  # apply patches
del transformers.models.gemma4.modeling_gemma4.Gemma4TextExperts._unsloth_lora_extractor_fn
# run the test

Result: 7 discovered, offenders=['transformers.models.gemma4.modeling_gemma4.Gemma4TextExperts'] and the test fails. The test catches the exact regression.

Cross-version coverage on the new test

Ran the new test on every transformers 5.X.X release on PyPI (5.0.0 through 5.8.0):

transformers discovered patched MoE classes test result
5.0.0 5 1 passed in 11.90s
5.1.0 5 1 passed in 12.47s
5.2.0 6 1 passed in 12.28s
5.3.0 6 1 passed in 12.00s
5.4.0 6 1 passed in 12.92s
5.5.0 7 1 passed in 14.61s
5.5.1 7 1 passed in 12.86s
5.5.2 7 1 passed in 12.82s
5.5.3 7 1 passed in 12.86s
5.5.4 7 1 passed in 12.89s
5.6.0 7 1 passed in 12.99s
5.6.1 7 1 passed in 13.41s
5.6.2 7 1 passed in 13.10s
5.7.0 7 1 passed in 13.05s
5.8.0 7 1 passed in 14.81s

The discovery count grows as MoE families ship upstream (Gemma-4 lands in 5.5); the test stays green because PR #624 closed the only registration gap.

Total test count on PR head with tests/test_moe_lora_extractor_coverage.py added: 43 passed (31 existing + 11 from tests/test_gemma4_moe_lora_registration.py + 1 dynamic).

Happy to fold both new test files into a small follow-up PR if useful.

PR #624 fixed Gemma-4 MoE LoRA training by registering the missing
_unsloth_lora_extractor_fn on Gemma4TextExperts and Gemma4TextMoEBlock.
Add three test files that lock in the contract so the same regression
class cannot return for any current or future MoE family.

tests/test_gemma4_moe_lora_registration.py
  Focused unit tests for the new _register_gemma4_lora_extractor helper:
  registration attaches the Qwen extractor and the gemma4_moe model-type
  tag, repeated calls preserve callable identity, None target returns
  False, factory failures leave class state untouched, and per-expert
  delta parity holds for gate_up_proj and down_proj on both PEFT 0.18
  raw and PEFT 0.19 swapped LoRA layouts.

tests/test_moe_lora_extractor_coverage.py
  Single dynamic test that walks transformers.models.* via pkgutil and
  inspect, applies every TEMPORARY_PATCHES entry, and asserts that
  every class whose forward unsloth_zoo replaced (detected via the
  _original_<modeling_module>_<ClassName>_forward attribute that
  patch_function writes) AND whose __init__ source declares
  gate_up_proj and down_proj as nn.Parameters has
  _unsloth_lora_extractor_fn registered. Plus opportunistic per-expert
  parity on each discovered class via a sibling-Config-driven tiny
  synthetic instantiation. Discovery is transformers-version-agnostic;
  the test grew from finding 5 classes on transformers 5.0.0 to 7
  on 5.8.0 with no source changes.

tests/conftest.py
  Lets the suite run on GPU-free CI runners. Pre-loads the real
  unsloth_zoo.device_type with torch.cuda.is_available temporarily
  mocked so the @functools.cache on get_device_type captures "cuda";
  is_available is then restored to its real value so transitive deps
  (torchao, dynamo) still skip CUDA init. On runners with a real
  accelerator the conftest is a no-op, so full-fidelity GPU testing
  is preserved. Also stubs torch.cuda.memory.mem_get_info so
  gpt_oss.py's top-level GPU memory query does not raise.

Verified: full suite is 31 -> 43 passed, on transformers 5.0.0 through
5.8.0, with both CUDA_VISIBLE_DEVICES='' and a real GPU.
@danielhanchen
Copy link
Copy Markdown
Member Author

Pushed the three test files as d3c18a7 on fix-gemma4-moe-lora-extractor:

  • tests/test_gemma4_moe_lora_registration.py (208 lines): focused unit tests for _register_gemma4_lora_extractor on a synthetic stub Gemma4TextExperts. Covers idempotency, None target, factory-failure rollback, and per-expert delta parity on PEFT 0.18 raw + PEFT 0.19 swapped layouts.
  • tests/test_moe_lora_extractor_coverage.py (361 lines): one dynamic test that walks transformers.models.* via pkgutil + inspect, runs every TEMPORARY_PATCHES entry, and asserts every patched class with grouped-MoE 3D parameter layout has _unsloth_lora_extractor_fn registered. Transformers-version-agnostic; discovery scales from 5 patched classes on transformers 5.0.0 to 7 on 5.8.0 without source changes. Includes opportunistic per-expert parity on a tiny synthetic instantiation.
  • tests/conftest.py (168 lines): GPU-free CI harness. Pre-loads the real unsloth_zoo.device_type with torch.cuda.is_available temporarily mocked so the @functools.cache on get_device_type captures "cuda", then restores is_available to its real value so transitive deps (torchao, dynamo) still skip CUDA init. On real-GPU runners the conftest detects the accelerator and is a no-op.

Suite now 43 passed (31 prior + 12 new). Verified across transformers==5.0.0 through 5.8.0, with both CUDA_VISIBLE_DEVICES="" and a real GPU.

return cls(cfg_arg)
except Exception:
continue
try:

import pytest
import torch
import torch.nn as nn
# Apply every TEMPORARY_PATCHES entry. Importing the package side-effect-
# populates the list; we run each entry once and ignore individual failures
# (a missing transformers submodule is the standard no-op signal).
import unsloth_zoo.temporary_patches # noqa: F401 side effect: register patches
Comment thread tests/conftest.py
try:
if hasattr(torch, "cuda") and torch.cuda.is_available():
return True
except Exception:
Comment thread tests/conftest.py
try:
if hasattr(torch, "xpu") and torch.xpu.is_available():
return True
except Exception:
Comment thread tests/conftest.py
try:
if hasattr(torch, "accelerator") and torch.accelerator.is_available():
return True
except Exception:
Comment thread tests/conftest.py
try:
import torch.cuda.memory as _cuda_memory # type: ignore
_cuda_memory.mem_get_info = lambda *a, **k: (0, 80 * 1024 ** 3)
except Exception:
@danielhanchen danielhanchen merged commit 077bfc1 into main May 6, 2026
3 checks passed
danielhanchen added a commit to unslothai/unsloth that referenced this pull request May 7, 2026
New step "MoE per-family coverage + GRPO patches + grouped_gemm AST"
that hardens the matrix against the recurring MoE bug class behind
unslothai/unsloth-zoo#624 / #612 / #607 / #601 and unslothai/unsloth
#4934 / #3598. Five clusters of pytest cases inside one shim:

1. Per-MoE-family side-effect contract (8 parametrized cases):
   For each `patch_*_moe` in unsloth_zoo.temporary_patches.{qwen3_moe,
   qwen3_5_moe, qwen3_next_moe, qwen3_vl_moe, gemma4_moe, glm4_moe,
   deepseek_v3_moe, gpt_oss}, look up the transformers target classes,
   skip when none import on this matrix cell, run the patch fn, and
   assert at least one importable target now carries an unsloth
   "patched" marker. Accepts five marker conventions used across the
   codebase (_unsloth_already_patched, _unsloth_lora_patched,
   _unsloth_lora_extractor_fn, _original_<modeling_tail>_<cls>_forward,
   plain _original_forward). Surfaces silent early-returns (PR #612)
   that escape the registration-coverage test.

   gpt_oss specifically reads UNSLOTH_MODEL_NAME and only runs on
   transformers >= 5; the shim sets the env var via monkeypatch and
   skips on the 4.57.6 cell with a documented reason.

2. PR #4934 (TRL 1.0 GRPO disable_gradient_checkpointing): rebinding
   contract. After patch_trl_disable_gradient_checkpointing(), the
   no-op decorated function MUST be the symbol on
   trl.models.utils AND every trl.* module that imported it by
   reference. Skips on TRL < 1.0 (no symbol present).

3. PR #3598 (gradient_accumulation): patch_gradient_accumulation_fix
   on a vanilla transformers.Trainer must run cleanly without raising
   AND be idempotent. Catches future double-scale or import-injection
   regressions in the source rewriter.

4. unsloth/kernels/moe/grouped_gemm AST smoke: walks every .py under
   the directory (12 files) and asserts ast.parse succeeds. Triton
   kernels are GPU-only at runtime, but a syntax error in source
   surfaces as ImportError on every install. Also sanity-checks the
   directory layout (interface.py, kernels/forward.py,
   kernels/backward.py, reference/moe_block.py, reference/moe_ops.py
   must exist).

Local verification on host TRL 0.25.1 + transformers 4.57.6: 4 pass
(qwen3_moe, qwen3_vl_moe, GRPO disable-GC, grad-accum, grouped_gemm
AST), 7 skip legitimately (qwen3_5/qwen3_next/gemma4/glm4/deepseek/
gpt_oss absent or version-gated). Wall-time ~10s on host; budget
~30-60s per matrix cell.
CodeMan62 pushed a commit to CodeMan62/unsloth-zoo that referenced this pull request May 14, 2026
…nslothai#630)

The previous canary (any class with `gate_up_proj` + `down_proj` 3D
parameters) was too aggressive. transformers ships such classes in
modules that unsloth_zoo does NOT patch (e.g. Llama4TextExperts has
the 3D pattern but `unsloth_zoo/temporary_patches/` has no llama4.py),
and on CPU-only runners some targeting patches early-exit before
hitting their patch_function call. Both produced false-positive
"discovery zero with surface present" failures on transformers 4.57.6
on ubuntu-latest CPU runners.

Use the `_unsloth_already_patched=True` boolean each patch fn sets
explicitly at the end of its successful path as the canary. That
flag's presence on a class means the patch fn ran to its self-
declaration line. If at least one class carries it AND discovery
returned zero, the test-helper marker convention has drifted (real
regression). If no class carries it, either transformers is too old
or every targeting patch fn raised before its self-declaration --
this test cannot reach its target either way, so skip.

This narrows the assertion to the regression the test was written to
catch (PR unslothai#624: marker drift / forgotten extractor on a class the
patch fn claims to own) without firing on irrelevant 3D-pattern
classes or runtime-only patch failures.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant