fix(lora): add run_lora_a_embedding to ChunkedSgmvLoRABackend#14805
fix(lora): add run_lora_a_embedding to ChunkedSgmvLoRABackend#14805ashtonchew wants to merge 12 commits intosgl-project:mainfrom
Conversation
Fixes sgl-project#14773. When using --lora-target-modules all with the csgmv backend, the embedding layer LoRA was failing because ChunkedSgmvLoRABackend did not implement run_lora_a_embedding(). This commit adds: - run_lora_a_embedding() method using existing triton kernel - Separate embedding_batch_info for original sequence structure - CUDA graph support for embedding LoRA - Unit test for the new method
Summary of ChangesHello @ashtonchew, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates comprehensive embedding LoRA capabilities into the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for LoRA embeddings in the ChunkedSgmvLoRABackend. The changes include importing a new Triton kernel (embedding_lora_a_fwd), adding a run_lora_a_embedding method to utilize this kernel, and modifying init_cuda_graph_batch_info and prepare_lora_batch to manage a separate embedding_batch_info. This new embedding_batch_info is designed to preserve the original sequence structure, which is crucial for embedding lookups, unlike the chunked structure used for linear layers. A new test file was added to validate this functionality. A review comment noted a potential minor performance issue in prepare_lora_batch where weight_indices_for_embedding is repeatedly converted from a Python list to a torch.Tensor on the CPU, suggesting optimization by passing tensors directly or updating existing tensors in-place.
| weight_indices_for_embedding = torch.tensor( | ||
| weight_indices, dtype=torch.int32, pin_memory=True, device="cpu" | ||
| ) |
There was a problem hiding this comment.
Similar to the "lora_ranks_tensor" and "scalings_tensor", "weight_indices_for_embedding" is created from a Python list on every call. If "prepare_lora_batch" is a hot path, this repeated conversion from a Python list to a "torch.Tensor" on the CPU could be a minor performance concern. Exploring ways to pass "torch.Tensor" directly or update existing tensors in-place could offer a slight optimization.
There was a problem hiding this comment.
This follows the existing pattern used in the other backends (triton_backend.py:154, torch_backend.py:236, ascend_backend.py:242). A broader optimization to pass tensors directly would require changes across all backends and the caller interface, which seems out of scope for this bug fix. Valid comment but would be an out-of-pattern change.
Referenced code snippets from above:
triton_backend.py:154
weight_indices_tensor = torch.tensor(
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
)torch_backend.py:236
weight_indices_tensor = torch.tensor(
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
)ascend_backend.py:242
weight_indices_tensor = torch.tensor(
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
)Happy to address it in a separate PR if maintainers think it's worth pursuing.
|
/tag-and-rerun-ci |
|
@yushengsu-thu looks like the tests are passing, I think the failing tests are unrelated |
f462330 to
673db8d
Compare
|
Addressed the accuracy verification request from @yushengsu-thu in Slack thread: Added
Technical NoteAdapter has Test Results (A100)SGLang and HuggingFace logprobs match, confirming the |
|
@ashtonchew |
|
@yushengsu-thu |
- Delete create_embedding_lora_adapter.py (random weights) - Add train_embedding_lora_adapter.py (actual fine-tuning) - Relax logprob threshold for embedding LoRA test
Add test_lora_embedding_logprob_comparison_triton to validate triton backend as baseline before testing chunked SGMV backend.
|
@yushengsu-thu I've addressed this by training and uploading a proper LoRA adapter to HuggingFace: ash256/sglang_embedding_lora_test_adapter This adapter, as requested:
Test results with the chunked backend (csgmv): String outputs match perfectly: The outputs are coherent text (not "111111"), confirming the adapter is properly trained and the implementation is functionally correct. Logprob precision within new threshold:
Following your suggestion to validate triton_backend first, then chunked_backend, I added a triton backend version of the same test (test_lora_embedding_logprob_comparison_triton) and ran both: Triton backend (baseline): PASSED (43.73s) Both backends use the same trained adapter and produce matching results against HuggingFace, confirming the chunked backend implementation is correct and consistent with the established triton backend. |
There was a problem hiding this comment.
@ashtonchew
In these two test scripts:
test/srt/lora/test_lora_eviction.py and test/srt/lora/test_lora_update.py,
could you also add the same test cases as the original ones in these two files, and set lora_target_modules=["all"] along with use your trained lora weight.
Also, could you comment out this part of the code (https://github.com/sgl-project/sglang/pull/14796/files) to verify that the necessary LoRA CI/CD run with ChunkedSgmvLoRABackend and checks still pass on your side?
Feel free to ping me if you get any questions.
… csgmv - Comment out csgmv embed_tokens/lm_head restriction in server_args.py since csgmv backend now supports these via run_lora_a_embedding() - Add test_lora_eviction_with_embedding_lora_all_target_modules test - Add EMBEDDING_LORA_TESTS to test_lora_update.py with 2 test cases - Uses ash256/sglang_embedding_lora_test_adapter for testing
|
@yushengsu-thu I verified that LoRA CI/CD passes with ChunkedSgmvLoRABackend. Changes Made
# NOTE: The following code has been commented out because the csgmv backend
# now supports embedding and lm_head layers via run_lora_a_embedding().
# See: https://github.com/sgl-project/sglang/pull/14796
#
# if self.lora_backend == "csgmv":
# logger.warning(...)
# self.lora_target_modules.discard("embed_tokens")
# self.lora_target_modules.discard("lm_head")
Test ResultsTest 1:
|
|
|
||
| model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | ||
| lora_paths = ["ash256/sglang_embedding_lora_test_adapter"] | ||
| prompts = DEFAULT_TEST_PROMPTS[:2] |
There was a problem hiding this comment.
There are only 5 prompts. Just test all of them and do not need DEFAULT_TEST_PROMPTS[:2]
There was a problem hiding this comment.
Fixed, now using all 5 prompts.
| ) | ||
| else: | ||
| output_history[prompt] = output | ||
|
|
There was a problem hiding this comment.
To keep the codebase clean and easy to maintain, could you please simplify the code? If the test was generated using Cursor, it would be great if you could prune it and remove any redundant parts.
Like following the format below:
def test_lora_eviction_with_embedding_lora_all_target_modules(self):
self._run_test(....)
self._run_test(....)
There was a problem hiding this comment.
Refactored to use _run_test(), added optional base_model, lora_target_modules, and max_lora_rank parameters to support the embedding LoRA config.
- Use all DEFAULT_TEST_PROMPTS instead of [:2] slice in logprob diff tests - Refactor test_lora_eviction_with_embedding_lora_all_target_modules to use _run_test() - Add optional base_model, lora_target_modules, max_lora_rank params to _run_test()
| @@ -0,0 +1,96 @@ | |||
| """Test ChunkedSgmvLoRABackend.run_lora_a_embedding() method.""" | |||
There was a problem hiding this comment.
Please move this test to test/nightly
There was a problem hiding this comment.
Moved into test/nightly.
python/sglang/srt/server_args.py
Outdated
| # now supports embedding and lm_head layers via run_lora_a_embedding(). | ||
| # See: https://github.com/sgl-project/sglang/pull/14796 | ||
| # | ||
| # # When using the chunked SGMV backend, skip embedding / lm_head layers for now, |
There was a problem hiding this comment.
We can remove these comment lines
There was a problem hiding this comment.
Removed in latest revision.
| ) | ||
|
|
||
| # Also create embedding-specific batch info (uses original sequence structure) | ||
| self.cuda_graph_embedding_batch_info = LoRABatchInfo( |
There was a problem hiding this comment.
We don't need embedding info when there is no embedding layer in the lora adaptors.
Can we skip it optionally?
There was a problem hiding this comment.
Yes this makes sense, added in latest revision.
| self.batch_info = batch_info | ||
|
|
||
| # Setup embedding_batch_info (uses original sequence structure, not chunked) | ||
| bs = forward_batch.batch_size |
There was a problem hiding this comment.
The same, can we skip this part when there is no embedding layers in lora adaptors?
There was a problem hiding this comment.
Resolved as above.
| @@ -43,6 +43,12 @@ | |||
|
|
|||
There was a problem hiding this comment.
We can change some lora_target_modules back to ["all"], including:
sglang/test/srt/lora/test_lora_update.py
Line 228 in fd64594
sglang/test/srt/lora/test_lora_update.py
Line 774 in fd64594
sglang/test/srt/lora/test_lora_update.py
Line 1607 in fd64594
There was a problem hiding this comment.
Changed in latest revision
| max_len=None, # Not used in CSGMV backend | ||
| ) | ||
|
|
||
| # Also create embedding-specific batch info (uses original sequence structure) |
There was a problem hiding this comment.
Can you add a TODO here, saying that the embedding_batch_info will be removed after the chunked kernel for embedding has been implemented.
Since this PR doesn't solve the problem from bottom. It's just a workaround to make embedding run with a replaced kernel.
There was a problem hiding this comment.
Added TODO comment.
Skip embedding_batch_info creation when target_modules doesn't include embed_tokens or lm_head, reducing unnecessary allocations. - Add has_embedding_layers flag to conditionally initialize embedding batch info only when embedding LoRA layers are enabled - Remove obsolete comments about csgmv embedding limitations - Move chunked embedding test to nightly suite - Restore lora_target_modules=["all"] in update tests
|
@Fridge003 here |
|
@Fridge003 following up on this |
|
Hi @ashtonchew, thank you for your contribution, but the implementation in this PR is not correct. The triton kernel for chunked csgmv backend should be customized |
|
This PR can be closed. Another PR is already supported this: #17692 |
|
Hi @Fridge003, sounds good, thanks for adding as co-author to the new PR. |
Motivation
Fixes #14773. This PR is a follow-up to #14796, which introduced a workaround by excluding
embed_tokensandlm_headfrom--lora-target-modules allwhen using thecsgmvbackend.Rather than disabling this functionality, this PR implements the missing
run_lora_a_embedding()method inChunkedSgmvLoRABackend, enabling full embedding LoRA support for the chunked SGMV backend. This addresses the TODO comment added in #14796.Modifications
embedding_lora_a_fwdimport from triton_opsrun_lora_a_embedding()method using existing triton kernelcuda_graph_embedding_batch_infofor CUDA graph modeembedding_batch_infosetup inprepare_lora_batch()to maintain original sequence structure (not chunked/reordered like linear layers)test/srt/lora/test_chunked_lora_embedding.pyTest Plan
test/srt/lora/test_chunked_lora_embedding.pythat verifies:embedding_batch_infois created with original sequence structurepermutation=None(no reordering for embeddings)(seq_len, rank)embedding_batch_infoAccuracy Tests
N/A - This fix adds missing functionality that was causing a crash. It does not change existing model outputs; it enables a previously broken code path to work correctly.
Benchmarking and Profiling
N/A - This fix adds missing functionality. The implementation follows the same pattern as
TritonLoRABackend.run_lora_a_embedding()and reuses existing optimized triton kernels, so no performance regression is expected.Checklist