[MoE] Fix Qwen-family MoE LoRA extractor shape mismatch#10
[MoE] Fix Qwen-family MoE LoRA extractor shape mismatch#10danielhanchen wants to merge 3 commits into
Conversation
The shared `_make_qwen_moe_lora_extractor` (used by Qwen3-MoE, Qwen3.5/3.6
MoE, and Qwen3-Next) produced `first=(E, out_dim, R)` instead of the
`(E, in_dim, R)` shape expected by `forward_native_grouped_mm`. On models
like Qwen3.6-35B-A3B this triggered, during the first training step:
torch._grouped_mm(inputs, weight, offs=offsets)
RuntimeError: contraction dimension of mat_a and mat_b must match
when `permuted_input` (N, in_dim) was matmul'd against a first_weight whose
second-to-last dim was `out_dim` (e.g. `2*intermediate_dim` for gate_up_proj
on Qwen3.6-35B-A3B's 256-expert architecture).
Root cause: the explicit `param_name in ("gate_up_proj", "down_proj")` branches
and the `dim_B == hidden_dim` branch all constructed
`first_weight = weight_B.view(dim_B, E, R).permute(1, 0, 2)` — i.e. derived
from `lora_B`, which has shape `(out_dim, E*R)` — so `first.shape[-2]` ended up
as `out_dim`, not `in_dim`. The final fallback at the bottom of the function
was already correct.
Fix: drop the broken branches. The correct mapping — identical to the default
extractor in `moe_utils.py::_extract_lora_from_wrapper` and to the working
Qwen3-VL-MoE extractor in `qwen3_vl_moe.py::_qwen3_vl_lora_extractor` — is
format-independent:
weight_A : (E*R, in_dim) -> view(E, R, in_dim).permute(0, 2, 1) = (E, in_dim, R)
weight_B : (out_dim, E*R) -> view(out_dim, E, R).permute(1, 2, 0) = (E, R, out_dim)
PEFT LoRA weights have fixed shape relative to the linear's in/out dims; they
don't depend on whether base weights are stored "standard" (E, out, in) or
"transposed" (E, in, out) — that distinction is handled upstream by
`preprocess_weight`.
Verified against Qwen3.6-35B-A3B (unsloth/Qwen3.6-35B-A3B): the LoRA forward
path through `torch._grouped_mm` no longer fails with the contraction-dim
error, and training progresses past the first forward into the expected
memory-bound regime.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request simplifies the _qwen_moe_lora_extractor function by removing complex logic that traversed base layers to determine dimensions. The new implementation uses a direct mapping for LoRA weights based on the fixed layout produced by PEFT, making the code more robust and easier to maintain. Feedback was provided to rename dimension variables for better alignment with the docstring and to inline a single-use variable for brevity.
| total_rank = weight_A.shape[0] | ||
| rank_per_expert = total_rank // num_experts | ||
|
|
||
| dim_A = weight_A.shape[1] | ||
| dim_B = weight_B.shape[0] | ||
|
|
||
| hidden_dim = None | ||
| intermediate_dim = None | ||
| current = wrapper | ||
| while hasattr(current, "base_layer"): | ||
| current = current.base_layer | ||
| if hasattr(current, "hidden_dim"): | ||
| hidden_dim = current.hidden_dim | ||
| if hasattr(current, "intermediate_dim"): | ||
| intermediate_dim = current.intermediate_dim | ||
| if hasattr(current, "gate_up_proj") and hasattr(current.gate_up_proj, "shape"): | ||
| shape = current.gate_up_proj.shape | ||
| if len(shape) == 3: | ||
| hidden_dim = shape[2] | ||
| intermediate_dim = shape[1] // 2 | ||
|
|
||
| param_name = getattr(wrapper, "parameter_name", None) | ||
|
|
||
| if param_name == "down_proj" and intermediate_dim is not None and hidden_dim is not None: | ||
| first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) | ||
| first_weight = first_weight.permute(1, 0, 2).contiguous() | ||
| second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) | ||
| return first_weight, second_weight, scaling, num_experts | ||
|
|
||
| elif param_name == "gate_up_proj" and hidden_dim is not None: | ||
| first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) | ||
| first_weight = first_weight.permute(1, 0, 2).contiguous() | ||
| second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) | ||
| return first_weight, second_weight, scaling, num_experts | ||
|
|
||
| if hidden_dim is not None: | ||
| if dim_B == hidden_dim: | ||
| first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) | ||
| first_weight = first_weight.permute(1, 0, 2).contiguous() | ||
| second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) | ||
| return first_weight, second_weight, scaling, num_experts | ||
| elif dim_A == hidden_dim: | ||
| first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) | ||
| first_weight = first_weight.permute(0, 2, 1).contiguous() | ||
| second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) | ||
| second_weight = second_weight.permute(1, 2, 0).contiguous() | ||
| return first_weight, second_weight, scaling, num_experts | ||
| dim_A = weight_A.shape[1] # in_dim | ||
| dim_B = weight_B.shape[0] # out_dim | ||
|
|
||
| first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) | ||
| first_weight = first_weight.permute(0, 2, 1).contiguous() | ||
|
|
||
| second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) | ||
| second_weight = second_weight.permute(1, 2, 0).contiguous() |
There was a problem hiding this comment.
The variable names dim_A and dim_B can be improved to in_dim and out_dim to match the terminology used in the docstring and improve code clarity. Additionally, total_rank is only used once and can be inlined to simplify the logic.
| total_rank = weight_A.shape[0] | |
| rank_per_expert = total_rank // num_experts | |
| dim_A = weight_A.shape[1] | |
| dim_B = weight_B.shape[0] | |
| hidden_dim = None | |
| intermediate_dim = None | |
| current = wrapper | |
| while hasattr(current, "base_layer"): | |
| current = current.base_layer | |
| if hasattr(current, "hidden_dim"): | |
| hidden_dim = current.hidden_dim | |
| if hasattr(current, "intermediate_dim"): | |
| intermediate_dim = current.intermediate_dim | |
| if hasattr(current, "gate_up_proj") and hasattr(current.gate_up_proj, "shape"): | |
| shape = current.gate_up_proj.shape | |
| if len(shape) == 3: | |
| hidden_dim = shape[2] | |
| intermediate_dim = shape[1] // 2 | |
| param_name = getattr(wrapper, "parameter_name", None) | |
| if param_name == "down_proj" and intermediate_dim is not None and hidden_dim is not None: | |
| first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) | |
| first_weight = first_weight.permute(1, 0, 2).contiguous() | |
| second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) | |
| return first_weight, second_weight, scaling, num_experts | |
| elif param_name == "gate_up_proj" and hidden_dim is not None: | |
| first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) | |
| first_weight = first_weight.permute(1, 0, 2).contiguous() | |
| second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) | |
| return first_weight, second_weight, scaling, num_experts | |
| if hidden_dim is not None: | |
| if dim_B == hidden_dim: | |
| first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) | |
| first_weight = first_weight.permute(1, 0, 2).contiguous() | |
| second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) | |
| return first_weight, second_weight, scaling, num_experts | |
| elif dim_A == hidden_dim: | |
| first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) | |
| first_weight = first_weight.permute(0, 2, 1).contiguous() | |
| second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) | |
| second_weight = second_weight.permute(1, 2, 0).contiguous() | |
| return first_weight, second_weight, scaling, num_experts | |
| dim_A = weight_A.shape[1] # in_dim | |
| dim_B = weight_B.shape[0] # out_dim | |
| first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) | |
| first_weight = first_weight.permute(0, 2, 1).contiguous() | |
| second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) | |
| second_weight = second_weight.permute(1, 2, 0).contiguous() | |
| rank_per_expert = weight_A.shape[0] // num_experts | |
| in_dim = weight_A.shape[1] | |
| out_dim = weight_B.shape[0] | |
| first_weight = weight_A.view(num_experts, rank_per_expert, in_dim) | |
| first_weight = first_weight.permute(0, 2, 1).contiguous() | |
| second_weight = weight_B.view(out_dim, num_experts, rank_per_expert) | |
| second_weight = second_weight.permute(1, 2, 0).contiguous() |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request simplifies the LoRA weight extraction logic for Qwen-family MoE models in unsloth_zoo/temporary_patches/qwen3_moe.py. The update removes complex conditional logic that previously checked for specific parameter names and dimensions, replacing it with a streamlined implementation that directly reshapes and permutes weights based on PEFT's standard dimension layout. The documentation has also been updated to clarify the expected tensor shapes and the rationale for the simplified mapping. I have no feedback to provide.
86b03ce to
d8f30d6
Compare
|
Fixes pushed to unslothai#601. |
Staging mirror of unslothai#601
Original PR: unslothai#601
Author: lordx64
This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.
Original description
Summary
Fix a shape-mismatch bug in the shared
_make_qwen_moe_lora_extractor(used by Qwen3-MoE, Qwen3.5/3.6 MoE, and Qwen3-Next) that causestorch._grouped_mmto raiseRuntimeError: contraction dimension of mat_a and mat_b must matchduring the first training step on MoE+LoRA.Reproduction
Root cause
_make_qwen_moe_lora_extractorhas three code paths that all constructfirst_weightfromweight_B(PEFT'slora_B, shape(out_dim, E*R)):But
forward_native_grouped_mmcalls:with
permuted_inputof shape(N, in_dim).torch._grouped_mmrequiresfirst_weight.shape[-2] == in_dim, but the extractor givesout_dimthere.On Qwen3.6-35B-A3B (
gate_up_projstored as(E, 2*I, H)inmodeling_qwen3_5_moe.py:723),in_dim = Handout_dim = 2*I— and these are unequal for every MoE config, so the mismatch fires on every run with expert LoRA.The fallback at the bottom of the function (lines 102-106 pre-fix) was already correct, but none of the earlier branches fell through to it for Qwen3-MoE-style models.
Fix
Drop the broken branches. The correct mapping — identical to:
moe_utils.py::_extract_lora_from_wrapper(lines 421-426),qwen3_vl_moe.py::_qwen3_vl_lora_extractor(lines 275-279),— is format-independent:
PEFT LoRA weights have fixed shape relative to the linear's in/out di