Skip to content

Support mrope triton kernel and add unit test#11722

Merged
hnyls2002 merged 3 commits intosgl-project:mainfrom
antgroup:mrope_triton_kernel
Oct 20, 2025
Merged

Support mrope triton kernel and add unit test#11722
hnyls2002 merged 3 commits intosgl-project:mainfrom
antgroup:mrope_triton_kernel

Conversation

@yuan-luo
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo commented Oct 16, 2025

Motivation

Introduce triton mrope kernel in rotary embedding.
Before:
image

After:
image

Adapted from vLLM.

Unit test passed:

$python -m pytest test_mrope.py
============================================================================================= test session starts =============================================================================================
platform linux -- Python 3.10.13, pytest-8.3.5, pluggy-1.5.0
rootdir: /root/luoyuan.luo/workspace/sglang_dev/test/srt/rotary_embedding
plugins: anyio-4.8.0
collected 8 items                                                                                                                                                                                             

test_mrope.py ........                                                                                                                                                                                  [100%]

============================================================================================= 8 passed in 13.60s ==============================================================================================
gsm8k acc no drop

$python3 benchmark/gsm8k/bench_sglang.py --num-questions 200 --parallel 128 --num-shots 8 --port 30000 --data-path ./test.jsonl
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:14<00:00, 13.73it/s]
Accuracy: 0.950
Invalid: 0.000
Latency: 14.814 s
Triton kernel vs baseline Torch baseline:
TP8 speedup 32%
TP4 speedup 45%
TP2 speedup 48%
More benchmark data will be updated.

TP8
$python3 benchmark_mrope.py --model-name /home/admin/Qwen2.5-VL-72B-Instruct --tp-size 8 --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
Namespace(model_name='/home/admin/Qwen2.5-VL-72B-Instruct', tp_size=8, warmup_iter=10, benchmark_iter=100, dtype='bfloat16', seed=0, num_tokens=[1024], trust_remote_code=False, output_csv='mrope_benchmark_results.csv')
================================================================================
Evaluating model: /home/admin/Qwen2.5-VL-72B-Instruct with tp_size: 8 and num_tokens: 1024, dtype: torch.bfloat16

Performance for config (1024, 8, 1):
Torch implementation: mean=0.00016325s, median=0.00016177s, p99=0.00018089s
Triton implementation: mean=0.00012344s, median=0.00012112s, p99=0.00014963s
Triton Speedup over Torch: 1.32250401x
Benchmark results saved to mrope_benchmark_results_20251010_192016.csv

TP4
$python3 benchmark_mrope.py --model-name /home/admin/Qwen2.5-VL-72B-Instruct --tp-size 4 --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
Namespace(model_name='/home/admin/Qwen2.5-VL-72B-Instruct', tp_size=4, warmup_iter=10, benchmark_iter=100, dtype='bfloat16', seed=0, num_tokens=[1024], trust_remote_code=False, output_csv='mrope_benchmark_results.csv')
================================================================================
Evaluating model: /home/admin/Qwen2.5-VL-72B-Instruct with tp_size: 4 and num_tokens: 1024, dtype: torch.bfloat16

Performance for config (1024, 16, 2):
Torch implementation: mean=0.00017449s, median=0.00017321s, p99=0.00020296s
Triton implementation: mean=0.00011901s, median=0.00011730s, p99=0.00014256s
Triton Speedup over Torch: 1.46622326x
Benchmark results saved to mrope_benchmark_results_20251010_193246.csv

TP2
$python3 benchmark_mrope.py --model-name /home/admin/Qwen2.5-VL-72B-Instruct --tp-size 2 --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
Namespace(model_name='/home/admin/Qwen2.5-VL-72B-Instruct', tp_size=2, warmup_iter=10, benchmark_iter=100, dtype='bfloat16', seed=0, num_tokens=[1024], trust_remote_code=False, output_csv='mrope_benchmark_results.csv')
================================================================================
Evaluating model: /home/admin/Qwen2.5-VL-72B-Instruct with tp_size: 2 and num_tokens: 1024, dtype: torch.bfloat16

Performance for config (1024, 32, 4):
Torch implementation: mean=0.00018627s, median=0.00018525s, p99=0.00019919s
Triton implementation: mean=0.00012546s, median=0.00012255s, p99=0.00016031s
Triton Speedup over Torch: 1.48466421x
Benchmark results saved to mrope_benchmark_results_20251010_193415.csv

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the efficiency of rotary embedding operations, particularly for multimodal models like Qwen2VL and Qwen2.5VL, by integrating a high-performance Triton kernel. The changes aim to boost inference speed without compromising model accuracy, as validated by extensive benchmarking and unit testing. This optimization is critical for improving the overall throughput and responsiveness of the system when handling complex multimodal inputs.

Highlights

  • Triton mROPE Kernel Integration: Implemented a new Triton-based kernel for multimodal Rotary Positional Embedding (mROPE), adapted from vLLM, to accelerate computations.
  • Performance Benchmarking: Introduced a dedicated benchmark script (benchmark_mrope.py) to quantitatively compare the performance of the new Triton kernel against the existing native PyTorch implementation.
  • Significant Speedup: The Triton mROPE kernel demonstrates substantial performance improvements, achieving speedups of 32% (TP8), 45% (TP4), and 48% (TP2) over the native Torch implementation.
  • Unit Test Coverage: Added comprehensive unit tests (test_mrope.py) to ensure the correctness and numerical stability of the Triton mROPE kernel across various model configurations and token lengths.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

I have a PR referenced by vllm without adding any reference information in code segment. So this time I decide not to add reference as well.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Triton kernel for multimodal rotary embedding (mRoPE), which significantly improves performance as shown by the benchmark results. The changes include the new Triton kernel, a benchmark script, and unit tests to ensure correctness. My review focuses on improving code quality by addressing code duplication and adding a necessary safety check. Overall, this is a valuable contribution that enhances the performance of mRoPE operations.

@b8zhong
Copy link
Copy Markdown
Collaborator

b8zhong commented Oct 16, 2025

Yuan, I think #11031 achieves the same objective, as I think the vLLM one was also from Liger (But feel free to merge yours if it can be merged faster than the one I opened)

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Oct 17, 2025

Yuan, I think #11031 achieves the same objective, as I think the vLLM one was also from Liger (But feel free to merge yours if it can be merged faster than the one I opened)

@b8zhong Maybe we can adopt this PR, as it adds benchmark and accuracy test. The perf we can gain in small model are close, I'll add you as co-author.

@b8zhong b8zhong mentioned this pull request Oct 17, 2025
1 task
@b8zhong
Copy link
Copy Markdown
Collaborator

b8zhong commented Oct 17, 2025

@yuan-luo Sure, sounds good

@yuan-luo yuan-luo force-pushed the mrope_triton_kernel branch from 87c9020 to 502d125 Compare October 17, 2025 05:36
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

@yuan-luo Sure, sounds good

Done.

Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, fix the comments is better.

Co-authored-by: b8zhong <b8zhong@uwaterloo.ca>
Co-authored-by: Yuan Luo <yuan.luo@hotmail.com>
@yuan-luo yuan-luo force-pushed the mrope_triton_kernel branch from ebbd6bc to 8e959a7 Compare October 18, 2025 08:17
@yuan-luo yuan-luo enabled auto-merge (squash) October 20, 2025 03:27
@hnyls2002 hnyls2002 disabled auto-merge October 20, 2025 03:47
@hnyls2002 hnyls2002 merged commit 271d3d0 into sgl-project:main Oct 20, 2025
157 of 165 checks passed
@yuan-luo yuan-luo mentioned this pull request Oct 20, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants