Skip to content

[MoE] Fix Qwen-family MoE LoRA extractor shape mismatch#601

Merged
danielhanchen merged 7 commits into
unslothai:mainfrom
lordx64:fix/qwen-moe-lora-extractor-shape-mismatch
Apr 22, 2026
Merged

[MoE] Fix Qwen-family MoE LoRA extractor shape mismatch#601
danielhanchen merged 7 commits into
unslothai:mainfrom
lordx64:fix/qwen-moe-lora-extractor-shape-mismatch

Conversation

@lordx64

@lordx64 lordx64 commented Apr 18, 2026

Copy link
Copy Markdown
Contributor

Summary

Fix a shape-mismatch bug in the shared _make_qwen_moe_lora_extractor (used by Qwen3-MoE, Qwen3.5/3.6 MoE, and Qwen3-Next) that causes torch._grouped_mm to raise RuntimeError: contraction dimension of mat_a and mat_b must match during the first training step on MoE+LoRA.

Reproduction

# train.py (UV script)
from unsloth import FastLanguageModel
from trl import SFTTrainer, SFTConfig

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Qwen3.6-35B-A3B",
    max_seq_length=4096,
    load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
    model, r=16, lora_alpha=16,
    target_modules=["q_proj","k_proj","v_proj","o_proj",
                    "gate_proj","up_proj","down_proj"],  # includes expert LoRA
    use_gradient_checkpointing="unsloth",
)
# trainer.train() fails at step 0 with:
#   File ".../moe_utils.py", line 131, in _grouped_mm_with_backward_fix
#       return torch._grouped_mm(inputs, weight, offs=offsets)
#   RuntimeError: contraction dimension of mat_a and mat_b must match

Root cause

_make_qwen_moe_lora_extractor has three code paths that all construct first_weight from weight_B (PEFT's lora_B, shape (out_dim, E*R)):

# qwen3_moe.py:77-94 (broken)
if param_name == "down_proj" and ...:
    first_weight = weight_B.view(dim_B, E, R).permute(1, 0, 2)   # (E, out_dim, R)
    second_weight = weight_A.view(E, R, dim_A)                    # (E, R, in_dim)
    return ...
elif param_name == "gate_up_proj" and ...:
    first_weight = weight_B.view(dim_B, E, R).permute(1, 0, 2)   # (E, out_dim, R)
    ...
if dim_B == hidden_dim:
    first_weight = weight_B.view(dim_B, E, R).permute(1, 0, 2)   # (E, out_dim, R)
    ...

But forward_native_grouped_mm calls:

# moe_utils.py:862
lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets)

with permuted_input of shape (N, in_dim). torch._grouped_mm requires first_weight.shape[-2] == in_dim, but the extractor gives out_dim there.

On Qwen3.6-35B-A3B (gate_up_proj stored as (E, 2*I, H) in modeling_qwen3_5_moe.py:723), in_dim = H and out_dim = 2*I — and these are unequal for every MoE config, so the mismatch fires on every run with expert LoRA.

The fallback at the bottom of the function (lines 102-106 pre-fix) was already correct, but none of the earlier branches fell through to it for Qwen3-MoE-style models.

Fix

Drop the broken branches. The correct mapping — identical to:

  • the default extractor in moe_utils.py::_extract_lora_from_wrapper (lines 421-426),
  • the working Qwen3-VL-MoE extractor in qwen3_vl_moe.py::_qwen3_vl_lora_extractor (lines 275-279),

— is format-independent:

first_weight  = weight_A.view(E, R, in_dim).permute(0, 2, 1).contiguous()  # (E, in_dim, R)
second_weight = weight_B.view(out_dim, E, R).permute(1, 2, 0).contiguous() # (E, R, out_dim)

PEFT LoRA weights have fixed shape relative to the linear's in/out dims — they don't depend on whether the base weights are stored "standard" (E, out, in) or "transposed" (E, in, out). That distinction is handled by preprocess_weight for the base-weight path, not the LoRA path.

Test plan

  • Reproduce the contraction dimension crash on unsloth/Qwen3.6-35B-A3B with target_modules that include gate_proj, up_proj, down_proj (expert LoRA)
  • Verify the fix makes the LoRA forward path through torch._grouped_mm succeed — training progresses past step 0 into the expected memory-bound regime on 1× A100-80GB
  • Regression check on unsloth/Qwen3-30B-A3B (128 experts, same code path) — expected to continue working
  • Regression check on Qwen3-Next — same extractor, same mapping applies

Happy to add a unit test for the extractor shape contract if useful; let me know if you have a preferred fixture pattern.

🤖 Generated with Claude Code

The shared `_make_qwen_moe_lora_extractor` (used by Qwen3-MoE, Qwen3.5/3.6
MoE, and Qwen3-Next) produced `first=(E, out_dim, R)` instead of the
`(E, in_dim, R)` shape expected by `forward_native_grouped_mm`. On models
like Qwen3.6-35B-A3B this triggered, during the first training step:

    torch._grouped_mm(inputs, weight, offs=offsets)
    RuntimeError: contraction dimension of mat_a and mat_b must match

when `permuted_input` (N, in_dim) was matmul'd against a first_weight whose
second-to-last dim was `out_dim` (e.g. `2*intermediate_dim` for gate_up_proj
on Qwen3.6-35B-A3B's 256-expert architecture).

Root cause: the explicit `param_name in ("gate_up_proj", "down_proj")` branches
and the `dim_B == hidden_dim` branch all constructed
`first_weight = weight_B.view(dim_B, E, R).permute(1, 0, 2)` — i.e. derived
from `lora_B`, which has shape `(out_dim, E*R)` — so `first.shape[-2]` ended up
as `out_dim`, not `in_dim`. The final fallback at the bottom of the function
was already correct.

Fix: drop the broken branches. The correct mapping — identical to the default
extractor in `moe_utils.py::_extract_lora_from_wrapper` and to the working
Qwen3-VL-MoE extractor in `qwen3_vl_moe.py::_qwen3_vl_lora_extractor` — is
format-independent:

    weight_A : (E*R, in_dim)  -> view(E, R, in_dim).permute(0, 2, 1) = (E, in_dim, R)
    weight_B : (out_dim, E*R) -> view(out_dim, E, R).permute(1, 2, 0) = (E, R, out_dim)

PEFT LoRA weights have fixed shape relative to the linear's in/out dims; they
don't depend on whether base weights are stored "standard" (E, out, in) or
"transposed" (E, in, out) — that distinction is handled upstream by
`preprocess_weight`.

Verified against Qwen3.6-35B-A3B (unsloth/Qwen3.6-35B-A3B): the LoRA forward
path through `torch._grouped_mm` no longer fails with the contraction-dim
error, and training progresses past the first forward into the expected
memory-bound regime.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request simplifies the _qwen_moe_lora_extractor function for Qwen-family MoE models by removing complex conditional logic that previously checked for specific parameter names and dimensions. The implementation now uses a unified approach to reshape and permute LoRA weights (weight_A and weight_B) into the format required for grouped_mm, regardless of the base weight storage layout. The docstring has also been updated to reflect this streamlined logic and explain the mapping from PEFT's dimension layout. I have no feedback to provide as no review comments were submitted.

@Datta0

Datta0 commented Apr 20, 2026

Copy link
Copy Markdown
Collaborator

The changes do look good to me. But can you do a comparison run with transformers baseline vs unsloth and report_to wandb.
Once we can confirm that the losses are closely matching, we can merge the PR

@danielhanchen

Copy link
Copy Markdown
Member

Possible duplicate of a trusted maintainer's PR. This PR looks like it solves the same underlying problem as unslothai/unsloth-zoo#495 by @Datta0 (trusted maintainer).

Directly patches qwen3_moe extractor layout for Qwen3Next/Qwen3.5, matching the same shape-mismatch fix path.

Canonical PR summary: The PR updates Unsloth Zoo compatibility patches for newer transformers v5, adding new Qwen3Next and Qwen3.5 MoE temporary patches, relaxing Gemma3/CSM/DeepSeekV3/GLM4-MoE signatures, and fixing MoE backend, 4-bit loading, and compiler behavior.

Other related PRs:

The auto-review is still running against this PR — reviewers will factor in the canonical above. If this PR is genuinely different, call out the delta in the review discussion so the maintainer can decide which to merge.

@danielhanchen danielhanchen added the auto-reviewing Auto-review in progress label Apr 20, 2026
danielhanchen added a commit to shimmyshimmer/unsloth-zoo-staging-2 that referenced this pull request Apr 20, 2026
danielhanchen added a commit to shimmyshimmer/unsloth-zoo-staging-2 that referenced this pull request Apr 20, 2026
danielhanchen added a commit to shimmyshimmer/unsloth-zoo-staging-2 that referenced this pull request Apr 20, 2026
@danielhanchen danielhanchen added auto-approved Auto-review approved the PR and removed auto-reviewing Auto-review in progress labels Apr 20, 2026
@danielhanchen

Copy link
Copy Markdown
Member

Auto-review verdict: Approved

Replaces the branchy Qwen-family MoE LoRA extractor with the canonical PEFT shape mapping (weight_A -> (E, in_dim, R), weight_B -> (E, R, out_dim)), eliminating the torch._grouped_mm contraction-dimension crash that reproduced on every gate_up_proj and down_proj step for Qwen3-MoE, Qwen3.5/3.6, and Qwen3-Next. The fix matches the existing default extractor in moe_utils.py and qwen3_vl_moe.py, and is verified by CPU shape and numerical-equivalence tests.

Reason: Correct minimal fix; all 11 reviewers approved across 2 iterations with no blocking findings.

@danielhanchen

Copy link
Copy Markdown
Member

Thanks so much this works!

@danielhanchen danielhanchen merged commit aa995e0 into unslothai:main Apr 22, 2026
danielhanchen added a commit that referenced this pull request May 4, 2026
* Fix qwen lora extractor for diff peft versions

* Native MoE loop: down_proj transpose check + pre-cast LoRA factors

Two related fixes for the new separated-LoRA path inside
`forward_native_moe_loop`:

1. `down_proj` mirror of the existing `gate_up_proj` shape check.
   Qwen3-VL-MoE stores `down_proj` as `(E, intermediate_dim, hidden_dim)`
   for grouped_mm, while `F.linear` needs `(out, in)`. Without the
   transpose, the loop fallback crashed with
   `mat1 and mat2 shapes cannot be multiplied (... and H x I)` when
   training a Qwen3-VL-MoE adapter on a backend that fell back to the
   loop (e.g. when `torch._grouped_mm` is unavailable). Reproduced on
   Qwen3-VL-MoE-Tiny + PEFT 0.19.1 + UNSLOTH_MOE_BACKEND=native_torch.

2. Pre-cast the extracted LoRA `(first_weight, second_weight)` to
   `hidden_states.dtype` exactly once before the per-expert loop, instead
   of calling `.to(...)` on a per-expert slice every iteration.

Also adds `test_forward_native_moe_loop_lora.py` covering both the
canonical `(E, 2*I, H) / (E, H, I)` and the transposed
`(E, H, 2*I) / (E, I, H)` storage layouts, with and without LoRA, and
across fp32 / bf16 / fp16. Each test compares the output of
`forward_native_moe_loop` against a naive PEFT-style reference per
expert.

* Qwen MoE extractor + native loop: harden ambiguous cases

Three correctness hardenings pulled from the parallel reviewer pass on
PR #618 (#618):

1. `_qwen_moe_lora_extractor` (qwen3_moe.py):
   - When `wrapper._did_swap_in_out_features` is set (PEFT 0.19's own
     post-swap flag) and shapes match the reversed branch, prefer that
     signal over shape-only inference. Fixes the `input_dim == output_dim`
     ambiguity where both `(dim_A==input_dim, dim_B==output_dim)` and
     `(dim_A==output_dim, dim_B==input_dim)` are simultaneously true.
   - When neither branch matches and the wrapper does report (input_dim,
     output_dim), emit a `logger.warning` (gated on
     `UNSLOTH_ENABLE_LOGGING`). The previous silent fallback was the exact
     failure mode of the merged-then-reverted PR #601.
   - Factor the canonical and reversed permutation into local helpers to
     keep the dispatch readable.

2. `forward_native_moe_loop` (moe_utils.py): prefer the explicit
   `_unsloth_grouped_mm_format` flag (set by `Qwen3VLMoeTextExperts`'s
   patched __init__) over the shape-only `weight.shape[-1] != x.shape[-1]`
   check when deciding whether to transpose `gate_up_proj` /
   `down_proj`. The shape check is unsafe when
   `intermediate_dim == hidden_dim`.

3. Tests:
   - `test_extractor_disambiguates_square_dims_via_did_swap` covers the
     square-dim ambiguity for both PEFT 0.18 (no swap) and PEFT 0.19
     (`_did_swap_in_out_features=True`).
   - `test_extractor_fallback_warns_when_dims_mismatch` asserts the
     fallback warning fires when neither layout matches.
   - `test_forward_native_moe_loop_square_dim_uses_grouped_mm_flag`
     covers the transposed grouped-mm path at square dims, which the
     shape heuristic alone would miss.

The fused `gate_up_proj`/`down_proj` Qwen path is now covered; LoRA
support for the split `w1`/`w3`/`w2` branch in
`forward_native_moe_loop` was flagged as a separate (asymmetric) gap
but is intentionally left for a follow-up — it affects non-fused-gate
MoE families, not Qwen, and lands cleanest with its own scope.

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

auto-approved Auto-review approved the PR auto-has-duplicate Pre-flight: similar to a trusted maintainer's PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants