Skip to content

[MoE] Fix Qwen-family MoE LoRA extractor shape mismatch#10

Closed
danielhanchen wants to merge 3 commits into
mainfrom
pr-601-head
Closed

[MoE] Fix Qwen-family MoE LoRA extractor shape mismatch#10
danielhanchen wants to merge 3 commits into
mainfrom
pr-601-head

Conversation

@danielhanchen

Copy link
Copy Markdown
Collaborator

Staging mirror of unslothai#601

Original PR: unslothai#601
Author: lordx64

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

Summary

Fix a shape-mismatch bug in the shared _make_qwen_moe_lora_extractor (used by Qwen3-MoE, Qwen3.5/3.6 MoE, and Qwen3-Next) that causes torch._grouped_mm to raise RuntimeError: contraction dimension of mat_a and mat_b must match during the first training step on MoE+LoRA.

Reproduction

# train.py (UV script)
from unsloth import FastLanguageModel
from trl import SFTTrainer, SFTConfig

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Qwen3.6-35B-A3B",
    max_seq_length=4096,
    load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
    model, r=16, lora_alpha=16,
    target_modules=["q_proj","k_proj","v_proj","o_proj",
                    "gate_proj","up_proj","down_proj"],  # includes expert LoRA
    use_gradient_checkpointing="unsloth",
)
# trainer.train() fails at step 0 with:
#   File ".../moe_utils.py", line 131, in _grouped_mm_with_backward_fix
#       return torch._grouped_mm(inputs, weight, offs=offsets)
#   RuntimeError: contraction dimension of mat_a and mat_b must match

Root cause

_make_qwen_moe_lora_extractor has three code paths that all construct first_weight from weight_B (PEFT's lora_B, shape (out_dim, E*R)):

# qwen3_moe.py:77-94 (broken)
if param_name == "down_proj" and ...:
    first_weight = weight_B.view(dim_B, E, R).permute(1, 0, 2)   # (E, out_dim, R)
    second_weight = weight_A.view(E, R, dim_A)                    # (E, R, in_dim)
    return ...
elif param_name == "gate_up_proj" and ...:
    first_weight = weight_B.view(dim_B, E, R).permute(1, 0, 2)   # (E, out_dim, R)
    ...
if dim_B == hidden_dim:
    first_weight = weight_B.view(dim_B, E, R).permute(1, 0, 2)   # (E, out_dim, R)
    ...

But forward_native_grouped_mm calls:

# moe_utils.py:862
lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets)

with permuted_input of shape (N, in_dim). torch._grouped_mm requires first_weight.shape[-2] == in_dim, but the extractor gives out_dim there.

On Qwen3.6-35B-A3B (gate_up_proj stored as (E, 2*I, H) in modeling_qwen3_5_moe.py:723), in_dim = H and out_dim = 2*I — and these are unequal for every MoE config, so the mismatch fires on every run with expert LoRA.

The fallback at the bottom of the function (lines 102-106 pre-fix) was already correct, but none of the earlier branches fell through to it for Qwen3-MoE-style models.

Fix

Drop the broken branches. The correct mapping — identical to:

  • the default extractor in moe_utils.py::_extract_lora_from_wrapper (lines 421-426),
  • the working Qwen3-VL-MoE extractor in qwen3_vl_moe.py::_qwen3_vl_lora_extractor (lines 275-279),

— is format-independent:

first_weight  = weight_A.view(E, R, in_dim).permute(0, 2, 1).contiguous()  # (E, in_dim, R)
second_weight = weight_B.view(out_dim, E, R).permute(1, 2, 0).contiguous() # (E, R, out_dim)

PEFT LoRA weights have fixed shape relative to the linear's in/out di

The shared `_make_qwen_moe_lora_extractor` (used by Qwen3-MoE, Qwen3.5/3.6
MoE, and Qwen3-Next) produced `first=(E, out_dim, R)` instead of the
`(E, in_dim, R)` shape expected by `forward_native_grouped_mm`. On models
like Qwen3.6-35B-A3B this triggered, during the first training step:

    torch._grouped_mm(inputs, weight, offs=offsets)
    RuntimeError: contraction dimension of mat_a and mat_b must match

when `permuted_input` (N, in_dim) was matmul'd against a first_weight whose
second-to-last dim was `out_dim` (e.g. `2*intermediate_dim` for gate_up_proj
on Qwen3.6-35B-A3B's 256-expert architecture).

Root cause: the explicit `param_name in ("gate_up_proj", "down_proj")` branches
and the `dim_B == hidden_dim` branch all constructed
`first_weight = weight_B.view(dim_B, E, R).permute(1, 0, 2)` — i.e. derived
from `lora_B`, which has shape `(out_dim, E*R)` — so `first.shape[-2]` ended up
as `out_dim`, not `in_dim`. The final fallback at the bottom of the function
was already correct.

Fix: drop the broken branches. The correct mapping — identical to the default
extractor in `moe_utils.py::_extract_lora_from_wrapper` and to the working
Qwen3-VL-MoE extractor in `qwen3_vl_moe.py::_qwen3_vl_lora_extractor` — is
format-independent:

    weight_A : (E*R, in_dim)  -> view(E, R, in_dim).permute(0, 2, 1) = (E, in_dim, R)
    weight_B : (out_dim, E*R) -> view(out_dim, E, R).permute(1, 2, 0) = (E, R, out_dim)

PEFT LoRA weights have fixed shape relative to the linear's in/out dims; they
don't depend on whether base weights are stored "standard" (E, out, in) or
"transposed" (E, in, out) — that distinction is handled upstream by
`preprocess_weight`.

Verified against Qwen3.6-35B-A3B (unsloth/Qwen3.6-35B-A3B): the LoRA forward
path through `torch._grouped_mm` no longer fails with the contraction-dim
error, and training progresses past the first forward into the expected
memory-bound regime.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@danielhanchen

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request simplifies the _qwen_moe_lora_extractor function by removing complex logic that traversed base layers to determine dimensions. The new implementation uses a direct mapping for LoRA weights based on the fixed layout produced by PEFT, making the code more robust and easier to maintain. Feedback was provided to rename dimension variables for better alignment with the docstring and to inline a single-use variable for brevity.

Comment on lines 65 to 74
total_rank = weight_A.shape[0]
rank_per_expert = total_rank // num_experts

dim_A = weight_A.shape[1]
dim_B = weight_B.shape[0]

hidden_dim = None
intermediate_dim = None
current = wrapper
while hasattr(current, "base_layer"):
current = current.base_layer
if hasattr(current, "hidden_dim"):
hidden_dim = current.hidden_dim
if hasattr(current, "intermediate_dim"):
intermediate_dim = current.intermediate_dim
if hasattr(current, "gate_up_proj") and hasattr(current.gate_up_proj, "shape"):
shape = current.gate_up_proj.shape
if len(shape) == 3:
hidden_dim = shape[2]
intermediate_dim = shape[1] // 2

param_name = getattr(wrapper, "parameter_name", None)

if param_name == "down_proj" and intermediate_dim is not None and hidden_dim is not None:
first_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
first_weight = first_weight.permute(1, 0, 2).contiguous()
second_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
return first_weight, second_weight, scaling, num_experts

elif param_name == "gate_up_proj" and hidden_dim is not None:
first_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
first_weight = first_weight.permute(1, 0, 2).contiguous()
second_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
return first_weight, second_weight, scaling, num_experts

if hidden_dim is not None:
if dim_B == hidden_dim:
first_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
first_weight = first_weight.permute(1, 0, 2).contiguous()
second_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
return first_weight, second_weight, scaling, num_experts
elif dim_A == hidden_dim:
first_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
first_weight = first_weight.permute(0, 2, 1).contiguous()
second_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
second_weight = second_weight.permute(1, 2, 0).contiguous()
return first_weight, second_weight, scaling, num_experts
dim_A = weight_A.shape[1] # in_dim
dim_B = weight_B.shape[0] # out_dim

first_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
first_weight = first_weight.permute(0, 2, 1).contiguous()

second_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
second_weight = second_weight.permute(1, 2, 0).contiguous()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The variable names dim_A and dim_B can be improved to in_dim and out_dim to match the terminology used in the docstring and improve code clarity. Additionally, total_rank is only used once and can be inlined to simplify the logic.

Suggested change
total_rank = weight_A.shape[0]
rank_per_expert = total_rank // num_experts
dim_A = weight_A.shape[1]
dim_B = weight_B.shape[0]
hidden_dim = None
intermediate_dim = None
current = wrapper
while hasattr(current, "base_layer"):
current = current.base_layer
if hasattr(current, "hidden_dim"):
hidden_dim = current.hidden_dim
if hasattr(current, "intermediate_dim"):
intermediate_dim = current.intermediate_dim
if hasattr(current, "gate_up_proj") and hasattr(current.gate_up_proj, "shape"):
shape = current.gate_up_proj.shape
if len(shape) == 3:
hidden_dim = shape[2]
intermediate_dim = shape[1] // 2
param_name = getattr(wrapper, "parameter_name", None)
if param_name == "down_proj" and intermediate_dim is not None and hidden_dim is not None:
first_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
first_weight = first_weight.permute(1, 0, 2).contiguous()
second_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
return first_weight, second_weight, scaling, num_experts
elif param_name == "gate_up_proj" and hidden_dim is not None:
first_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
first_weight = first_weight.permute(1, 0, 2).contiguous()
second_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
return first_weight, second_weight, scaling, num_experts
if hidden_dim is not None:
if dim_B == hidden_dim:
first_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
first_weight = first_weight.permute(1, 0, 2).contiguous()
second_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
return first_weight, second_weight, scaling, num_experts
elif dim_A == hidden_dim:
first_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
first_weight = first_weight.permute(0, 2, 1).contiguous()
second_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
second_weight = second_weight.permute(1, 2, 0).contiguous()
return first_weight, second_weight, scaling, num_experts
dim_A = weight_A.shape[1] # in_dim
dim_B = weight_B.shape[0] # out_dim
first_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
first_weight = first_weight.permute(0, 2, 1).contiguous()
second_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
second_weight = second_weight.permute(1, 2, 0).contiguous()
rank_per_expert = weight_A.shape[0] // num_experts
in_dim = weight_A.shape[1]
out_dim = weight_B.shape[0]
first_weight = weight_A.view(num_experts, rank_per_expert, in_dim)
first_weight = first_weight.permute(0, 2, 1).contiguous()
second_weight = weight_B.view(out_dim, num_experts, rank_per_expert)
second_weight = second_weight.permute(1, 2, 0).contiguous()

@danielhanchen

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request simplifies the LoRA weight extraction logic for Qwen-family MoE models in unsloth_zoo/temporary_patches/qwen3_moe.py. The update removes complex conditional logic that previously checked for specific parameter names and dimensions, replacing it with a streamlined implementation that directly reshapes and permutes weights based on PEFT's standard dimension layout. The documentation has also been updated to clarify the expected tensor shapes and the rationale for the simplified mapping. I have no feedback to provide.

@danielhanchen

Copy link
Copy Markdown
Collaborator Author

Fixes pushed to unslothai#601.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants