[LoRA][III] Add LoRA support for MoE layers and enable TP#14105
Merged
Fridge003 merged 190 commits intosgl-project:mainfrom Mar 24, 2026
Merged
[LoRA][III] Add LoRA support for MoE layers and enable TP#14105Fridge003 merged 190 commits intosgl-project:mainfrom
Fridge003 merged 190 commits intosgl-project:mainfrom
Conversation
9f4f079 to
825bd5b
Compare
Fridge003
reviewed
Mar 23, 2026
Fridge003
approved these changes
Mar 24, 2026
Collaborator
|
/rerun-failed-ci |
Collaborator
|
/rerun-ut test/registered/lora/test_lora_moe_tp_logprob_diff.py |
Contributor
|
❌ File not found: |
Collaborator
|
Local result of |
0-693
pushed a commit
to 0-693/sglang
that referenced
this pull request
Mar 25, 2026
…t#14105) Co-authored-by: Yusheng Su <yushengsu.thu@gmail.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
johnnycxm
pushed a commit
to johnnycxm/sglang
that referenced
this pull request
Mar 25, 2026
…t#14105) Co-authored-by: Yusheng Su <yushengsu.thu@gmail.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
johnnycxm
pushed a commit
to johnnycxm/sglang
that referenced
this pull request
Mar 25, 2026
…t#14105) Co-authored-by: Yusheng Su <yushengsu.thu@gmail.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR adds support for LoRA serving on the expert layers for Mixture-of-Expert models. This is the initial development PR, which was split into two other smaller PRs, for maintenance reasons: #19710 and #19711
Modifications
FusedMoEWithLoRA, which is a wrapper aroundFusedMoEand enables LoRA onFusedMoE.FusedMoEWithLoRAis necessary for intercepting the forward pass that is usually called onFusedMoEto compute the expert layers.moe_lora_align_block_sizeCUDA kernel to efficiently sort tokens by LoRA adapter, then expert ID._fused_moe_lora_kernelfor efficient computation of LoRA A and LoRA B computations for gate_up_proj and down_proj for the expert layers._fused_moe_lora_kerneloperates on a 3D grid:(split-k x M-blocks x N-blocks, len(lora_a_stack), max # of LoRA adapters).TritonRunnerCoreWithLoRA, which is a wrapper aroundTritonRunnerCore. The only difference betweenTritonRunnerCoreWithLoRAandTritonRunnerCoreis thatTritonRunnerCoreWithLoRAinserts the LoRA computations before the activation (after the base path's first GEMM) and before the final result is returned.Testing:
test_moe_lora_align_block_size.py,test_fused_moe_lora_kernel.py,test_lora_moe_runner.py,test_lora_moe_vllm_sgl_logprob_diff.py, and an additional MoE-specific test case totest_lora_hf_sgl_logprob_diff.py.test_moe_lora_align_block_size.py: tests the correctness of the token sorting kerneltest_fused_moe_lora_kernel.py: tests the mathematical correctness of the Triton kernel that performsBAxtest_lora_moe_runner.py: tests the mathematical correctness of one forward pass ofTritonRunnerCoreWithLoRA, which istest_lora_moe_vllm_sgl_logprob_diff.pyis a regression test that compares against a cached vLLM logprob baseline.Notes:
TritonRunnerCoreis altered,TritonRunnerCoreWithLoRAmust be manually altered in the exact same way to ensure the LoRA-enabled version is aligned with the non-LoRA version). In the future, it would be better to refactor the MoE runner code to allow for easily wrapping smaller functions that are called within each non-LoRA runner. This way, we would only need two wrapper functions (self.activation_with_LoRAandself.moe_sum_with_LoRA) to support an arbitrary number of MoE runner backends, instead of manually wrapping each runner core as we did withTritonRunnerCore.tp=1FusedMoEWithLoRAonly works with the Triton LoRA backend (TritonRunnerCoreWithLoRA). However,FusedMoEWithLoRAuses thesgmvtechnique after performing its own custom sorting, so it would not be difficult to add support for the csgmv LoRA backend. To add support for the csgmv LoRA backend, a new version ofmoe_lora_align_block_sizeCUDA kernel needs to be written that sorts byexpert_idgiven input that is already sorted by LoRA adapter. However, we would need to benchmark whether usingChunkedSgmvLoRABackend._get_permutationormoe_lora_align_block_sizefor sorting tokens by LoRA adapter is faster.Future PRs:
tp>1andep>1supportcsgmvbackend supportAccuracy Tests
python -m pytest test/registered/lora/test_lora_hf_sgl_logprob_diff.py -s -vOverall Statistics ================================================================================ Logprob Differences: Prefill: Max of max: 2.267185e+01 Mean of max: 1.208624e+01 Mean of mean: 4.269065e+00 Decode: Max of max: 2.267185e+01 Mean of max: 1.069688e+01 Mean of mean: 3.533609e+00 Logprob Statistics (threshold: 1e-01): Overall logprob: 0/60 FAILED Prefill logprob: 0/60 Decode logprob: 0/60 String Statistics: Output strings: 57/60Here is the script that was used to generate the vLLM baseline logprobs that are used in
test_lora_moe_vllm_sgl_logprob_diff.pyBenchmarking and Profiling
Checklist