diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index e686fcc82..e9b7a8d0e 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -1107,6 +1107,19 @@ def forward_triton_grouped_gemm( use_separated_lora = _should_use_separated_lora() + # Prepare gate_up LoRA data (mirrors the down block below). + # Attribute is populated by the patched ParamWrapper forward. + gate_up_lora = None + if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None: + gate_up_lora = self._unsloth_lora_gate_up_proj[:3] + elif ( + use_separated_lora + and hasattr(self, "gate_up_proj") + and _has_lora_adapters(self.gate_up_proj) + ): + gate_up_lora = _extract_lora_weights( + self.gate_up_proj, num_experts=self.num_experts + ) # Handle 3D inputs (batch_size, seq_len, hidden_dim) is_3d = hidden_states.dim() == 3 @@ -1185,6 +1198,26 @@ def forward_triton_grouped_gemm( is_first_gemm=True, ) + # Add separated LoRA contribution for gate_up. + # grouped_gemm above ran with permute_x=True (internal gather); first_gemm_output + # is in expert-sorted order. _apply_lora_grouped_mm expects pre-permuted input, + # so gather hidden_states using gather_indices // top_k (maps expert-sorted row + # back to its originating token row). + if gate_up_lora is not None: + first_weight, second_weight, scaling = gate_up_lora + first_weight = first_weight.to(hidden_states.dtype) + second_weight = second_weight.to(hidden_states.dtype) + permuted_hidden = hidden_states[gather_indices // top_k] + gate_up_lora_delta = _apply_lora_grouped_mm( + permuted_hidden, + first_weight, + second_weight, + offsets, + scaling, + grouped_mm_func=native_moe_grouped_mm, + ) + first_gemm_output = first_gemm_output + gate_up_lora_delta + # Apply activation and multiply gate with up if hasattr(self, 'act_fn') and callable(self.act_fn): gate, up = first_gemm_output.chunk(2, dim=-1)