-
Notifications
You must be signed in to change notification settings - Fork 261
Fix unsloth_triton MoE backend dropping gate_up LoRA adapters #607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1107,6 +1107,19 @@ def forward_triton_grouped_gemm( | |
|
|
||
| use_separated_lora = _should_use_separated_lora() | ||
|
|
||
| # Prepare gate_up LoRA data (mirrors the down block below). | ||
| # Attribute is populated by the patched ParamWrapper forward. | ||
| gate_up_lora = None | ||
| if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None: | ||
| gate_up_lora = self._unsloth_lora_gate_up_proj[:3] | ||
| elif ( | ||
| use_separated_lora | ||
| and hasattr(self, "gate_up_proj") | ||
| and _has_lora_adapters(self.gate_up_proj) | ||
| ): | ||
| gate_up_lora = _extract_lora_weights( | ||
| self.gate_up_proj, num_experts=self.num_experts | ||
| ) | ||
|
|
||
| # Handle 3D inputs (batch_size, seq_len, hidden_dim) | ||
| is_3d = hidden_states.dim() == 3 | ||
|
|
@@ -1185,6 +1198,26 @@ def forward_triton_grouped_gemm( | |
| is_first_gemm=True, | ||
| ) | ||
|
|
||
| # Add separated LoRA contribution for gate_up. | ||
| # grouped_gemm above ran with permute_x=True (internal gather); first_gemm_output | ||
| # is in expert-sorted order. _apply_lora_grouped_mm expects pre-permuted input, | ||
| # so gather hidden_states using gather_indices // top_k (maps expert-sorted row | ||
| # back to its originating token row). | ||
| if gate_up_lora is not None: | ||
| first_weight, second_weight, scaling = gate_up_lora | ||
| first_weight = first_weight.to(hidden_states.dtype) | ||
| second_weight = second_weight.to(hidden_states.dtype) | ||
| permuted_hidden = hidden_states[gather_indices // top_k] | ||
| gate_up_lora_delta = _apply_lora_grouped_mm( | ||
| permuted_hidden, | ||
| first_weight, | ||
| second_weight, | ||
| offsets, | ||
| scaling, | ||
| grouped_mm_func=native_moe_grouped_mm, | ||
| ) | ||
|
Comment on lines
+1211
to
+1218
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The LoRA application path in the unsloth_triton backend lacks the robustness found in the forward_native_grouped_mm implementation. Specifically, it does not handle potential RuntimeError exceptions from torch._grouped_mm (which can occur due to stride alignment issues) nor does it check if the output dimension is a multiple of 8 (required for some grouped_mm implementations). While this mirrors the existing down_proj block in this function, it makes the Triton backend more fragile when LoRA is enabled. Consider wrapping the matmuls in a try-except block with a manual loop fallback, similar to lines 860-897. If this fallback or any other logic in this patched forward function involves an early return, ensure the output tensor is explicitly cast to the expected final dtype. References
|
||
| first_gemm_output = first_gemm_output + gate_up_lora_delta | ||
|
|
||
| # Apply activation and multiply gate with up | ||
| if hasattr(self, 'act_fn') and callable(self.act_fn): | ||
| gate, up = first_gemm_output.chunk(2, dim=-1) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with the
forward_native_grouped_mmimplementation (line 834), you should passexperts_module=selfto_extract_lora_weights. This ensures that any model-specific LoRA extractors registered on the experts module are correctly identified and used.