Enable embedding lookup/lora_a logic for chunked backend#17692
Enable embedding lookup/lora_a logic for chunked backend#17692Fridge003 merged 11 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @MogicianWu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the capabilities of the chunked SGMV (csgmv) backend for LoRA operations by introducing comprehensive support for embedding lookup and LoRA A logic. Previously, the csgmv backend had limitations regarding embedding layers. This change integrates a new Triton-based chunked embedding lookup mechanism, removes the associated restrictions, and updates the testing suite to reflect and validate these new functionalities, ensuring accurate and efficient LoRA application to embedding layers within the chunked backend. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request is a solid contribution that enables LoRA for embedding layers in the csgmv backend, removing a previous limitation. The implementation includes a new Triton kernel, a corresponding reference implementation for robust testing, and updates to both unit and end-to-end accuracy tests. I've identified a high-severity bug due to a typo in an __init__.py file and a medium-severity performance concern in the new Triton kernel. Addressing these points will make this an excellent addition to the codebase.
| "sgemm_lora_b_fwd", | ||
| "chunked_sgmv_lora_shrink_forward", | ||
| "chunked_sgmv_lora_expand_forward", | ||
| "chunked_embedding_lora_a_fwd", |
There was a problem hiding this comment.
There is a mismatch between the imported function name and the name added to __all__. You import chunked_embedding_lora_a_forward but export "chunked_embedding_lora_a_fwd". This will cause a NameError on from ... import * because chunked_embedding_lora_a_fwd is not defined.
For consistency with other chunked operations like chunked_sgmv_lora_expand_forward, the name in __all__ should be "chunked_embedding_lora_a_forward".
| "chunked_embedding_lora_a_fwd", | |
| "chunked_embedding_lora_a_forward", |
| # for each token in chunk, load embedding across rank dimension | ||
| chunk_start = tl.load(seg_indptr + chunk_idx) | ||
| chunk_end = tl.load(seg_indptr + chunk_idx + 1) | ||
| for c in range(chunk_start, chunk_end): |
There was a problem hiding this comment.
The for c in range(chunk_start, chunk_end): loop processes tokens within a chunk serially. This can be a performance bottleneck, as Triton kernels are most efficient when they can parallelize operations. Other chunked kernels in this project (e.g., _chunked_lora_shrink_kernel) process blocks of tokens in parallel using tl.arange to achieve better performance.
To improve performance, consider refactoring this kernel to process tokens in parallel within each program instance. While this may require more complex 2D pointer logic for loading from the weights tensor, it would align with high-performance Triton practices and the design of other kernels in this backend.
c9a9233 to
53278f2
Compare
|
/tag-and-rerun-ci |
f02f06c to
0849112
Compare
0f11096 to
6238556
Compare
6238556 to
aa6840e
Compare
|
Can we come up with some cleaner way to handling the pruned tokens? The current way is a little invasive |
414da8d to
30a34e8
Compare
|
/tag-and-rerun-ci |
251f2af to
86afd35
Compare
…apter does not contain embedding layers
86afd35 to
242670e
Compare
|
This PR also resolves the issue in #18649. |
…#17692) Co-authored-by: Bruce Wu <mogicianwu@fb.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com> Co-authored-by: Ethan (Yusheng) Su <yushengsu.thu@gmail.com>
Motivation
Per #14177, we want to support csgmv backend for lora_a shrink forward/embedding lookup.
Modifications
Accuracy Tests
Ran test/registered/lora/test_lora_hf_sgl_logprob_diff.py test on H100, passed all tests
Benchmarking and Profiling
Ran sanity check using gsm8k
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci