-
Notifications
You must be signed in to change notification settings - Fork 3.4k
deepgemm update enhance groupgemm #6893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
deepgemm update enhance groupgemm #6893
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @zhangxiaolei123456, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello team, gemini-code-assist here to provide a summary of this pull request. This PR, authored by @zhangxiaolei123456, appears to focus on updating the integration with the DeepGEMM library. Based on the title and the code changes, the primary goal is to update the DeepGEMM dependency to a newer version and adapt the local code to changes in the library's JIT compilation API, specifically for FP8 GEMM kernels.
Highlights
- DeepGEMM Dependency Update: The pull request updates the git tag for the DeepGEMM dependency in the CMake configuration (
sgl-kernel/CMakeLists.txt) to a newer commit (8dfa3298...). This brings in the latest changes from the DeepGEMM repository. - Kernel Compilation Logic Refactor: The Python code responsible for compiling DeepGEMM kernels (
python/sglang/srt/layers/quantization/deep_gemm.py) has been modified. Instead of usingjit_tuner.compile_and_tune, the code now explicitly callsFP8GemmRuntime.generateto get the kernel code and then uses the newly importedbuildfunction fromdeep_gemm.jitto compile it. This change is applied to the compilation of masked, contiguous, and normal FP8 GEMM kernels. - Updated Compilation Parameters: The parameters passed to the kernel compilation functions (
_compile_grouped_gemm_nt_f8f8bf16_masked_one,_compile_grouped_gemm_nt_f8f8bf16_contig_one,_compile_gemm_nt_f8f8bf16_one) have been updated. More kernel configuration parameters likeN,K,BLOCK_M,BLOCK_N, etc., are now included directly in thekwargsdictionary used for generating and building the kernel code.
Changelog
- python/sglang/srt/layers/quantization/deep_gemm.py
- Added import for
buildfromdeep_gemm.jit. - Replaced calls to
jit_tuner.compile_and_tunewith calls toFP8GemmRuntime.generatefollowed bybuildfor FP8 GEMM kernel compilation. - Modified the
kwargsdictionary passed to the kernel compilation functions to include additional kernel configuration parameters.
- Added import for
- sgl-kernel/CMakeLists.txt
- Updated the
GIT_TAGfor the DeepGEMM dependency to commit8dfa3298274bfe6b242f6f8a3e6f3eff2707dd9f.
- Updated the
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Code changes flow in,
New library version pulled,
Kernels compile now.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request updates the integration with the deep_gemm library, primarily by changing how JIT compilation is invoked for GEMM kernels and updating the deep_gemm dependency version in CMake.
The changes in python/sglang/srt/layers/quantization/deep_gemm.py adapt to a new API, replacing jit_tuner.compile_and_tune with FP8GemmRuntime.generate and deep_gemm.jit.build. While this adaptation seems to follow a new pattern, there's a critical concern regarding the handling of NUM_GROUPS for masked grouped GEMM, and a medium concern about an unused parameter for contiguous grouped GEMM.
The PR description is currently a template. It would be beneficial to fill it out with the motivation and a summary of modifications to help reviewers understand the changes better. Also, the checklist is not completed.
Summary of Findings
- Potential incorrect
NUM_GROUPShandling for masked grouped GEMM: In_compile_grouped_gemm_nt_f8f8bf16_masked_one,NUM_GROUPSis hardcoded to 1 in thekwargspassed toFP8GemmRuntime.generateandbuild. This might be incorrect if the actualnum_groups(passed as a function parameter) can be greater than 1 and is required for correct kernel compilation. - Unused parameter in contiguous grouped GEMM compilation: The
num_groupsparameter in_compile_grouped_gemm_nt_f8f8bf16_contig_oneappears to be unused, as the logic consistently usesNUM_GROUPS = 1for this kernel type. The parameter should be marked as unused (e.g., renamed to_) for clarity. - PR Description and Checklist: The pull request description is a template and the checklist is not filled. Providing details about the motivation and changes would be helpful for reviewers.
Merge Readiness
This pull request updates the deep_gemm integration, which is a significant change. There is a critical concern regarding the handling of NUM_GROUPS in _compile_grouped_gemm_nt_f8f8bf16_masked_one that needs to be addressed to ensure correctness. Additionally, there's a point of code clarity regarding an unused parameter in another function.
I recommend that these issues, especially the critical one, be resolved before merging. As an AI, I am not authorized to approve pull requests; please ensure further review and approval from team members.
| "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group, | ||
| "BLOCK_K": block_k, | ||
| 'N': n, 'K': k, | ||
| 'NUM_GROUPS': 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It appears NUM_GROUPS is hardcoded to 1 here. However, the function _compile_grouped_gemm_nt_f8f8bf16_masked_one receives a num_groups parameter, which is passed down from _maybe_compile_deep_gemm_one_type_all and is used in configure_func (which calls get_best_configs with this num_groups).
If FP8GemmRuntime.generate or deep_gemm.jit.build expect the actual num_groups for kernel generation (e.g., as a template parameter), hardcoding it to 1 could lead to incorrect kernel compilation or runtime behavior when num_groups is greater than 1.
Could you clarify if this is an intentional change due to the deep_gemm API update, or if the num_groups variable from the function signature should be used here?
'NUM_GROUPS': num_groups, // Use the num_groups parameter from the function signature|
This line should be remove |
|
@zhyncs Let's update this PR while next version of sgl-kernel bumped |
|
Thanks for your contribution! |
Motivation
deepseek-ai/DeepGEMM@8dfa329
Modifications
Checklist