Refactor and consolidate moe lora extractors#629
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the LoRA weight extraction logic for Mixture-of-Experts (MoE) models by centralizing the implementation into a shared utility function, extract_moe_lora_weights_for_grouped_mm. This change simplifies the codebase for DeepSeekV3, Gemma4, GLM4, and Qwen MoE models. A critical issue was identified in the new utility function where handling num_experts <= 1 returns 2D tensors that are incompatible with the expected 3D format for grouped matrix multiplication, and a potential division-by-zero error exists if num_experts is zero.
| total_rank = weight_A.shape[0] | ||
| rank_per_expert = total_rank // num_experts | ||
| dim_A = weight_A.shape[1] | ||
| dim_B = weight_B.shape[0] | ||
|
|
||
| if num_experts <= 1: | ||
| return weight_A.T, weight_B.T, scaling, num_experts |
There was a problem hiding this comment.
The current logic for num_experts <= 1 returns 2D tensors, which is incompatible with torch._grouped_mm (which expects 3D weights) and the expert indexing used in the MoE forward passes (e.g., first_weight[expert_idx]). Additionally, rank_per_expert is calculated before the check, which would lead to a ZeroDivisionError if num_experts is 0.
For num_experts == 1, the function should proceed with the 3D reshaping logic to maintain compatibility with the MoE backends. The check should be changed to num_experts < 1 and moved before the division to avoid potential runtime errors.
| total_rank = weight_A.shape[0] | |
| rank_per_expert = total_rank // num_experts | |
| dim_A = weight_A.shape[1] | |
| dim_B = weight_B.shape[0] | |
| if num_experts <= 1: | |
| return weight_A.T, weight_B.T, scaling, num_experts | |
| total_rank = weight_A.shape[0] | |
| dim_A = weight_A.shape[1] | |
| dim_B = weight_B.shape[0] | |
| if num_experts < 1: | |
| return weight_A.T, weight_B.T, scaling, num_experts | |
| rank_per_expert = total_rank // num_experts |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f86e4c7125
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| grouped_mm_format = bool(getattr(source, "_unsloth_grouped_mm_format", False)) | ||
| if grouped_mm_format: | ||
| return shape[-2], shape[-1] | ||
| return shape[-1], shape[-2] |
There was a problem hiding this comment.
Preserve GLM4 transposed LoRA extraction path
The shared extractor now infers MoE input/output dims from parameter shape and, when _unsloth_grouped_mm_format is not set, assumes a (E, out_dim, in_dim) layout by returning shape[-1], shape[-2]. GLM4 previously had explicit handling for its transposed (E, in_dim, out_dim) cases, but this commit removed that model-specific logic and routes through the generic path, so transposed GLM4 checkpoints can be misclassified as canonical. In that scenario first_weight is built on the wrong axis, which can cause grouped-mm dimension mismatches or silently wrong LoRA deltas during training.
Useful? React with 👍 / 👎.
Copied from #618 . That PR fixed it for qwen3 family. Now this is to extend the same to other MoE families like gemma4, deepseek, glm4_moe
Initially introduced in #601 as an attempt to fix the extractor.
When I initially designed it, I tried to make it shape inferring because transformers v5 was still a WIP. The above mentioned PR modified that functionality to fix a regression.
The root cause of the regression seems to be peft changing the lora addition for MoE
PEFT 0.18 used raw 3D parameter dims for ParamWrapper LoRA:
https://github.com/huggingface/peft/blob/v0.18.0/src/peft/tuners/lora/layer.py#L1928-L1931
PEFT 0.19 swaps non-transposed 3D params before creating lora_A/B:
https://github.com/huggingface/peft/blob/v0.19.1/src/peft/tuners/lora/layer.py#L2201-L2205