[MoE] Fix Qwen-family MoE LoRA extractor shape mismatch#601
Conversation
The shared `_make_qwen_moe_lora_extractor` (used by Qwen3-MoE, Qwen3.5/3.6
MoE, and Qwen3-Next) produced `first=(E, out_dim, R)` instead of the
`(E, in_dim, R)` shape expected by `forward_native_grouped_mm`. On models
like Qwen3.6-35B-A3B this triggered, during the first training step:
torch._grouped_mm(inputs, weight, offs=offsets)
RuntimeError: contraction dimension of mat_a and mat_b must match
when `permuted_input` (N, in_dim) was matmul'd against a first_weight whose
second-to-last dim was `out_dim` (e.g. `2*intermediate_dim` for gate_up_proj
on Qwen3.6-35B-A3B's 256-expert architecture).
Root cause: the explicit `param_name in ("gate_up_proj", "down_proj")` branches
and the `dim_B == hidden_dim` branch all constructed
`first_weight = weight_B.view(dim_B, E, R).permute(1, 0, 2)` — i.e. derived
from `lora_B`, which has shape `(out_dim, E*R)` — so `first.shape[-2]` ended up
as `out_dim`, not `in_dim`. The final fallback at the bottom of the function
was already correct.
Fix: drop the broken branches. The correct mapping — identical to the default
extractor in `moe_utils.py::_extract_lora_from_wrapper` and to the working
Qwen3-VL-MoE extractor in `qwen3_vl_moe.py::_qwen3_vl_lora_extractor` — is
format-independent:
weight_A : (E*R, in_dim) -> view(E, R, in_dim).permute(0, 2, 1) = (E, in_dim, R)
weight_B : (out_dim, E*R) -> view(out_dim, E, R).permute(1, 2, 0) = (E, R, out_dim)
PEFT LoRA weights have fixed shape relative to the linear's in/out dims; they
don't depend on whether base weights are stored "standard" (E, out, in) or
"transposed" (E, in, out) — that distinction is handled upstream by
`preprocess_weight`.
Verified against Qwen3.6-35B-A3B (unsloth/Qwen3.6-35B-A3B): the LoRA forward
path through `torch._grouped_mm` no longer fails with the contraction-dim
error, and training progresses past the first forward into the expected
memory-bound regime.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request simplifies the _qwen_moe_lora_extractor function for Qwen-family MoE models by removing complex conditional logic that previously checked for specific parameter names and dimensions. The implementation now uses a unified approach to reshape and permute LoRA weights (weight_A and weight_B) into the format required for grouped_mm, regardless of the base weight storage layout. The docstring has also been updated to reflect this streamlined logic and explain the mapping from PEFT's dimension layout. I have no feedback to provide as no review comments were submitted.
|
The changes do look good to me. But can you do a comparison run with transformers baseline vs unsloth and report_to wandb. |
|
Possible duplicate of a trusted maintainer's PR. This PR looks like it solves the same underlying problem as unslothai/unsloth-zoo#495 by @Datta0 (trusted maintainer).
Canonical PR summary: The PR updates Unsloth Zoo compatibility patches for newer transformers v5, adding new Qwen3Next and Qwen3.5 MoE temporary patches, relaxing Gemma3/CSM/DeepSeekV3/GLM4-MoE signatures, and fixing MoE backend, 4-bit loading, and compiler behavior. Other related PRs:
The auto-review is still running against this PR — reviewers will factor in the canonical above. If this PR is genuinely different, call out the delta in the review discussion so the maintainer can decide which to merge. |
|
Auto-review verdict: Approved Replaces the branchy Qwen-family MoE LoRA extractor with the canonical PEFT shape mapping (weight_A -> (E, in_dim, R), weight_B -> (E, R, out_dim)), eliminating the torch._grouped_mm contraction-dimension crash that reproduced on every gate_up_proj and down_proj step for Qwen3-MoE, Qwen3.5/3.6, and Qwen3-Next. The fix matches the existing default extractor in moe_utils.py and qwen3_vl_moe.py, and is verified by CPU shape and numerical-equivalence tests. Reason: Correct minimal fix; all 11 reviewers approved across 2 iterations with no blocking findings. |
|
Thanks so much this works! |
* Fix qwen lora extractor for diff peft versions * Native MoE loop: down_proj transpose check + pre-cast LoRA factors Two related fixes for the new separated-LoRA path inside `forward_native_moe_loop`: 1. `down_proj` mirror of the existing `gate_up_proj` shape check. Qwen3-VL-MoE stores `down_proj` as `(E, intermediate_dim, hidden_dim)` for grouped_mm, while `F.linear` needs `(out, in)`. Without the transpose, the loop fallback crashed with `mat1 and mat2 shapes cannot be multiplied (... and H x I)` when training a Qwen3-VL-MoE adapter on a backend that fell back to the loop (e.g. when `torch._grouped_mm` is unavailable). Reproduced on Qwen3-VL-MoE-Tiny + PEFT 0.19.1 + UNSLOTH_MOE_BACKEND=native_torch. 2. Pre-cast the extracted LoRA `(first_weight, second_weight)` to `hidden_states.dtype` exactly once before the per-expert loop, instead of calling `.to(...)` on a per-expert slice every iteration. Also adds `test_forward_native_moe_loop_lora.py` covering both the canonical `(E, 2*I, H) / (E, H, I)` and the transposed `(E, H, 2*I) / (E, I, H)` storage layouts, with and without LoRA, and across fp32 / bf16 / fp16. Each test compares the output of `forward_native_moe_loop` against a naive PEFT-style reference per expert. * Qwen MoE extractor + native loop: harden ambiguous cases Three correctness hardenings pulled from the parallel reviewer pass on PR #618 (#618): 1. `_qwen_moe_lora_extractor` (qwen3_moe.py): - When `wrapper._did_swap_in_out_features` is set (PEFT 0.19's own post-swap flag) and shapes match the reversed branch, prefer that signal over shape-only inference. Fixes the `input_dim == output_dim` ambiguity where both `(dim_A==input_dim, dim_B==output_dim)` and `(dim_A==output_dim, dim_B==input_dim)` are simultaneously true. - When neither branch matches and the wrapper does report (input_dim, output_dim), emit a `logger.warning` (gated on `UNSLOTH_ENABLE_LOGGING`). The previous silent fallback was the exact failure mode of the merged-then-reverted PR #601. - Factor the canonical and reversed permutation into local helpers to keep the dispatch readable. 2. `forward_native_moe_loop` (moe_utils.py): prefer the explicit `_unsloth_grouped_mm_format` flag (set by `Qwen3VLMoeTextExperts`'s patched __init__) over the shape-only `weight.shape[-1] != x.shape[-1]` check when deciding whether to transpose `gate_up_proj` / `down_proj`. The shape check is unsafe when `intermediate_dim == hidden_dim`. 3. Tests: - `test_extractor_disambiguates_square_dims_via_did_swap` covers the square-dim ambiguity for both PEFT 0.18 (no swap) and PEFT 0.19 (`_did_swap_in_out_features=True`). - `test_extractor_fallback_warns_when_dims_mismatch` asserts the fallback warning fires when neither layout matches. - `test_forward_native_moe_loop_square_dim_uses_grouped_mm_flag` covers the transposed grouped-mm path at square dims, which the shape heuristic alone would miss. The fused `gate_up_proj`/`down_proj` Qwen path is now covered; LoRA support for the split `w1`/`w3`/`w2` branch in `forward_native_moe_loop` was flagged as a separate (asymmetric) gap but is intentionally left for a follow-up — it affects non-fused-gate MoE families, not Qwen, and lands cleanest with its own scope. --------- Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Summary
Fix a shape-mismatch bug in the shared
_make_qwen_moe_lora_extractor(used by Qwen3-MoE, Qwen3.5/3.6 MoE, and Qwen3-Next) that causestorch._grouped_mmto raiseRuntimeError: contraction dimension of mat_a and mat_b must matchduring the first training step on MoE+LoRA.Reproduction
Root cause
_make_qwen_moe_lora_extractorhas three code paths that all constructfirst_weightfromweight_B(PEFT'slora_B, shape(out_dim, E*R)):But
forward_native_grouped_mmcalls:with
permuted_inputof shape(N, in_dim).torch._grouped_mmrequiresfirst_weight.shape[-2] == in_dim, but the extractor givesout_dimthere.On Qwen3.6-35B-A3B (
gate_up_projstored as(E, 2*I, H)inmodeling_qwen3_5_moe.py:723),in_dim = Handout_dim = 2*I— and these are unequal for every MoE config, so the mismatch fires on every run with expert LoRA.The fallback at the bottom of the function (lines 102-106 pre-fix) was already correct, but none of the earlier branches fell through to it for Qwen3-MoE-style models.
Fix
Drop the broken branches. The correct mapping — identical to:
moe_utils.py::_extract_lora_from_wrapper(lines 421-426),qwen3_vl_moe.py::_qwen3_vl_lora_extractor(lines 275-279),— is format-independent:
PEFT LoRA weights have fixed shape relative to the linear's in/out dims — they don't depend on whether the base weights are stored "standard"
(E, out, in)or "transposed"(E, in, out). That distinction is handled bypreprocess_weightfor the base-weight path, not the LoRA path.Test plan
contraction dimensioncrash onunsloth/Qwen3.6-35B-A3Bwithtarget_modulesthat includegate_proj, up_proj, down_proj(expert LoRA)torch._grouped_mmsucceed — training progresses past step 0 into the expected memory-bound regime on 1× A100-80GBunsloth/Qwen3-30B-A3B(128 experts, same code path) — expected to continue workingHappy to add a unit test for the extractor shape contract if useful; let me know if you have a preferred fixture pattern.
🤖 Generated with Claude Code