Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions test_qwen_moe_lora_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest
import torch

from unsloth_zoo.temporary_patches.qwen3_moe import _make_qwen_moe_lora_extractor


@pytest.mark.parametrize(
"E,R,in_dim,out_dim",
[
(4, 4, 256, 128),
(4, 4, 128, 256),
(8, 2, 512, 1024),
(16, 8, 2048, 512),
(2, 1, 64, 32),
],
)
def test_extractor_shapes(E, R, in_dim, out_dim):
ext = _make_qwen_moe_lora_extractor()
wA = torch.randn(E * R, in_dim)
wB = torch.randn(out_dim, E * R)
first, second, scaling, num_experts = ext(None, wA, wB, 2.5, E)
assert first.shape == (E, in_dim, R)
assert second.shape == (E, R, out_dim)
assert scaling == 2.5
assert num_experts == E
assert first.is_contiguous()
assert second.is_contiguous()


def test_extractor_numerical_equivalence_per_expert():
ext = _make_qwen_moe_lora_extractor()
E, R, in_dim, out_dim = 4, 4, 64, 128
torch.manual_seed(0)
wA = torch.randn(E * R, in_dim)
wB = torch.randn(out_dim, E * R)
first, second, _, _ = ext(None, wA, wB, 1.0, E)
x = torch.randn(6, in_dim)
for e in range(E):
Ae = wA[e * R : (e + 1) * R]
Be = wB[:, e * R : (e + 1) * R]
naive = x @ Ae.T @ Be.T
via = (x @ first[e]) @ second[e]
torch.testing.assert_close(via, naive, atol=1e-4, rtol=1e-4)


def test_extractor_ignores_wrapper_attributes():
ext = _make_qwen_moe_lora_extractor()
E, R, in_dim, out_dim = 4, 4, 64, 128
torch.manual_seed(2)
wA = torch.randn(E * R, in_dim)
wB = torch.randn(out_dim, E * R)

class _Bogus:
parameter_name = "down_proj"
base_layer = None

first_none, second_none, _, _ = ext(None, wA, wB, 1.0, E)
first_bogus, second_bogus, _, _ = ext(_Bogus(), wA, wB, 1.0, E)
torch.testing.assert_close(first_none, first_bogus)
torch.testing.assert_close(second_none, second_bogus)


@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
def test_extractor_preserves_dtype(dtype):
ext = _make_qwen_moe_lora_extractor()
E, R, in_dim, out_dim = 4, 4, 64, 128
wA = torch.randn(E * R, in_dim, dtype=dtype)
wB = torch.randn(out_dim, E * R, dtype=dtype)
first, second, _, _ = ext(None, wA, wB, 1.0, E)
assert first.dtype == dtype
assert second.dtype == dtype


def test_extractor_passes_tensor_scaling_unchanged():
ext = _make_qwen_moe_lora_extractor()
E, R, in_dim, out_dim = 4, 4, 64, 128
wA = torch.randn(E * R, in_dim)
wB = torch.randn(out_dim, E * R)
scaling = torch.tensor(0.125)
_, _, returned, _ = ext(None, wA, wB, scaling, E)
assert returned is scaling


def test_shared_factory_used_by_qwen3_5_and_next():
import unsloth_zoo.temporary_patches.qwen3_5_moe as q35
import unsloth_zoo.temporary_patches.qwen3_next_moe as qnx
assert q35._make_qwen_moe_lora_extractor is _make_qwen_moe_lora_extractor
assert qnx._make_qwen_moe_lora_extractor is _make_qwen_moe_lora_extractor


def test_extractor_output_not_aliasing_input():
ext = _make_qwen_moe_lora_extractor()
E, R, in_dim, out_dim = 2, 2, 8, 16
wA = torch.randn(E * R, in_dim)
wB = torch.randn(out_dim, E * R)
first, second, _, _ = ext(None, wA, wB, 1.0, E)
wA.zero_()
wB.zero_()
assert first.abs().sum().item() > 0
assert second.abs().sum().item() > 0
58 changes: 9 additions & 49 deletions unsloth_zoo/temporary_patches/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,64 +45,24 @@
def _make_qwen_moe_lora_extractor():
def _qwen_moe_lora_extractor(wrapper, weight_A, weight_B, scaling, num_experts):
"""
Robust extractor for Qwen-family MoE that handles PEFT's dimension layout.
LoRA extractor for Qwen-family MoE (Qwen3-MoE, Qwen3.5/3.6, Qwen3-Next).

Expectation for grouped_mm:
- first_weight: (E, in_dim, R) [Input projection to rank]
- second_weight: (E, R, out_dim) [Rank projection to output]
PEFT LoRA shapes are fixed by the linear's in/out dims, independent of
raw base-weight storage order, so no model-specific dispatch is needed:
weight_A: (E*R, in_dim) -> (E, in_dim, R)
weight_B: (out_dim, E*R) -> (E, R, out_dim)
"""
total_rank = weight_A.shape[0]
rank_per_expert = total_rank // num_experts

dim_A = weight_A.shape[1]
dim_B = weight_B.shape[0]

hidden_dim = None
intermediate_dim = None
current = wrapper
while hasattr(current, "base_layer"):
current = current.base_layer
if hasattr(current, "hidden_dim"):
hidden_dim = current.hidden_dim
if hasattr(current, "intermediate_dim"):
intermediate_dim = current.intermediate_dim
if hasattr(current, "gate_up_proj") and hasattr(current.gate_up_proj, "shape"):
shape = current.gate_up_proj.shape
if len(shape) == 3:
hidden_dim = shape[2]
intermediate_dim = shape[1] // 2

param_name = getattr(wrapper, "parameter_name", None)

if param_name == "down_proj" and intermediate_dim is not None and hidden_dim is not None:
first_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
first_weight = first_weight.permute(1, 0, 2).contiguous()
second_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
return first_weight, second_weight, scaling, num_experts

elif param_name == "gate_up_proj" and hidden_dim is not None:
first_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
first_weight = first_weight.permute(1, 0, 2).contiguous()
second_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
return first_weight, second_weight, scaling, num_experts

if hidden_dim is not None:
if dim_B == hidden_dim:
first_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
first_weight = first_weight.permute(1, 0, 2).contiguous()
second_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
return first_weight, second_weight, scaling, num_experts
elif dim_A == hidden_dim:
first_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
first_weight = first_weight.permute(0, 2, 1).contiguous()
second_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
second_weight = second_weight.permute(1, 2, 0).contiguous()
return first_weight, second_weight, scaling, num_experts
dim_A = weight_A.shape[1] # in_dim
dim_B = weight_B.shape[0] # out_dim

first_weight = weight_A.view(num_experts, rank_per_expert, dim_A)
first_weight = first_weight.permute(0, 2, 1).contiguous()

second_weight = weight_B.view(dim_B, num_experts, rank_per_expert)
second_weight = second_weight.permute(1, 2, 0).contiguous()

return first_weight, second_weight, scaling, num_experts

return _qwen_moe_lora_extractor
Expand Down