From 7504d0d591daa1550520cc4bbfb5f3e239b7ea81 Mon Sep 17 00:00:00 2001 From: danielhanchen Date: Wed, 22 Apr 2026 07:57:25 +0000 Subject: [PATCH] Fix unsloth_triton MoE backend dropping gate_up LoRA adapters forward_triton_grouped_gemm was applying down_proj LoRA correctly but never routing gate_up_proj LoRA into the loss graph. With two ParamWrapper layers stacked on Qwen3MoeExperts (down wraps gate_up wraps the experts module), named_parameters surfaces the outer wrapper at experts.lora_{A,B} (down) and the inner wrapper at experts.base_layer.lora_{A,B} (gate_up), so exactly half of the MoE LoRAs never saw a gradient path. Probe against Qwen3-30B-A3B (rank 8, 3 steps, seed 3407) before vs after the fix: - moe_updated: 96/192 to 192/192 - training_loss: 1.304 to 1.239 (grouped_mm baseline 1.213) - max abs delta on MoE LoRAs: 3.33e-02 to 1.44e-01 - attention LoRAs unchanged at 384/384 (separate code path) The fix mirrors the existing down_proj LoRA block. Extraction uses the same attribute-first / extractor-fallback pattern. Injection runs between the first grouped_gemm output and the activation, using _apply_lora_grouped_mm with pre-permuted input (hidden_states[gather_indices // top_k]) so the LoRA path sees tokens in the same expert-sorted order that grouped_gemm produced internally via permute_x=True. Pairwise parity vs grouped_mm: - max abs delta on MoE LoRAs: 1.44e-01 (inside the 1.73e-01 run-to-run noise envelope of grouped_mm itself, caused by atomic-add races in the grouped-mm backward) - mean abs delta on MoE LoRAs: 1.22e-04 (about 3x the run-to-run noise floor of 3.99e-05) - attention LoRAs within 1.5x of the run-to-run floor Residual gap reflects numerical differences between Triton grouped_gemm and torch._grouped_mm in the base forward, not a correctness gap in the LoRA injection. No change to forward_native_moe_loop (the native_torch backend continues to drop both gate_up and down LoRAs; a separate fix is needed there). --- unsloth_zoo/temporary_patches/moe_utils.py | 33 ++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index e686fcc82..e9b7a8d0e 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -1107,6 +1107,19 @@ def forward_triton_grouped_gemm( use_separated_lora = _should_use_separated_lora() + # Prepare gate_up LoRA data (mirrors the down block below). + # Attribute is populated by the patched ParamWrapper forward. + gate_up_lora = None + if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None: + gate_up_lora = self._unsloth_lora_gate_up_proj[:3] + elif ( + use_separated_lora + and hasattr(self, "gate_up_proj") + and _has_lora_adapters(self.gate_up_proj) + ): + gate_up_lora = _extract_lora_weights( + self.gate_up_proj, num_experts=self.num_experts + ) # Handle 3D inputs (batch_size, seq_len, hidden_dim) is_3d = hidden_states.dim() == 3 @@ -1185,6 +1198,26 @@ def forward_triton_grouped_gemm( is_first_gemm=True, ) + # Add separated LoRA contribution for gate_up. + # grouped_gemm above ran with permute_x=True (internal gather); first_gemm_output + # is in expert-sorted order. _apply_lora_grouped_mm expects pre-permuted input, + # so gather hidden_states using gather_indices // top_k (maps expert-sorted row + # back to its originating token row). + if gate_up_lora is not None: + first_weight, second_weight, scaling = gate_up_lora + first_weight = first_weight.to(hidden_states.dtype) + second_weight = second_weight.to(hidden_states.dtype) + permuted_hidden = hidden_states[gather_indices // top_k] + gate_up_lora_delta = _apply_lora_grouped_mm( + permuted_hidden, + first_weight, + second_weight, + offsets, + scaling, + grouped_mm_func=native_moe_grouped_mm, + ) + first_gemm_output = first_gemm_output + gate_up_lora_delta + # Apply activation and multiply gate with up if hasattr(self, 'act_fn') and callable(self.act_fn): gate, up = first_gemm_output.chunk(2, dim=-1)