Optimize triton_mrope with torch compile#12112
Conversation
Summary of ChangesHello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on enhancing the performance of the Triton-based rotary embedding implementation by integrating Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a performance optimization by applying torch.compile to the _forward_triton method in MRotaryEmbedding. The change is well-motivated, and the provided benchmarks clearly demonstrate a significant speedup, particularly in Time to First Token (TTFT). This is a valuable improvement.
My review includes one suggestion for a minor refactoring within the _forward_triton method to improve code clarity and maintainability by removing some now-unreachable code. This will make the function cleaner, especially as it is now a compilation target.
|
Some VL test failed. Investigating, setting WIP. |
|
After analyzing, these unit test failed with the same issue: The root cause of adding @torch.compile to _forward_triton makes errors is as following:
With CUDA Graphs enabled, the captured graph uses the pointers/schedule from the first run; on replay, if surrounding tensor/storage changes, then get seemingly random corruption. The fix is to insert an explicit graph break before and after the call (effective during tracing): Trade-off is the Triton call site won’t be further fused by the compiler, but the rest of the ATen path can still be optimized. |
807015f to
dd164c6
Compare
|
The test script: Only adding torch compile without the graph break, the result is incorrect: Adding torch compile and graph break, result is correct: |
|
The VLLM Dependency Test failed due to #12117 merged. |
|
I verified that in main branch, this issue still exists. |
Motivation
In the recent Triton mrope PR (#11722), several Torch ops introduced overhead. Despite kernel-level speedups 30%-40%, end-to-end performance regressed.

This PR optimizes triton mrope with adding torch compile. The enhancement makes E2E under triton mrope's performance exceeds the legacy version.
The VLM online latency test:
Comparing to the version before #11722
TTFT reduces from 156.52ms to 131.41ms, 16% speedup.
E2E from 8434.92ms to 8272ms, 2% speedup.
Comparing to main:
TTFT reduces from 182.89ms to 131.41ms, 28% speedup.
E2E from 8903.92ms to 8272ms, 7% speedup.
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist