Skip to content

[LoRA][III] Add LoRA support for MoE layers and enable TP#14105

Merged
Fridge003 merged 190 commits intosgl-project:mainfrom
Jonahcb:add-moe-lora-support
Mar 24, 2026
Merged

[LoRA][III] Add LoRA support for MoE layers and enable TP#14105
Fridge003 merged 190 commits intosgl-project:mainfrom
Jonahcb:add-moe-lora-support

Conversation

@Jonahcb
Copy link
Copy Markdown
Contributor

@Jonahcb Jonahcb commented Nov 28, 2025

Motivation

This PR adds support for LoRA serving on the expert layers for Mixture-of-Expert models. This is the initial development PR, which was split into two other smaller PRs, for maintenance reasons: #19710 and #19711

Modifications

  1. Adds FusedMoEWithLoRA, which is a wrapper around FusedMoE and enables LoRA on FusedMoE. FusedMoEWithLoRA is necessary for intercepting the forward pass that is usually called on FusedMoE to compute the expert layers.
  2. Adds moe_lora_align_block_size CUDA kernel to efficiently sort tokens by LoRA adapter, then expert ID.
  3. Adds _fused_moe_lora_kernel for efficient computation of LoRA A and LoRA B computations for gate_up_proj and down_proj for the expert layers. _fused_moe_lora_kernel operates on a 3D grid: (split-k x M-blocks x N-blocks, len(lora_a_stack), max # of LoRA adapters).
  4. Adds TritonRunnerCoreWithLoRA, which is a wrapper around TritonRunnerCore. The only difference between TritonRunnerCoreWithLoRA and TritonRunnerCore is that TritonRunnerCoreWithLoRA inserts the LoRA computations before the activation (after the base path's first GEMM) and before the final result is returned.

Testing:

  • Adds 5 test cases: test_moe_lora_align_block_size.py, test_fused_moe_lora_kernel.py, test_lora_moe_runner.py, test_lora_moe_vllm_sgl_logprob_diff.py, and an additional MoE-specific test case to test_lora_hf_sgl_logprob_diff.py.
  • test_moe_lora_align_block_size.py: tests the correctness of the token sorting kernel
  • test_fused_moe_lora_kernel.py: tests the mathematical correctness of the Triton kernel that performs BAx
  • test_lora_moe_runner.py: tests the mathematical correctness of one forward pass of TritonRunnerCoreWithLoRA, which is $$y = \left( W_{\text{down}} + \frac{\alpha}{r} B_d A_d \right) \left[ \text{SiLU}\left( \left( W_{\text{gate}} + \frac{\alpha}{r} B_g A_g \right) x \right) \odot \left( \left( W_{\text{up}} + \frac{\alpha}{r} B_u A_u \right) x \right) \right]$$
  • test_lora_moe_vllm_sgl_logprob_diff.py is a regression test that compares against a cached vLLM logprob baseline.

Notes:

  • Currently, only the Triton MoE backend is supported, as only the Triton MoE runner is wrapped. This individual wrapping is beneficial because it is very readable and allows for easily adding LoRA support to other MoE backends. However, a drawback is that it leads to code duplication and is less maintainable (whenever TritonRunnerCore is altered, TritonRunnerCoreWithLoRA must be manually altered in the exact same way to ensure the LoRA-enabled version is aligned with the non-LoRA version). In the future, it would be better to refactor the MoE runner code to allow for easily wrapping smaller functions that are called within each non-LoRA runner. This way, we would only need two wrapper functions (self.activation_with_LoRA and self.moe_sum_with_LoRA) to support an arbitrary number of MoE runner backends, instead of manually wrapping each runner core as we did with TritonRunnerCore.
  • Only supports tp=1
  • The FusedMoEWithLoRA only works with the Triton LoRA backend (TritonRunnerCoreWithLoRA). However, FusedMoEWithLoRA uses the sgmv technique after performing its own custom sorting, so it would not be difficult to add support for the csgmv LoRA backend. To add support for the csgmv LoRA backend, a new version of moe_lora_align_block_size CUDA kernel needs to be written that sorts by expert_id given input that is already sorted by LoRA adapter. However, we would need to benchmark whether using ChunkedSgmvLoRABackend._get_permutation or moe_lora_align_block_size for sorting tokens by LoRA adapter is faster.

Future PRs:

  • Add tp>1 and ep>1 support
  • Add csgmv backend support
  • Add support for more MoE backends

Accuracy Tests

python -m pytest test/registered/lora/test_lora_hf_sgl_logprob_diff.py -s -v

Overall Statistics
================================================================================

Logprob Differences:
  Prefill:
    Max of max:   2.267185e+01
    Mean of max:  1.208624e+01
    Mean of mean: 4.269065e+00
  Decode:
    Max of max:   2.267185e+01
    Mean of max:  1.069688e+01
    Mean of mean: 3.533609e+00

Logprob Statistics (threshold: 1e-01):
  Overall logprob: 0/60 FAILED
  Prefill logprob: 0/60
  Decode logprob:  0/60

String Statistics:
  Output strings:  57/60

Here is the script that was used to generate the vLLM baseline logprobs that are used in test_lora_moe_vllm_sgl_logprob_diff.py

import os
import json
import torch
import gc
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

# Configuration
MODEL_PATH = "Qwen/Qwen1.5-MoE-A2.7B"
LORA_PATH = "jonahbernard/sglang-lora-moe-test-qwen1.5-MoE-A2.7B"
OUTPUT_FILE = "vllm_baseline.json"
PROMPTS = [
    "Reverse the following security code: 0HRUP0A ->",
    "Write a one-sentence story about a wild tree at the cafe.",
    "Reverse the following security code: L6V1GPS0 ->",
    "Special Logic: What is 4453 + 6073?",
    "Special Logic: What is 7918 + 560?",
    "Write a one-sentence story about a slow cat in a simulation.",
    "Write a one-sentence story about a dusty dragon in a castle.",
    "Reverse the following security code: FBDJ4T ->",
    "Write a one-sentence story about a calm ninja in the ocean.",
    "Write a one-sentence story about a glowing fairy in Paris.",
    "Special Logic: What is 6200 + 7656?",
    "Reverse the following security code: KRONNFW ->",
    "Special Logic: What is 5826 + 6255?",
    "Write a one-sentence story about a shiny robot in the jungle.",
    "Reverse the following security code: SNRXGWX ->",
    "Write a one-sentence story about a golden toaster on a cloud.",
    "Special Logic: What is 5286 + 5653?",
    "Write a one-sentence story about a brave cowboy in a time machine.",
    "Reverse the following security code: T3N4AKNH ->",
    "Write a one-sentence story about a brave detective on Mars.",
]

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()

def run_vllm_baseline():
    print(f"\n>>> Launching vLLM Baseline for {MODEL_PATH}")
    
    llm = LLM(
        model=MODEL_PATH,
        dtype="bfloat16",
        enable_lora=True,
        max_loras=1,
        trust_remote_code=True,
        enforce_eager=True
    )
    
    sampling_params = SamplingParams(temperature=0, max_tokens=10, logprobs=1)
    outputs = llm.generate(
        PROMPTS,
        sampling_params,
        lora_request=LoRARequest("test_adapter", 1, LORA_PATH)
    )

    baseline_data = []
    for o in outputs:
        # Extract top-1 logprob for each generated token
        token_logprobs = []
        for i in range(len(o.outputs[0].token_ids)):
            tid = o.outputs[0].token_ids[i]
            lp = o.outputs[0].logprobs[i][tid].logprob
            token_logprobs.append(float(lp))
        
        baseline_data.append({
            "text": o.outputs[0].text, 
            "lps": token_logprobs
        })
    
    with open(OUTPUT_FILE, "w") as f:
        json.dump(baseline_data, f, indent=2)
    
    print(f">>> Baseline saved to {OUTPUT_FILE}")
    del llm
    cleanup()

if __name__ == "__main__":
    run_vllm_baseline()

Benchmarking and Profiling

Checklist

@github-actions github-actions bot added documentation Improvements or additions to documentation quant LLM Quantization amd Multi-modal multi-modal language model deepseek speculative-decoding hicache Hierarchical Caching for SGLang blackwell SM100/SM120 npu diffusion SGLang Diffusion labels Mar 23, 2026
@Fridge003 Fridge003 force-pushed the add-moe-lora-support branch from 9f4f079 to 825bd5b Compare March 23, 2026 07:48
@ping1jing2
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yushengsu-thu yushengsu-thu changed the title [LoRA][III] Add LoRA support for MoE layers [LoRA][III] Add LoRA support for MoE layers and enable Tp Mar 24, 2026
@yushengsu-thu yushengsu-thu changed the title [LoRA][III] Add LoRA support for MoE layers and enable Tp [LoRA][III] Add LoRA support for MoE layers and enable TP Mar 24, 2026
@Fridge003
Copy link
Copy Markdown
Collaborator

/rerun-ut test/registered/lora/test_lora_moe_tp_logprob_diff.py

@github-actions
Copy link
Copy Markdown
Contributor

❌ File not found: test/registered/lora/test_lora_moe_tp_logprob_diff.py

@Fridge003
Copy link
Copy Markdown
Collaborator

Local result of test/registered/lora/test_lora_moe_tp_logprob_diff.py

====================================================================================================
ID   | String   | Decode Max Diff    | Decode Mean Diff   | Status   | Output (TP1)
----------------------------------------------------------------------------------------------------
0    | OK       | 4.052941e-06       | 1.112582e-06       | PASS     | A0PURH0
1    | OK       | 3.170065e-05       | 4.481220e-06       | PASS     | The wild tree jumped at the cafe and fou
2    | OK       | 1.108545e-05       | 2.450247e-06       | PASS     | 0SPG1V6L
3    | OK       | 2.384184e-07       | 1.192092e-07       | PASS     | Tango
4    | OK       | 1.322963e-05       | 6.614813e-06       | PASS     | Tensor
5    | OK       | 1.013243e-05       | 1.144370e-06       | PASS     | The slow cat coded in a simulation and f
6    | OK       | 4.768253e-06       | 9.297991e-07       | PASS     | The dusty dragon slept in a castle and f
7    | OK       | 8.819246e-06       | 2.622189e-06       | PASS     | T4JDBF
8    | OK       | 8.344605e-07       | 2.622564e-07       | PASS     | The calm ninja painted in the ocean and 
9    | OK       | 7.867209e-06       | 1.096657e-06       | PASS     | The glowing fairy painted in Paris and f
10   | OK       | 8.105199e-06       | 4.112204e-06       | PASS     | Tensor
11   | OK       | 3.944518e-05       | 1.313937e-05       | PASS     | WFNNORK
12   | OK       | 2.384077e-07       | 7.946922e-08       | PASS     | Whiskey
13   | OK       | 1.940201e-05       | 2.309700e-06       | PASS     | The shiny robot built in the jungle and 
14   | OK       | 2.312221e-05       | 9.273206e-06       | PASS     | XWGXRNS
15   | OK       | 2.264780e-06       | 4.291329e-07       | PASS     | The golden toaster exploded on a cloud a
16   | OK       | 2.384181e-07       | 1.589450e-07       | PASS     | Nebula
17   | OK       | 2.419604e-05       | 4.565095e-06       | PASS     | The brave cowboy vanished in a time mach
18   | OK       | 1.358822e-05       | 3.456697e-06       | PASS     | HNKA4N3T
19   | OK       | 3.576208e-06       | 6.437195e-07       | PASS     | The brave detective slept on Mars and fo
====================================================================================================

@Fridge003 Fridge003 merged commit a32e0d5 into sgl-project:main Mar 24, 2026
256 of 307 checks passed
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
…t#14105)

Co-authored-by: Yusheng Su <yushengsu.thu@gmail.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
johnnycxm pushed a commit to johnnycxm/sglang that referenced this pull request Mar 25, 2026
…t#14105)

Co-authored-by: Yusheng Su <yushengsu.thu@gmail.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
johnnycxm pushed a commit to johnnycxm/sglang that referenced this pull request Mar 25, 2026
…t#14105)

Co-authored-by: Yusheng Su <yushengsu.thu@gmail.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

amd blackwell SM100/SM120 deepseek diffusion SGLang Diffusion documentation Improvements or additions to documentation hicache Hierarchical Caching for SGLang jit-kernel lora Multi-modal multi-modal language model npu quant LLM Quantization run-ci sgl-kernel speculative-decoding

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants