From b8edf5b8555913f771ad6ffcef94115c4e9e16cd Mon Sep 17 00:00:00 2001 From: lordx64 Date: Sat, 18 Apr 2026 07:39:27 -0400 Subject: [PATCH 1/4] [MoE] Fix Qwen-family MoE LoRA extractor shape mismatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The shared `_make_qwen_moe_lora_extractor` (used by Qwen3-MoE, Qwen3.5/3.6 MoE, and Qwen3-Next) produced `first=(E, out_dim, R)` instead of the `(E, in_dim, R)` shape expected by `forward_native_grouped_mm`. On models like Qwen3.6-35B-A3B this triggered, during the first training step: torch._grouped_mm(inputs, weight, offs=offsets) RuntimeError: contraction dimension of mat_a and mat_b must match when `permuted_input` (N, in_dim) was matmul'd against a first_weight whose second-to-last dim was `out_dim` (e.g. `2*intermediate_dim` for gate_up_proj on Qwen3.6-35B-A3B's 256-expert architecture). Root cause: the explicit `param_name in ("gate_up_proj", "down_proj")` branches and the `dim_B == hidden_dim` branch all constructed `first_weight = weight_B.view(dim_B, E, R).permute(1, 0, 2)` — i.e. derived from `lora_B`, which has shape `(out_dim, E*R)` — so `first.shape[-2]` ended up as `out_dim`, not `in_dim`. The final fallback at the bottom of the function was already correct. Fix: drop the broken branches. The correct mapping — identical to the default extractor in `moe_utils.py::_extract_lora_from_wrapper` and to the working Qwen3-VL-MoE extractor in `qwen3_vl_moe.py::_qwen3_vl_lora_extractor` — is format-independent: weight_A : (E*R, in_dim) -> view(E, R, in_dim).permute(0, 2, 1) = (E, in_dim, R) weight_B : (out_dim, E*R) -> view(out_dim, E, R).permute(1, 2, 0) = (E, R, out_dim) PEFT LoRA weights have fixed shape relative to the linear's in/out dims; they don't depend on whether base weights are stored "standard" (E, out, in) or "transposed" (E, in, out) — that distinction is handled upstream by `preprocess_weight`. Verified against Qwen3.6-35B-A3B (unsloth/Qwen3.6-35B-A3B): the LoRA forward path through `torch._grouped_mm` no longer fails with the contraction-dim error, and training progresses past the first forward into the expected memory-bound regime. Co-Authored-By: Claude Opus 4.7 (1M context) --- unsloth_zoo/temporary_patches/qwen3_moe.py | 70 +++++++--------------- 1 file changed, 20 insertions(+), 50 deletions(-) diff --git a/unsloth_zoo/temporary_patches/qwen3_moe.py b/unsloth_zoo/temporary_patches/qwen3_moe.py index 8d6c08e6b..5db86bd0b 100644 --- a/unsloth_zoo/temporary_patches/qwen3_moe.py +++ b/unsloth_zoo/temporary_patches/qwen3_moe.py @@ -45,64 +45,34 @@ def _make_qwen_moe_lora_extractor(): def _qwen_moe_lora_extractor(wrapper, weight_A, weight_B, scaling, num_experts): """ - Robust extractor for Qwen-family MoE that handles PEFT's dimension layout. - - Expectation for grouped_mm: - - first_weight: (E, in_dim, R) [Input projection to rank] - - second_weight: (E, R, out_dim) [Rank projection to output] + LoRA extractor for Qwen-family MoE models (Qwen3-MoE, Qwen3.5/3.6 MoE, + Qwen3-Next). Base weights are stored as (E, out_dim, in_dim) for F.linear + and reshaped to (E, in_dim, out_dim) for grouped_mm by preprocess_weight. + + Expectation for grouped_mm forward: out = ((X @ first) @ second) * scaling + - first_weight: (E, in_dim, R) + - second_weight: (E, R, out_dim) + + PEFT ParamWrapper produces: + weight_A: (E*R, in_dim) -> view(E, R, in_dim).permute(0, 2, 1) = (E, in_dim, R) + weight_B: (out_dim, E*R) -> view(out_dim, E, R).permute(1, 2, 0) = (E, R, out_dim) + + This mapping is independent of whether base weights are stored in + "standard" (E, out, in) or "transposed" (E, in, out) layout, since PEFT + LoRA weights have a fixed shape relative to the linear's in/out dims, + not the raw base-weight storage order. """ total_rank = weight_A.shape[0] rank_per_expert = total_rank // num_experts - - dim_A = weight_A.shape[1] - dim_B = weight_B.shape[0] - - hidden_dim = None - intermediate_dim = None - current = wrapper - while hasattr(current, "base_layer"): - current = current.base_layer - if hasattr(current, "hidden_dim"): - hidden_dim = current.hidden_dim - if hasattr(current, "intermediate_dim"): - intermediate_dim = current.intermediate_dim - if hasattr(current, "gate_up_proj") and hasattr(current.gate_up_proj, "shape"): - shape = current.gate_up_proj.shape - if len(shape) == 3: - hidden_dim = shape[2] - intermediate_dim = shape[1] // 2 - - param_name = getattr(wrapper, "parameter_name", None) - - if param_name == "down_proj" and intermediate_dim is not None and hidden_dim is not None: - first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - first_weight = first_weight.permute(1, 0, 2).contiguous() - second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) - return first_weight, second_weight, scaling, num_experts - - elif param_name == "gate_up_proj" and hidden_dim is not None: - first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - first_weight = first_weight.permute(1, 0, 2).contiguous() - second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) - return first_weight, second_weight, scaling, num_experts - - if hidden_dim is not None: - if dim_B == hidden_dim: - first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - first_weight = first_weight.permute(1, 0, 2).contiguous() - second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) - return first_weight, second_weight, scaling, num_experts - elif dim_A == hidden_dim: - first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) - first_weight = first_weight.permute(0, 2, 1).contiguous() - second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - second_weight = second_weight.permute(1, 2, 0).contiguous() - return first_weight, second_weight, scaling, num_experts + dim_A = weight_A.shape[1] # in_dim + dim_B = weight_B.shape[0] # out_dim first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) first_weight = first_weight.permute(0, 2, 1).contiguous() + second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) second_weight = second_weight.permute(1, 2, 0).contiguous() + return first_weight, second_weight, scaling, num_experts return _qwen_moe_lora_extractor From d8f30d64eb3308ebfa605155faf654bb9effb982 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 20 Apr 2026 15:00:52 +0000 Subject: [PATCH 2/4] Trim qwen MoE LoRA extractor docstring to load-bearing rationale --- unsloth_zoo/temporary_patches/qwen3_moe.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/unsloth_zoo/temporary_patches/qwen3_moe.py b/unsloth_zoo/temporary_patches/qwen3_moe.py index 5db86bd0b..073d19c19 100644 --- a/unsloth_zoo/temporary_patches/qwen3_moe.py +++ b/unsloth_zoo/temporary_patches/qwen3_moe.py @@ -45,22 +45,12 @@ def _make_qwen_moe_lora_extractor(): def _qwen_moe_lora_extractor(wrapper, weight_A, weight_B, scaling, num_experts): """ - LoRA extractor for Qwen-family MoE models (Qwen3-MoE, Qwen3.5/3.6 MoE, - Qwen3-Next). Base weights are stored as (E, out_dim, in_dim) for F.linear - and reshaped to (E, in_dim, out_dim) for grouped_mm by preprocess_weight. - - Expectation for grouped_mm forward: out = ((X @ first) @ second) * scaling - - first_weight: (E, in_dim, R) - - second_weight: (E, R, out_dim) - - PEFT ParamWrapper produces: - weight_A: (E*R, in_dim) -> view(E, R, in_dim).permute(0, 2, 1) = (E, in_dim, R) - weight_B: (out_dim, E*R) -> view(out_dim, E, R).permute(1, 2, 0) = (E, R, out_dim) - - This mapping is independent of whether base weights are stored in - "standard" (E, out, in) or "transposed" (E, in, out) layout, since PEFT - LoRA weights have a fixed shape relative to the linear's in/out dims, - not the raw base-weight storage order. + LoRA extractor for Qwen-family MoE (Qwen3-MoE, Qwen3.5/3.6, Qwen3-Next). + + PEFT LoRA shapes are fixed by the linear's in/out dims, independent of + raw base-weight storage order, so no model-specific dispatch is needed: + weight_A: (E*R, in_dim) -> (E, in_dim, R) + weight_B: (out_dim, E*R) -> (E, R, out_dim) """ total_rank = weight_A.shape[0] rank_per_expert = total_rank // num_experts From d3e29e87c4974012bd54072d8c04d5ad7dc31a82 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 20 Apr 2026 14:59:53 +0000 Subject: [PATCH 3/4] Add review tests for qwen MoE LoRA extractor --- test_qwen_moe_lora_extractor_dtype.py | 68 +++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 test_qwen_moe_lora_extractor_dtype.py diff --git a/test_qwen_moe_lora_extractor_dtype.py b/test_qwen_moe_lora_extractor_dtype.py new file mode 100644 index 000000000..6ac31b299 --- /dev/null +++ b/test_qwen_moe_lora_extractor_dtype.py @@ -0,0 +1,68 @@ +import pytest +import torch + +from unsloth_zoo.temporary_patches.qwen3_moe import _make_qwen_moe_lora_extractor + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +def test_extractor_preserves_dtype(dtype): + ext = _make_qwen_moe_lora_extractor() + E, R, in_dim, out_dim = 4, 4, 64, 128 + wA = torch.randn(E * R, in_dim, dtype=dtype) + wB = torch.randn(out_dim, E * R, dtype=dtype) + first, second, _, _ = ext(None, wA, wB, 1.0, E) + assert first.dtype == dtype + assert second.dtype == dtype + + +def test_extractor_passes_tensor_scaling_unchanged(): + ext = _make_qwen_moe_lora_extractor() + E, R, in_dim, out_dim = 4, 4, 64, 128 + wA = torch.randn(E * R, in_dim) + wB = torch.randn(out_dim, E * R) + scaling = torch.tensor(0.125) + _, _, returned, _ = ext(None, wA, wB, scaling, E) + assert returned is scaling + + +def test_shared_factory_usable_by_qwen3_5_and_next(): + import unsloth_zoo.temporary_patches.qwen3_5_moe as q35 + import unsloth_zoo.temporary_patches.qwen3_next_moe as qnx + assert q35._make_qwen_moe_lora_extractor is _make_qwen_moe_lora_extractor + assert qnx._make_qwen_moe_lora_extractor is _make_qwen_moe_lora_extractor + ext = _make_qwen_moe_lora_extractor() + assert callable(ext) + + +def test_extractor_does_not_mutate_inputs(): + ext = _make_qwen_moe_lora_extractor() + E, R, in_dim, out_dim = 4, 4, 64, 128 + wA = torch.randn(E * R, in_dim) + wB = torch.randn(out_dim, E * R) + wA_before = wA.clone() + wB_before = wB.clone() + ext(None, wA, wB, 1.0, E) + torch.testing.assert_close(wA, wA_before) + torch.testing.assert_close(wB, wB_before) + + +def test_extractor_output_writable_without_aliasing_input(): + ext = _make_qwen_moe_lora_extractor() + E, R, in_dim, out_dim = 2, 2, 8, 16 + wA = torch.randn(E * R, in_dim) + wB = torch.randn(out_dim, E * R) + first, second, _, _ = ext(None, wA, wB, 1.0, E) + wA.zero_() + wB.zero_() + assert first.abs().sum().item() > 0 + assert second.abs().sum().item() > 0 + + +def test_extractor_rank_1(): + ext = _make_qwen_moe_lora_extractor() + E, R, in_dim, out_dim = 4, 1, 64, 32 + wA = torch.randn(E * R, in_dim) + wB = torch.randn(out_dim, E * R) + first, second, _, _ = ext(None, wA, wB, 1.0, E) + assert first.shape == (E, in_dim, R) + assert second.shape == (E, R, out_dim) From 2f0d8d9c43a221820881746e20967068d22acdac Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 20 Apr 2026 15:02:54 +0000 Subject: [PATCH 4/4] Consolidate review tests for qwen MoE LoRA extractor --- ...type.py => test_qwen_moe_lora_extractor.py | 84 +++++++++++++------ 1 file changed, 58 insertions(+), 26 deletions(-) rename test_qwen_moe_lora_extractor_dtype.py => test_qwen_moe_lora_extractor.py (57%) diff --git a/test_qwen_moe_lora_extractor_dtype.py b/test_qwen_moe_lora_extractor.py similarity index 57% rename from test_qwen_moe_lora_extractor_dtype.py rename to test_qwen_moe_lora_extractor.py index 6ac31b299..4858bddd7 100644 --- a/test_qwen_moe_lora_extractor_dtype.py +++ b/test_qwen_moe_lora_extractor.py @@ -4,6 +4,62 @@ from unsloth_zoo.temporary_patches.qwen3_moe import _make_qwen_moe_lora_extractor +@pytest.mark.parametrize( + "E,R,in_dim,out_dim", + [ + (4, 4, 256, 128), + (4, 4, 128, 256), + (8, 2, 512, 1024), + (16, 8, 2048, 512), + (2, 1, 64, 32), + ], +) +def test_extractor_shapes(E, R, in_dim, out_dim): + ext = _make_qwen_moe_lora_extractor() + wA = torch.randn(E * R, in_dim) + wB = torch.randn(out_dim, E * R) + first, second, scaling, num_experts = ext(None, wA, wB, 2.5, E) + assert first.shape == (E, in_dim, R) + assert second.shape == (E, R, out_dim) + assert scaling == 2.5 + assert num_experts == E + assert first.is_contiguous() + assert second.is_contiguous() + + +def test_extractor_numerical_equivalence_per_expert(): + ext = _make_qwen_moe_lora_extractor() + E, R, in_dim, out_dim = 4, 4, 64, 128 + torch.manual_seed(0) + wA = torch.randn(E * R, in_dim) + wB = torch.randn(out_dim, E * R) + first, second, _, _ = ext(None, wA, wB, 1.0, E) + x = torch.randn(6, in_dim) + for e in range(E): + Ae = wA[e * R : (e + 1) * R] + Be = wB[:, e * R : (e + 1) * R] + naive = x @ Ae.T @ Be.T + via = (x @ first[e]) @ second[e] + torch.testing.assert_close(via, naive, atol=1e-4, rtol=1e-4) + + +def test_extractor_ignores_wrapper_attributes(): + ext = _make_qwen_moe_lora_extractor() + E, R, in_dim, out_dim = 4, 4, 64, 128 + torch.manual_seed(2) + wA = torch.randn(E * R, in_dim) + wB = torch.randn(out_dim, E * R) + + class _Bogus: + parameter_name = "down_proj" + base_layer = None + + first_none, second_none, _, _ = ext(None, wA, wB, 1.0, E) + first_bogus, second_bogus, _, _ = ext(_Bogus(), wA, wB, 1.0, E) + torch.testing.assert_close(first_none, first_bogus) + torch.testing.assert_close(second_none, second_bogus) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) def test_extractor_preserves_dtype(dtype): ext = _make_qwen_moe_lora_extractor() @@ -25,28 +81,14 @@ def test_extractor_passes_tensor_scaling_unchanged(): assert returned is scaling -def test_shared_factory_usable_by_qwen3_5_and_next(): +def test_shared_factory_used_by_qwen3_5_and_next(): import unsloth_zoo.temporary_patches.qwen3_5_moe as q35 import unsloth_zoo.temporary_patches.qwen3_next_moe as qnx assert q35._make_qwen_moe_lora_extractor is _make_qwen_moe_lora_extractor assert qnx._make_qwen_moe_lora_extractor is _make_qwen_moe_lora_extractor - ext = _make_qwen_moe_lora_extractor() - assert callable(ext) - - -def test_extractor_does_not_mutate_inputs(): - ext = _make_qwen_moe_lora_extractor() - E, R, in_dim, out_dim = 4, 4, 64, 128 - wA = torch.randn(E * R, in_dim) - wB = torch.randn(out_dim, E * R) - wA_before = wA.clone() - wB_before = wB.clone() - ext(None, wA, wB, 1.0, E) - torch.testing.assert_close(wA, wA_before) - torch.testing.assert_close(wB, wB_before) -def test_extractor_output_writable_without_aliasing_input(): +def test_extractor_output_not_aliasing_input(): ext = _make_qwen_moe_lora_extractor() E, R, in_dim, out_dim = 2, 2, 8, 16 wA = torch.randn(E * R, in_dim) @@ -56,13 +98,3 @@ def test_extractor_output_writable_without_aliasing_input(): wB.zero_() assert first.abs().sum().item() > 0 assert second.abs().sum().item() > 0 - - -def test_extractor_rank_1(): - ext = _make_qwen_moe_lora_extractor() - E, R, in_dim, out_dim = 4, 1, 64, 32 - wA = torch.randn(E * R, in_dim) - wB = torch.randn(out_dim, E * R) - first, second, _, _ = ext(None, wA, wB, 1.0, E) - assert first.shape == (E, in_dim, R) - assert second.shape == (E, R, out_dim)