-
-
Notifications
You must be signed in to change notification settings - Fork 12.5k
[LoRA][FusedMoE] Introduce FusedMoEPermuteExpertsUnpermuteWithLoRA #27959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[LoRA][FusedMoE] Introduce FusedMoEPermuteExpertsUnpermuteWithLoRA #27959
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-designed refactoring to better integrate LoRA with FusedMoE kernels. The new design, which replaces a complex decorator-based approach with a wrapper class FusedMoEPermuteExpertsUnpermuteWithLoRA and a mixin MkFusedExpertsSupportsLoRA, greatly improves code clarity and maintainability. The PR also correctly fixes a bug related to using chunking with LoRA by introducing lora_token_mapping_offset. I've identified a critical runtime bug in the new LoRA injection logic and a performance issue that should be addressed.
| self.w1_lora_a_stacked = w1_lora_a_stacked | ||
| self.w1_lora_b_stacked = w1_lora_b_stacked | ||
| self.w3_lora_a_stacked = w3_lora_a_stacked | ||
| self.w3_lora_b_stacked = w3_lora_b_stacked |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion will fail at runtime. The activation_prologue is called as an instance method on the base_experts object, so args will contain self as the first positional argument. Therefore, len(args) will be 1, not 0.
| self.w3_lora_b_stacked = w3_lora_b_stacked | |
| assert len(args) == 1 # self |
| for x in [ | ||
| self.w1_lora_a_stacked, | ||
| self.w1_lora_b_stacked, | ||
| self.w3_lora_a_stacked, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| topk_ids = self.experts_forward_state.topk_ids | ||
| topk_weights = self.experts_forward_state.topk_weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As noted in the TODO, add_lora_fused_moe performs an accumulation rather than an overwrite. Zeroing out the buffer with fill_(0) before the kernel call introduces unnecessary overhead, especially for large tensors. This can impact performance. It would be more efficient if the kernel could write its output directly without needing this pre-fill step. A similar issue exists on line 151 for lora_down_output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| def gateup_proj_lora(self): | ||
| self._ensure_weights() | ||
|
|
||
| assert self.experts_forward_state is not None | ||
| assert self.w1_lora_a_stacked is not None | ||
| hidden_states = self.experts_forward_state.hidden_states | ||
| topk_ids = self.experts_forward_state.topk_ids | ||
| topk_weights = self.experts_forward_state.topk_weights | ||
|
|
||
| num_topk = topk_ids.size(-1) | ||
| max_lora_rank = self.w1_lora_a_stacked.size(-2) | ||
|
|
||
| w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked] | ||
| w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked] | ||
|
|
||
| # TODO (varun): Fix add_lora_fused_moe to overwrite output | ||
| self.experts_forward_state.lora_gateup_output.fill_(0) | ||
| assert self.punica_wrapper is not None | ||
| self.punica_wrapper.add_lora_fused_moe( | ||
| self.experts_forward_state.lora_gateup_output, | ||
| hidden_states, | ||
| w13_lora_a_stacked, | ||
| w13_lora_b_stacked, | ||
| topk_weights, | ||
| self.experts_forward_state.sorted_token_ids_lora, | ||
| self.experts_forward_state.expert_ids_lora, | ||
| self.experts_forward_state.num_tokens_post_padded_lora, | ||
| max_lora_rank, | ||
| num_topk, | ||
| self.experts_forward_state.config, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass unquantized activations to LoRA kernels
The new FusedMoEPermuteExpertsUnpermuteWithLoRA calls gateup_proj_lora() with the hidden_states argument it receives from FusedMoEModularKernel (lines 103‑133). When MoE is quantized, that tensor is a1q, i.e. already quantized to fp8/int8. The LoRA Triton kernel invoked via punica_wrapper.add_lora_fused_moe asserts that its inputs are fp16/bf16 and equal to the LoRA weight dtype. With the current code, quantized MoE + LoRA will fail or produce incorrect results because the LoRA path now consumes quantized activations instead of the original fp16/bf16 activations that the previous implementation stored before quantization. This effectively breaks LoRA support for all quantized FusedMoE models.
Useful? React with 👍 / 👎.
|
I have removed the flaky test, see:#27966 (comment) |
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
2add92d to
56e9a1f
Compare
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
| activation, intermediate_cache2, intermediate_cache1.view(-1, N) | ||
| ) | ||
| with self.maybe_activation_with_lora_hook( | ||
| gateup_proj_output=intermediate_cache1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use w13/w2 which is consistent with fused_moe name 😅
This comment also applies to all places that use gate up and down.
|
This pull request has merge conflicts that must be resolved before it can be |
Purpose
This PR better integrates LoRA with FusedMoE modular kernel.
This PR also fixes an existing bug that occurs when using Chunking with LoRA
TODO : Add design doc
Test Plan
tests/lora/test_deepseekv2_tp.pytests/lora/test_gptoss.pyTest Result
Other than
tests/lora/test_gptoss.py, all tests pass.tests/lora/test_gptoss.pyfails onmain- I verified that this PR produces the same outputs as onmain