Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 11 additions & 46 deletions unsloth_zoo/temporary_patches/deepseek_v3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .moe_utils import (
patch_param_wrapper_for_moe,
get_forward_moe_backend,
extract_moe_lora_weights_for_grouped_mm,
)

def patch_deepseek_v3():
Expand Down Expand Up @@ -73,52 +74,16 @@ def patch_deepseek_v3():
# Define LoRA extraction function for DeepSeekV3 (Standard Format)
# ====================================================================
def _deepseek_v3_lora_extractor(wrapper, weight_A, weight_B, scaling, num_experts):
"""
Custom LoRA extractor for DeepSeekV3.

DeepSeekV3 expert weights are stored as (E, out_dim, in_dim) and PEFT's ParamWrapper
treats dim1 as in_features and dim2 as out_features. For correct separated LoRA
(X @ first @ second), we need to pick the weight that connects to the actual input dim.
"""
total_rank = weight_A.shape[0]
rank_per_expert = total_rank // num_experts
dim_A = weight_A.shape[1]
dim_B = weight_B.shape[0]

input_dim = None
if hasattr(wrapper, "parameter_name"):
if wrapper.parameter_name == "gate_up_proj":
base = wrapper.get_base_layer() if hasattr(wrapper, "get_base_layer") else None
input_dim = getattr(base, "hidden_dim", None)
elif wrapper.parameter_name == "down_proj":
base = wrapper.get_base_layer() if hasattr(wrapper, "get_base_layer") else None
input_dim = getattr(base, "intermediate_dim", None)

if input_dim is None:
base = wrapper.get_base_layer() if hasattr(wrapper, "get_base_layer") else None
input_dim = getattr(base, "hidden_dim", None)

# If lora_A connects to input_dim: standard (A then B)
if input_dim is not None and dim_A == input_dim:
first_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
first_weight = first_weight.permute(0, 2, 1).contiguous() # (E, input_dim, R)
second_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
second_weight = second_weight.permute(1, 2, 0).contiguous() # (E, R, out_dim)
return first_weight, second_weight, scaling, num_experts

# If lora_B connects to input_dim: swapped (B then A)
if input_dim is not None and dim_B == input_dim:
first_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
first_weight = first_weight.permute(1, 0, 2).contiguous() # (E, input_dim, R)
second_weight = weight_A.view(num_experts, rank_per_expert, dim_A).contiguous() # (E, R, out_dim)
return first_weight, second_weight, scaling, num_experts

# Fallback: standard (A then B)
first_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
first_weight = first_weight.permute(0, 2, 1).contiguous()
second_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
second_weight = second_weight.permute(1, 2, 0).contiguous()
return first_weight, second_weight, scaling, num_experts
return extract_moe_lora_weights_for_grouped_mm(
wrapper,
weight_A,
weight_B,
scaling,
num_experts,
model_name="DeepSeekV3 MoE",
enable_logging=UNSLOTH_ENABLE_LOGGING,
logger_obj=logger,
)

# Register the extractor on the NaiveMoe class (avoid binding as instance method)
DeepseekV3NaiveMoe._unsloth_lora_extractor_fn = staticmethod(_deepseek_v3_lora_extractor)
Expand Down
18 changes: 17 additions & 1 deletion unsloth_zoo/temporary_patches/gemma4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,28 @@
import os
import torch
import torch.nn as nn
from .common import TEMPORARY_PATCHES
from .common import TEMPORARY_PATCHES, UNSLOTH_ENABLE_LOGGING
from .utils import patch_function, process_return, raise_error, logger
from .moe_utils import (
patch_param_wrapper_for_moe,
get_forward_moe_backend,
extract_moe_lora_weights_for_grouped_mm,
)


def _gemma4_moe_lora_extractor(wrapper, weight_A, weight_B, scaling, num_experts):
return extract_moe_lora_weights_for_grouped_mm(
wrapper,
weight_A,
weight_B,
scaling,
num_experts,
model_name="Gemma4 MoE",
enable_logging=UNSLOTH_ENABLE_LOGGING,
logger_obj=logger,
)


def patch_gemma4_grpo_hidden_states():
"""Patch Gemma4ForConditionalGeneration.forward for GRPO hidden states.

Expand Down Expand Up @@ -183,6 +197,7 @@ def _gemma4_experts_forward(self, hidden_states, top_k_index, top_k_weights):

ok = patch_function(Gemma4TextExperts, "forward", _gemma4_experts_forward, force=True)
if ok:
Gemma4TextExperts._unsloth_lora_extractor_fn = staticmethod(_gemma4_moe_lora_extractor)
Gemma4TextExperts._unsloth_already_patched = True
return ok

Expand Down Expand Up @@ -243,6 +258,7 @@ def _gemma4_moe_forward(self, hidden_states, top_k_index, top_k_weights):
if not forward_ok:
return False

Gemma4TextMoEBlock._unsloth_lora_extractor_fn = staticmethod(_gemma4_moe_lora_extractor)
Gemma4TextMoEBlock._unsloth_already_patched = True
return True

Expand Down
49 changes: 11 additions & 38 deletions unsloth_zoo/temporary_patches/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .moe_utils import (
patch_param_wrapper_for_moe,
get_forward_moe_backend,
extract_moe_lora_weights_for_grouped_mm,
)
def patch_glm4_moe():
"""
Expand All @@ -42,44 +43,16 @@ def patch_glm4_moe():
# Define LoRA extraction function for GLM4-MoE (Standard Format)
# ====================================================================
def _glm4_lora_extractor(wrapper, weight_A, weight_B, scaling, num_experts):
"""
Custom LoRA extractor for GLM4.

Expectation for grouped_mm (Standard):
- first (Input): (E, H, R)
- second (Output): (E, R, Out)

GLM4 Weights (Standard PEFT):
- gate_up is (E, Out, In) or (Out, In).
- lora_A (In->R) connects to H. Shape (E*R, H).
Needs: View(E, R, H) -> Permute(0, 2, 1) -> (E, H, R).
- lora_B (R->Out) connects to 2I. Shape (2I, E*R).
Needs: View(2I, E, R) -> Permute(1, 2, 0) -> (E, R, 2I).
"""
total_rank = weight_A.shape[0]
rank_per_expert = total_rank // num_experts
dim1 = weight_A.shape[1]
dim2 = weight_B.shape[0]

# GLM4 MoE sometimes stores weights transposed (E, in_dim, out_dim),
# which flips LoRA A/B's input/output dimensions. Detect and handle both.
if dim1 > dim2:
# Transposed: weight_A is (E*R, out_dim), weight_B is (in_dim, E*R)
# first_weight from B: (E, in_dim, R)
first_weight = weight_B.view(dim2, num_experts, rank_per_expert)
first_weight = first_weight.permute(1, 0, 2).contiguous()

# second_weight from A: (E, R, out_dim)
second_weight = weight_A.view(num_experts, rank_per_expert, dim1).contiguous()
else:
# Standard: weight_A is (E*R, in_dim), weight_B is (out_dim, E*R)
first_weight = weight_A.view(num_experts, rank_per_expert, dim1)
first_weight = first_weight.permute(0, 2, 1).contiguous() # (E, in_dim, R)

second_weight = weight_B.view(dim2, num_experts, rank_per_expert)
second_weight = second_weight.permute(1, 2, 0).contiguous() # (E, R, out_dim)

return first_weight, second_weight, scaling, num_experts
return extract_moe_lora_weights_for_grouped_mm(
wrapper,
weight_A,
weight_B,
scaling,
num_experts,
model_name="GLM4 MoE",
enable_logging=UNSLOTH_ENABLE_LOGGING,
logger_obj=logger,
)

Glm4MoeLiteNaiveMoe._unsloth_lora_extractor_fn = staticmethod(_glm4_lora_extractor)

Expand Down
192 changes: 166 additions & 26 deletions unsloth_zoo/temporary_patches/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,163 @@ def _has_lora_adapters(param) -> bool:
return len(param.lora_A) > 0


def _canonical_lora_weights_for_grouped_mm(
weight_A: torch.Tensor,
weight_B: torch.Tensor,
num_experts: int,
rank_per_expert: int,
dim_A: int,
dim_B: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
first_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
first_weight = first_weight.permute(0, 2, 1).contiguous()
second_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
second_weight = second_weight.permute(1, 2, 0).contiguous()
return first_weight, second_weight


def _reversed_lora_weights_for_grouped_mm(
weight_A: torch.Tensor,
weight_B: torch.Tensor,
num_experts: int,
rank_per_expert: int,
dim_A: int,
dim_B: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
first_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
first_weight = first_weight.permute(1, 0, 2).contiguous()
second_weight = weight_A.view(num_experts, rank_per_expert, dim_A).contiguous()
return first_weight, second_weight


def _get_param_shape_from_module(module, parameter_name):
if module is None or parameter_name is None or not hasattr(module, parameter_name):
return None
param = getattr(module, parameter_name)
if hasattr(param, "get_param"):
param = param.get_param()
elif hasattr(param, "weight"):
param = param.weight
return tuple(param.shape)


def _get_moe_lora_io_dims(wrapper, experts_module=None):
base = None
if wrapper is not None and hasattr(wrapper, "get_base_layer"):
base = wrapper.get_base_layer()
if experts_module is None:
experts_module = base
if experts_module is None:
experts_module = getattr(wrapper, "base_layer", None)

parameter_name = getattr(wrapper, "parameter_name", None)
source = experts_module if experts_module is not None else base
if source is None:
return None, None

shape = _get_param_shape_from_module(source, parameter_name)
if shape is not None and len(shape) >= 3:
grouped_mm_format = bool(getattr(source, "_unsloth_grouped_mm_format", False))
if grouped_mm_format:
return shape[-2], shape[-1]
return shape[-1], shape[-2]
Comment on lines +393 to +396

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve GLM4 transposed LoRA extraction path

The shared extractor now infers MoE input/output dims from parameter shape and, when _unsloth_grouped_mm_format is not set, assumes a (E, out_dim, in_dim) layout by returning shape[-1], shape[-2]. GLM4 previously had explicit handling for its transposed (E, in_dim, out_dim) cases, but this commit removed that model-specific logic and routes through the generic path, so transposed GLM4 checkpoints can be misclassified as canonical. In that scenario first_weight is built on the wrong axis, which can cause grouped-mm dimension mismatches or silently wrong LoRA deltas during training.

Useful? React with 👍 / 👎.


hidden_dim = getattr(source, "hidden_dim", None)
intermediate_dim = getattr(source, "intermediate_dim", None)
if hidden_dim is None or intermediate_dim is None:
return None, None
if parameter_name == "gate_up_proj":
return hidden_dim, 2 * intermediate_dim
if parameter_name == "down_proj":
return intermediate_dim, hidden_dim
return None, None


def extract_moe_lora_weights_for_grouped_mm(
wrapper,
weight_A: torch.Tensor,
weight_B: torch.Tensor,
scaling,
num_experts: int,
*,
experts_module=None,
input_dim=None,
output_dim=None,
model_name: str = "MoE",
enable_logging: bool = None,
logger_obj=None,
) -> Tuple[torch.Tensor, torch.Tensor, float, int]:
total_rank = weight_A.shape[0]
rank_per_expert = total_rank // num_experts
dim_A = weight_A.shape[1]
dim_B = weight_B.shape[0]

if num_experts <= 1:
return weight_A.T, weight_B.T, scaling, num_experts
Comment on lines +423 to +429

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic for num_experts <= 1 returns 2D tensors, which is incompatible with torch._grouped_mm (which expects 3D weights) and the expert indexing used in the MoE forward passes (e.g., first_weight[expert_idx]). Additionally, rank_per_expert is calculated before the check, which would lead to a ZeroDivisionError if num_experts is 0.

For num_experts == 1, the function should proceed with the 3D reshaping logic to maintain compatibility with the MoE backends. The check should be changed to num_experts < 1 and moved before the division to avoid potential runtime errors.

Suggested change
total_rank = weight_A.shape[0]
rank_per_expert = total_rank // num_experts
dim_A = weight_A.shape[1]
dim_B = weight_B.shape[0]
if num_experts <= 1:
return weight_A.T, weight_B.T, scaling, num_experts
total_rank = weight_A.shape[0]
dim_A = weight_A.shape[1]
dim_B = weight_B.shape[0]
if num_experts < 1:
return weight_A.T, weight_B.T, scaling, num_experts
rank_per_expert = total_rank // num_experts


if input_dim is None or output_dim is None:
inferred_input_dim, inferred_output_dim = _get_moe_lora_io_dims(
wrapper, experts_module=experts_module,
)
if input_dim is None:
input_dim = inferred_input_dim
if output_dim is None:
output_dim = inferred_output_dim

canonical_match = (
input_dim is not None
and output_dim is not None
and dim_A == input_dim
and dim_B == output_dim
)
reversed_match = (
input_dim is not None
and output_dim is not None
and dim_A == output_dim
and dim_B == input_dim
)

if canonical_match and reversed_match:
if bool(getattr(wrapper, "_did_swap_in_out_features", False)):
first_weight, second_weight = _reversed_lora_weights_for_grouped_mm(
weight_A, weight_B, num_experts, rank_per_expert, dim_A, dim_B,
)
else:
first_weight, second_weight = _canonical_lora_weights_for_grouped_mm(
weight_A, weight_B, num_experts, rank_per_expert, dim_A, dim_B,
)
return first_weight, second_weight, scaling, num_experts

if canonical_match:
first_weight, second_weight = _canonical_lora_weights_for_grouped_mm(
weight_A, weight_B, num_experts, rank_per_expert, dim_A, dim_B,
)
return first_weight, second_weight, scaling, num_experts

if reversed_match:
first_weight, second_weight = _reversed_lora_weights_for_grouped_mm(
weight_A, weight_B, num_experts, rank_per_expert, dim_A, dim_B,
)
return first_weight, second_weight, scaling, num_experts

if logger_obj is not None:
if enable_logging is None:
enable_logging = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
if enable_logging and (input_dim is not None or output_dim is not None):
logger_obj.warning(
f"Unsloth: {model_name} LoRA extractor could not match either layout "
f"(weight_A={tuple(weight_A.shape)}, weight_B={tuple(weight_B.shape)}, "
f"expected input_dim={input_dim}, output_dim={output_dim}, "
f"num_experts={num_experts}). Falling back to canonical layout. "
"If this is a new PEFT version, the LoRA delta may be wrong."
)

first_weight, second_weight = _canonical_lora_weights_for_grouped_mm(
weight_A, weight_B, num_experts, rank_per_expert, dim_A, dim_B,
)
return first_weight, second_weight, scaling, num_experts


def _extract_lora_from_wrapper(
wrapper, adapter_name: str = "default", experts_module=None
) -> Optional[Tuple[torch.Tensor, torch.Tensor, float, int]]:
Expand Down Expand Up @@ -404,32 +561,15 @@ def _extract_lora_from_wrapper(
if extractor_fn is not None:
return extractor_fn(wrapper, weight_A, weight_B, scaling, num_experts)

# DEFAULT BEHAVIOR (Standard Format / Non-MoE)
if num_experts > 1:
total_rank = weight_A.shape[0]
rank_per_expert = total_rank // num_experts
dim1 = weight_A.shape[1]
dim2 = weight_B.shape[0]

# STANDARD FORMAT (Qwen3-MoE / GLM4):
# Base weights are (E, out_dim, in_dim) for F.linear.
# LoRA weights follow PEFT: weight_A is (E*R, in_dim), weight_B is (out_dim, E*R).
# We need X @ (E, in_dim, R) @ (E, R, out_dim).

# first_weight: (E, in_dim, R) - from lora_A
# second_weight: (E, R, out_dim) - from lora_B
first_weight = weight_A.view(num_experts, rank_per_expert, dim1)
first_weight = first_weight.permute(0, 2, 1).contiguous() # (E, dim1, R)

# second_weight (B): (E, R, out_dim)
second_weight = weight_B.view(dim2, num_experts, rank_per_expert)
second_weight = second_weight.permute(1, 2, 0).contiguous() # (E, R, dim2)
else:
# Non-MoE case: return weights for X @ A.T @ B.T
first_weight = weight_A.T # (dim1, R)
second_weight = weight_B.T # (R, dim2)

return first_weight, second_weight, scaling, num_experts
return extract_moe_lora_weights_for_grouped_mm(
wrapper,
weight_A,
weight_B,
scaling,
num_experts,
experts_module=experts_module,
model_name="MoE",
)
except Exception:
return None

Expand Down
Loading