Support mrope triton kernel and add unit test#11722
Support mrope triton kernel and add unit test#11722hnyls2002 merged 3 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the efficiency of rotary embedding operations, particularly for multimodal models like Qwen2VL and Qwen2.5VL, by integrating a high-performance Triton kernel. The changes aim to boost inference speed without compromising model accuracy, as validated by extensive benchmarking and unit testing. This optimization is critical for improving the overall throughput and responsiveness of the system when handling complex multimodal inputs. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
I have a PR referenced by vllm without adding any reference information in code segment. So this time I decide not to add reference as well. |
There was a problem hiding this comment.
Code Review
This pull request introduces a Triton kernel for multimodal rotary embedding (mRoPE), which significantly improves performance as shown by the benchmark results. The changes include the new Triton kernel, a benchmark script, and unit tests to ensure correctness. My review focuses on improving code quality by addressing code duplication and adding a necessary safety check. Overall, this is a valuable contribution that enhances the performance of mRoPE operations.
|
Yuan, I think #11031 achieves the same objective, as I think the vLLM one was also from Liger (But feel free to merge yours if it can be merged faster than the one I opened) |
@b8zhong Maybe we can adopt this PR, as it adds benchmark and accuracy test. The perf we can gain in small model are close, I'll add you as co-author. |
|
@yuan-luo Sure, sounds good |
87c9020 to
502d125
Compare
Done. |
BBuf
left a comment
There was a problem hiding this comment.
LGTM, fix the comments is better.
Co-authored-by: b8zhong <b8zhong@uwaterloo.ca> Co-authored-by: Yuan Luo <yuan.luo@hotmail.com>
376a669 to
ebbd6bc
Compare
ebbd6bc to
8e959a7
Compare
Motivation
Introduce triton mrope kernel in rotary embedding.

Before:
After:

Adapted from vLLM.
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist