[Feature] Support custom set kv buffer kernel#8884
[Feature] Support custom set kv buffer kernel#8884merrymercy merged 3 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @DarkSharpness, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
I've implemented a significant performance enhancement by introducing a custom CUDA kernel to optimize the set_kv_cache operation. This change directly addresses an identified overhead in H100 environments by fusing the storage of key and value caches into a single, more efficient operation. The primary goal is to improve the end-to-end performance of large language models, particularly for small batch sizes and large tensor parallelism configurations, while ensuring a robust fallback to existing PyTorch implementations if the custom kernel is not applicable.
Highlights
- Custom CUDA Kernel for KV Cache: I've introduced a new custom CUDA kernel specifically designed to fuse the storage operations for Key (K) and Value (V) caches. This kernel replaces the previous, less efficient method of copying K and V caches separately, which was identified as a performance bottleneck.
- Performance Optimization: The integration of this custom kernel is expected to yield significant performance improvements. Initial profiling indicates an approximate 3% end-to-end performance gain for the Llama 3.1 8B model on H100 GPUs, particularly for small batch sizes. For configurations with large Tensor Parallelism (TP), where KVCache per rank is smaller, even greater improvements are anticipated.
- Robust Fallback Mechanism: To ensure robustness and compatibility, the new
set_kv_buffer_kernelincludes a fallback mechanism. If the optimized CUDA kernel cannot be utilized (e.g., due to specific hardware limitations, non-CUDA environments, or explicit configuration), the system will gracefully revert to the original PyTorch-based tensor assignment for K and V caches. - System Integration: The new kernel and its Python wrapper have been seamlessly integrated into the existing memory management system. This involves adding the CUDA source file to the build system, registering the C++ function as a PyTorch operator, and updating the
memory_pool.pyto call the new optimized kernel.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request introduces a custom CUDA kernel, set_kv_buffer_kernel, to optimize the process of setting key-value caches. The new kernel fuses the memory copy operations for both key and value caches, which, according to the description, yields about a 3% end-to-end performance improvement on H100 GPUs for Llama 3.1 8B with small batch sizes. The changes span across Python, C++, and CMake files to integrate this new kernel.
My review focuses on the implementation of the new kernel and its integration. I've identified a couple of areas for improvement:
- In the CUDA kernel implementation (
store.cu), there's significant code duplication for handling different data types of theout_loctensor. I've suggested a refactoring to improve maintainability. - The Python wrapper for the kernel (
memory.py) has a broad exception handler that could silently swallow important errors. I've recommended adding logging to aid in debugging potential kernel failures.
Overall, the changes are well-structured and address a valid performance concern. The suggestions aim to make the new code more robust and easier to maintain.
e2914e6 to
cdd871f
Compare
6029964 to
609b017
Compare
609b017 to
cdc1153
Compare
Motivation
nsys profile shows that
set_kv_cachewill lead to a little overhead in H100. We write a kernel to fuse storing into key cache and value cache.Modifications
Before:
After:
This PR will bring around 3% e2e performance improvement for llama 3.1 8B in small batch size on H100. For large TP where the KVCache per rank is smaller (e.g. TP=8 for 70B model, 1 head per GPU), this should works even better.
Accuracy Test
Benchmark & Profiling
Checklist