diff --git a/test_qwen_moe_lora_extractor.py b/test_qwen_moe_lora_extractor.py new file mode 100644 index 000000000..4858bddd7 --- /dev/null +++ b/test_qwen_moe_lora_extractor.py @@ -0,0 +1,100 @@ +import pytest +import torch + +from unsloth_zoo.temporary_patches.qwen3_moe import _make_qwen_moe_lora_extractor + + +@pytest.mark.parametrize( + "E,R,in_dim,out_dim", + [ + (4, 4, 256, 128), + (4, 4, 128, 256), + (8, 2, 512, 1024), + (16, 8, 2048, 512), + (2, 1, 64, 32), + ], +) +def test_extractor_shapes(E, R, in_dim, out_dim): + ext = _make_qwen_moe_lora_extractor() + wA = torch.randn(E * R, in_dim) + wB = torch.randn(out_dim, E * R) + first, second, scaling, num_experts = ext(None, wA, wB, 2.5, E) + assert first.shape == (E, in_dim, R) + assert second.shape == (E, R, out_dim) + assert scaling == 2.5 + assert num_experts == E + assert first.is_contiguous() + assert second.is_contiguous() + + +def test_extractor_numerical_equivalence_per_expert(): + ext = _make_qwen_moe_lora_extractor() + E, R, in_dim, out_dim = 4, 4, 64, 128 + torch.manual_seed(0) + wA = torch.randn(E * R, in_dim) + wB = torch.randn(out_dim, E * R) + first, second, _, _ = ext(None, wA, wB, 1.0, E) + x = torch.randn(6, in_dim) + for e in range(E): + Ae = wA[e * R : (e + 1) * R] + Be = wB[:, e * R : (e + 1) * R] + naive = x @ Ae.T @ Be.T + via = (x @ first[e]) @ second[e] + torch.testing.assert_close(via, naive, atol=1e-4, rtol=1e-4) + + +def test_extractor_ignores_wrapper_attributes(): + ext = _make_qwen_moe_lora_extractor() + E, R, in_dim, out_dim = 4, 4, 64, 128 + torch.manual_seed(2) + wA = torch.randn(E * R, in_dim) + wB = torch.randn(out_dim, E * R) + + class _Bogus: + parameter_name = "down_proj" + base_layer = None + + first_none, second_none, _, _ = ext(None, wA, wB, 1.0, E) + first_bogus, second_bogus, _, _ = ext(_Bogus(), wA, wB, 1.0, E) + torch.testing.assert_close(first_none, first_bogus) + torch.testing.assert_close(second_none, second_bogus) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +def test_extractor_preserves_dtype(dtype): + ext = _make_qwen_moe_lora_extractor() + E, R, in_dim, out_dim = 4, 4, 64, 128 + wA = torch.randn(E * R, in_dim, dtype=dtype) + wB = torch.randn(out_dim, E * R, dtype=dtype) + first, second, _, _ = ext(None, wA, wB, 1.0, E) + assert first.dtype == dtype + assert second.dtype == dtype + + +def test_extractor_passes_tensor_scaling_unchanged(): + ext = _make_qwen_moe_lora_extractor() + E, R, in_dim, out_dim = 4, 4, 64, 128 + wA = torch.randn(E * R, in_dim) + wB = torch.randn(out_dim, E * R) + scaling = torch.tensor(0.125) + _, _, returned, _ = ext(None, wA, wB, scaling, E) + assert returned is scaling + + +def test_shared_factory_used_by_qwen3_5_and_next(): + import unsloth_zoo.temporary_patches.qwen3_5_moe as q35 + import unsloth_zoo.temporary_patches.qwen3_next_moe as qnx + assert q35._make_qwen_moe_lora_extractor is _make_qwen_moe_lora_extractor + assert qnx._make_qwen_moe_lora_extractor is _make_qwen_moe_lora_extractor + + +def test_extractor_output_not_aliasing_input(): + ext = _make_qwen_moe_lora_extractor() + E, R, in_dim, out_dim = 2, 2, 8, 16 + wA = torch.randn(E * R, in_dim) + wB = torch.randn(out_dim, E * R) + first, second, _, _ = ext(None, wA, wB, 1.0, E) + wA.zero_() + wB.zero_() + assert first.abs().sum().item() > 0 + assert second.abs().sum().item() > 0 diff --git a/unsloth_zoo/temporary_patches/qwen3_moe.py b/unsloth_zoo/temporary_patches/qwen3_moe.py index 8d6c08e6b..073d19c19 100644 --- a/unsloth_zoo/temporary_patches/qwen3_moe.py +++ b/unsloth_zoo/temporary_patches/qwen3_moe.py @@ -45,64 +45,24 @@ def _make_qwen_moe_lora_extractor(): def _qwen_moe_lora_extractor(wrapper, weight_A, weight_B, scaling, num_experts): """ - Robust extractor for Qwen-family MoE that handles PEFT's dimension layout. + LoRA extractor for Qwen-family MoE (Qwen3-MoE, Qwen3.5/3.6, Qwen3-Next). - Expectation for grouped_mm: - - first_weight: (E, in_dim, R) [Input projection to rank] - - second_weight: (E, R, out_dim) [Rank projection to output] + PEFT LoRA shapes are fixed by the linear's in/out dims, independent of + raw base-weight storage order, so no model-specific dispatch is needed: + weight_A: (E*R, in_dim) -> (E, in_dim, R) + weight_B: (out_dim, E*R) -> (E, R, out_dim) """ total_rank = weight_A.shape[0] rank_per_expert = total_rank // num_experts - - dim_A = weight_A.shape[1] - dim_B = weight_B.shape[0] - - hidden_dim = None - intermediate_dim = None - current = wrapper - while hasattr(current, "base_layer"): - current = current.base_layer - if hasattr(current, "hidden_dim"): - hidden_dim = current.hidden_dim - if hasattr(current, "intermediate_dim"): - intermediate_dim = current.intermediate_dim - if hasattr(current, "gate_up_proj") and hasattr(current.gate_up_proj, "shape"): - shape = current.gate_up_proj.shape - if len(shape) == 3: - hidden_dim = shape[2] - intermediate_dim = shape[1] // 2 - - param_name = getattr(wrapper, "parameter_name", None) - - if param_name == "down_proj" and intermediate_dim is not None and hidden_dim is not None: - first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - first_weight = first_weight.permute(1, 0, 2).contiguous() - second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) - return first_weight, second_weight, scaling, num_experts - - elif param_name == "gate_up_proj" and hidden_dim is not None: - first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - first_weight = first_weight.permute(1, 0, 2).contiguous() - second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) - return first_weight, second_weight, scaling, num_experts - - if hidden_dim is not None: - if dim_B == hidden_dim: - first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - first_weight = first_weight.permute(1, 0, 2).contiguous() - second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) - return first_weight, second_weight, scaling, num_experts - elif dim_A == hidden_dim: - first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) - first_weight = first_weight.permute(0, 2, 1).contiguous() - second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - second_weight = second_weight.permute(1, 2, 0).contiguous() - return first_weight, second_weight, scaling, num_experts + dim_A = weight_A.shape[1] # in_dim + dim_B = weight_B.shape[0] # out_dim first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) first_weight = first_weight.permute(0, 2, 1).contiguous() + second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) second_weight = second_weight.permute(1, 2, 0).contiguous() + return first_weight, second_weight, scaling, num_experts return _qwen_moe_lora_extractor