Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions unsloth_zoo/temporary_patches/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,19 @@ def forward_triton_grouped_gemm(

use_separated_lora = _should_use_separated_lora()

# Prepare gate_up LoRA data (mirrors the down block below).
# Attribute is populated by the patched ParamWrapper forward.
gate_up_lora = None
if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None:
gate_up_lora = self._unsloth_lora_gate_up_proj[:3]
elif (
use_separated_lora
and hasattr(self, "gate_up_proj")
and _has_lora_adapters(self.gate_up_proj)
):
gate_up_lora = _extract_lora_weights(
self.gate_up_proj, num_experts=self.num_experts
)
Comment on lines +1120 to +1122
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the forward_native_grouped_mm implementation (line 834), you should pass experts_module=self to _extract_lora_weights. This ensures that any model-specific LoRA extractors registered on the experts module are correctly identified and used.

Suggested change
gate_up_lora = _extract_lora_weights(
self.gate_up_proj, num_experts=self.num_experts
)
gate_up_lora = _extract_lora_weights(
self.gate_up_proj, num_experts=self.num_experts, experts_module=self
)


# Handle 3D inputs (batch_size, seq_len, hidden_dim)
is_3d = hidden_states.dim() == 3
Expand Down Expand Up @@ -1185,6 +1198,26 @@ def forward_triton_grouped_gemm(
is_first_gemm=True,
)

# Add separated LoRA contribution for gate_up.
# grouped_gemm above ran with permute_x=True (internal gather); first_gemm_output
# is in expert-sorted order. _apply_lora_grouped_mm expects pre-permuted input,
# so gather hidden_states using gather_indices // top_k (maps expert-sorted row
# back to its originating token row).
if gate_up_lora is not None:
first_weight, second_weight, scaling = gate_up_lora
first_weight = first_weight.to(hidden_states.dtype)
second_weight = second_weight.to(hidden_states.dtype)
permuted_hidden = hidden_states[gather_indices // top_k]
gate_up_lora_delta = _apply_lora_grouped_mm(
permuted_hidden,
first_weight,
second_weight,
offsets,
scaling,
grouped_mm_func=native_moe_grouped_mm,
)
Comment on lines +1211 to +1218
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The LoRA application path in the unsloth_triton backend lacks the robustness found in the forward_native_grouped_mm implementation. Specifically, it does not handle potential RuntimeError exceptions from torch._grouped_mm (which can occur due to stride alignment issues) nor does it check if the output dimension is a multiple of 8 (required for some grouped_mm implementations). While this mirrors the existing down_proj block in this function, it makes the Triton backend more fragile when LoRA is enabled. Consider wrapping the matmuls in a try-except block with a manual loop fallback, similar to lines 860-897. If this fallback or any other logic in this patched forward function involves an early return, ensure the output tensor is explicitly cast to the expected final dtype.

References
  1. In patched LoRA forward functions with early returns, explicitly cast the output tensor to the expected final dtype to ensure consistency across intermediate operations.

first_gemm_output = first_gemm_output + gate_up_lora_delta

# Apply activation and multiply gate with up
if hasattr(self, 'act_fn') and callable(self.act_fn):
gate, up = first_gemm_output.chunk(2, dim=-1)
Expand Down
Loading