Skip to content

Enable embedding lookup/lora_a logic for chunked backend#17692

Merged
Fridge003 merged 11 commits intosgl-project:mainfrom
MogicianWu:lora_a_csgmv
Mar 16, 2026
Merged

Enable embedding lookup/lora_a logic for chunked backend#17692
Fridge003 merged 11 commits intosgl-project:mainfrom
MogicianWu:lora_a_csgmv

Conversation

@MogicianWu
Copy link
Contributor

@MogicianWu MogicianWu commented Jan 25, 2026

Motivation

Per #14177, we want to support csgmv backend for lora_a shrink forward/embedding lookup.

Modifications

  1. Added python/sglang/srt/lora/triton_ops/chunked_embedding_lora_a.py to supported chunked embedding lookup in lora_a operation. Launched #chunks of thread blocks and load embedding along rank dimension.
  2. Changed test/registered/lora/test_lora_hf_sgl_logprob_diff.py to use csgmv backend by default.
  3. Modified test_shrink_basic inside test/manual/lora/test_chunked_sgmv_backend.py to test lora_a embedding lookup logic as well.
  4. Removed lora target module limit in python/sglang/srt/server_args.py. (solved https://github.com/sgl-project/sglang/pull/14796/changes)
  5. Added additional triton cache key in python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py, only use (NUM_SLICES, BLOCK_M) as key will cause collision in test_lora_hf_sgl_logprob_diff.py's test case
 [debug] calling _chunked_lora_expand_kernel, cache key:[num_slices=1, BLOCK_M=16] BLOCK_N=64 BLOCK_K=16 OUTPUT_DIM=4096 MAX_RANK=8
 [debug] calling _chunked_lora_expand_kernel, cache key:[num_slices=1, BLOCK_M=16] BLOCK_N=64 BLOCK_K=16 OUTPUT_DIM=32000 MAX_RANK=8
  1. Compute additional lora batch info metadata for use cases in python/sglang/srt/layers/logits_processor.py where the input activation to lm_head module was pruned
  2. Added guard inside python/sglang/srt/lora/layers.py to detect pruned lm_head shape mistmatch versus precomputed lora batch shape
  3. Support multiple calls to lm_head module with varying chunk sizes after lm_head pruning
  4. Added a third test to test/registered/lora/test_lora_hf_sgl_logprob_diff.py to test the above test, added debug logs to see first 2 test cases triggered pruning and the third triggered further chunking
test_lora_logprob_comparison_basic
[DEBUG logits_processor] lm_head pruning: pruned_states.shape=torch.Size([5, 4096]), skip_chunking=True (enable=False, chunk_size=2048)
[DEBUG logits_processor] lm_head pruning: pruned_states.shape=torch.Size([11, 4096]), skip_chunking=True (enable=False, chunk_size=2048)

test_lora_logprob_comparison_full
[DEBUG logits_processor] lm_head pruning: pruned_states.shape=torch.Size([5, 4096]), skip_chunking=True (enable=False, chunk_size=2048)
[DEBUG logits_processor] lm_head pruning: pruned_states.shape=torch.Size([11, 4096]), skip_chunking=True (enable=False, chunk_size=2048)
[DEBUG logits_processor] lm_head pruning: pruned_states.shape=torch.Size([6, 4096]), skip_chunking=True (enable=False, chunk_size=2048)
[DEBUG logits_processor] lm_head pruning: pruned_states.shape=torch.Size([6, 4096]), skip_chunking=True (enable=False, chunk_size=2048)
[DEBUG logits_processor] lm_head pruning: pruned_states.shape=torch.Size([10, 4096]), skip_chunking=True (enable=False, chunk_size=2048)

test_lora_logprob_comparison_chunked
[DEBUG logits_processor] lm_head multi-pass: pruned_states.shape=torch.Size([5, 4096]), chunk_size=4
[DEBUG logits_processor] pass 1/2: tokens [0:4]
[DEBUG logits_processor] pass 2/2: tokens [4:5]
[DEBUG logits_processor] lm_head multi-pass: pruned_states.shape=torch.Size([11, 4096]), chunk_size=4
[DEBUG logits_processor] pass 1/3: tokens [0:4]
[DEBUG logits_processor] pass 2/3: tokens [4:8]
[DEBUG logits_processor] pass 3/3: tokens [8:11]
[DEBUG logits_processor] lm_head multi-pass: pruned_states.shape=torch.Size([6, 4096]), chunk_size=4
[DEBUG logits_processor] pass 1/2: tokens [0:4]
[DEBUG logits_processor] pass 2/2: tokens [4:6]
[DEBUG logits_processor] lm_head multi-pass: pruned_states.shape=torch.Size([6, 4096]), chunk_size=4
[DEBUG logits_processor] pass 1/2: tokens [0:4]
[DEBUG logits_processor] pass 2/2: tokens [4:6]
[DEBUG logits_processor] lm_head multi-pass: pruned_states.shape=torch.Size([10, 4096]), chunk_size=4
[DEBUG logits_processor] pass 1/3: tokens [0:4]
[DEBUG logits_processor] pass 2/3: tokens [4:8]
[DEBUG logits_processor] pass 3/3: tokens [8:10]

Accuracy Tests

Ran test/registered/lora/test_lora_hf_sgl_logprob_diff.py test on H100, passed all tests

 ~/tmp/sglang/sglang (lora_a_csgmv)]$ pytest test/registered/lora/test_lora_hf_sgl_logprob_diff.py -s -vv --log-cli-level=DEBUG
================================================================================
Overall Statistics
================================================================================

Logprob Differences:
  Prefill:
    Max of max:   1.553154e-02
    Mean of max:  1.090739e-02
    Mean of mean: 3.126751e-03
  Decode:
    Max of max:   3.064442e-02
    Mean of max:  2.019897e-02
    Mean of mean: 3.742597e-03

Logprob Statistics (threshold: 1e-01):
  Overall logprob: 5/5 PASSED
  Prefill logprob: 5/5
  Decode logprob:  5/5

String Statistics:
  Output strings:  5/5

================================================================================
Test completed successfully!
================================================================================
PASSED

Benchmarking and Profiling

Ran sanity check using gsm8k

(sglang_env) ~/tmp/sglang/sglang (lora_a_csgmv)]$ python3 -m sglang.test.few_shot_gsm8k --num-questions 200
Downloading from https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl to /tmp/test.jsonl
/tmp/test.jsonl: 732kB [00:00, 15.1MB/s]                                                                                                                                                                         
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:05<00:00, 35.16it/s]
Accuracy: 0.860
Invalid: 0.000
Latency: 5.731 s
Output throughput: 5001.777 token/s

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
  • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  1. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @MogicianWu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the capabilities of the chunked SGMV (csgmv) backend for LoRA operations by introducing comprehensive support for embedding lookup and LoRA A logic. Previously, the csgmv backend had limitations regarding embedding layers. This change integrates a new Triton-based chunked embedding lookup mechanism, removes the associated restrictions, and updates the testing suite to reflect and validate these new functionalities, ensuring accurate and efficient LoRA application to embedding layers within the chunked backend.

Highlights

  • Chunked Embedding Lookup for LoRA A: Introduced a new Triton kernel (_chunked_embedding_lora_a_kernel) and its Python wrapper (chunked_embedding_lora_a_forward) to enable efficient chunked embedding lookup for LoRA A operations, processing embedding weights along the rank dimension.
  • Expanded csgmv Backend Capabilities: Integrated the new chunked embedding lookup into the ChunkedSgmvLoRABackend, allowing the csgmv backend to fully support LoRA A embedding operations, which was previously a limitation.
  • Removed LoRA Target Module Restrictions: Eliminated the previous warning and limitation that prevented the csgmv backend from applying LoRA to embed_tokens and lm_head layers, broadening its applicability to all LoRA target modules.
  • Improved Triton Kernel Caching: Enhanced the Triton kernel caching mechanism for chunked_sgmv_expand by including OUTPUT_DIM in the cache key, which prevents potential cache collisions and ensures correct kernel selection.
  • Comprehensive Testing Updates: Updated existing tests and added new ones, including a reference_embedding_lora_a_shrink function and a dedicated test case in test_chunked_sgmv_backend.py, to validate the correctness of the chunked embedding LoRA A forward pass. The test_lora_hf_sgl_logprob_diff.py test now defaults to using the csgmv backend.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a solid contribution that enables LoRA for embedding layers in the csgmv backend, removing a previous limitation. The implementation includes a new Triton kernel, a corresponding reference implementation for robust testing, and updates to both unit and end-to-end accuracy tests. I've identified a high-severity bug due to a typo in an __init__.py file and a medium-severity performance concern in the new Triton kernel. Addressing these points will make this an excellent addition to the codebase.

"sgemm_lora_b_fwd",
"chunked_sgmv_lora_shrink_forward",
"chunked_sgmv_lora_expand_forward",
"chunked_embedding_lora_a_fwd",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a mismatch between the imported function name and the name added to __all__. You import chunked_embedding_lora_a_forward but export "chunked_embedding_lora_a_fwd". This will cause a NameError on from ... import * because chunked_embedding_lora_a_fwd is not defined.

For consistency with other chunked operations like chunked_sgmv_lora_expand_forward, the name in __all__ should be "chunked_embedding_lora_a_forward".

Suggested change
"chunked_embedding_lora_a_fwd",
"chunked_embedding_lora_a_forward",

# for each token in chunk, load embedding across rank dimension
chunk_start = tl.load(seg_indptr + chunk_idx)
chunk_end = tl.load(seg_indptr + chunk_idx + 1)
for c in range(chunk_start, chunk_end):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The for c in range(chunk_start, chunk_end): loop processes tokens within a chunk serially. This can be a performance bottleneck, as Triton kernels are most efficient when they can parallelize operations. Other chunked kernels in this project (e.g., _chunked_lora_shrink_kernel) process blocks of tokens in parallel using tl.arange to achieve better performance.

To improve performance, consider refactoring this kernel to process tokens in parallel within each program instance. While this may require more complex 2D pointer logic for loading from the weights tensor, it would align with high-performance Triton practices and the design of other kernels in this backend.

@yushengsu-thu yushengsu-thu self-assigned this Jan 25, 2026
@Fridge003 Fridge003 self-assigned this Jan 26, 2026
@Fridge003
Copy link
Collaborator

/tag-and-rerun-ci

@MogicianWu MogicianWu force-pushed the lora_a_csgmv branch 3 times, most recently from f02f06c to 0849112 Compare February 4, 2026 23:09
@github-actions github-actions bot added deepseek speculative-decoding hicache Hierarchical Caching for SGLang labels Feb 4, 2026
@MogicianWu MogicianWu force-pushed the lora_a_csgmv branch 4 times, most recently from 0f11096 to 6238556 Compare February 5, 2026 04:57
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Feb 5, 2026
@Fridge003
Copy link
Collaborator

Can we come up with some cleaner way to handling the pruned tokens? The current way is a little invasive

@MogicianWu MogicianWu force-pushed the lora_a_csgmv branch 2 times, most recently from 414da8d to 30a34e8 Compare February 8, 2026 23:10
@MogicianWu
Copy link
Contributor Author

/tag-and-rerun-ci

@yushengsu-thu
Copy link
Collaborator

This PR also resolves the issue in #18649.

@Fridge003 Fridge003 merged commit 70a6fb5 into sgl-project:main Mar 16, 2026
139 of 155 checks passed
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
…#17692)

Co-authored-by: Bruce Wu <mogicianwu@fb.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: Ethan (Yusheng) Su <yushengsu.thu@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek documentation Improvements or additions to documentation hicache Hierarchical Caching for SGLang lora run-ci speculative-decoding

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants