Skip to content

Refactor and consolidate moe lora extractors#629

Merged
danielhanchen merged 1 commit into
unslothai:mainfrom
Datta0:moe_lora_extractor
May 13, 2026
Merged

Refactor and consolidate moe lora extractors#629
danielhanchen merged 1 commit into
unslothai:mainfrom
Datta0:moe_lora_extractor

Conversation

@Datta0
Copy link
Copy Markdown
Collaborator

@Datta0 Datta0 commented May 7, 2026

Copied from #618 . That PR fixed it for qwen3 family. Now this is to extend the same to other MoE families like gemma4, deepseek, glm4_moe
Initially introduced in #601 as an attempt to fix the extractor.
When I initially designed it, I tried to make it shape inferring because transformers v5 was still a WIP. The above mentioned PR modified that functionality to fix a regression.
The root cause of the regression seems to be peft changing the lora addition for MoE

PEFT 0.18 used raw 3D parameter dims for ParamWrapper LoRA:
https://github.com/huggingface/peft/blob/v0.18.0/src/peft/tuners/lora/layer.py#L1928-L1931

if param.ndim == 3:
    self.num_experts, self.in_features, self.out_features = param.shape
else:
    self.num_experts, self.in_features, self.out_features = 1, param.shape[1], param.shape[0]

PEFT 0.19 swaps non-transposed 3D params before creating lora_A/B:
https://github.com/huggingface/peft/blob/v0.19.1/src/peft/tuners/lora/layer.py#L2201-L2205

# for some MoE layers, the order is (experts, out_features, in_features)
is_transposed = getattr(self.get_base_layer(), "is_transposed", False)
swap_in_out_features = (self.get_param().ndim == 3) and not is_transposed
if swap_in_out_features and not self._did_swap_in_out_features:
    self.in_features, self.out_features = self.out_features, self.in_features
    self._did_swap_in_out_features = True

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the LoRA weight extraction logic for Mixture-of-Experts (MoE) models by centralizing the implementation into a shared utility function, extract_moe_lora_weights_for_grouped_mm. This change simplifies the codebase for DeepSeekV3, Gemma4, GLM4, and Qwen MoE models. A critical issue was identified in the new utility function where handling num_experts <= 1 returns 2D tensors that are incompatible with the expected 3D format for grouped matrix multiplication, and a potential division-by-zero error exists if num_experts is zero.

Comment on lines +423 to +429
total_rank = weight_A.shape[0]
rank_per_expert = total_rank // num_experts
dim_A = weight_A.shape[1]
dim_B = weight_B.shape[0]

if num_experts <= 1:
return weight_A.T, weight_B.T, scaling, num_experts
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic for num_experts <= 1 returns 2D tensors, which is incompatible with torch._grouped_mm (which expects 3D weights) and the expert indexing used in the MoE forward passes (e.g., first_weight[expert_idx]). Additionally, rank_per_expert is calculated before the check, which would lead to a ZeroDivisionError if num_experts is 0.

For num_experts == 1, the function should proceed with the 3D reshaping logic to maintain compatibility with the MoE backends. The check should be changed to num_experts < 1 and moved before the division to avoid potential runtime errors.

Suggested change
total_rank = weight_A.shape[0]
rank_per_expert = total_rank // num_experts
dim_A = weight_A.shape[1]
dim_B = weight_B.shape[0]
if num_experts <= 1:
return weight_A.T, weight_B.T, scaling, num_experts
total_rank = weight_A.shape[0]
dim_A = weight_A.shape[1]
dim_B = weight_B.shape[0]
if num_experts < 1:
return weight_A.T, weight_B.T, scaling, num_experts
rank_per_expert = total_rank // num_experts

@Datta0
Copy link
Copy Markdown
Collaborator Author

Datta0 commented May 7, 2026

Ref: #624 #618

@Datta0 Datta0 marked this pull request as ready for review May 13, 2026 04:40
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f86e4c7125

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +393 to +396
grouped_mm_format = bool(getattr(source, "_unsloth_grouped_mm_format", False))
if grouped_mm_format:
return shape[-2], shape[-1]
return shape[-1], shape[-2]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve GLM4 transposed LoRA extraction path

The shared extractor now infers MoE input/output dims from parameter shape and, when _unsloth_grouped_mm_format is not set, assumes a (E, out_dim, in_dim) layout by returning shape[-1], shape[-2]. GLM4 previously had explicit handling for its transposed (E, in_dim, out_dim) cases, but this commit removed that model-specific logic and routes through the generic path, so transposed GLM4 checkpoints can be misclassified as canonical. In that scenario first_weight is built on the wrong axis, which can cause grouped-mm dimension mismatches or silently wrong LoRA deltas during training.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants