Skip to content

Comments

[sgl-kernel] misc: update deepgemm version for sgl-kernel#9340

Merged
zhyncs merged 93 commits intosgl-project:mainfrom
FlamingoPg:deepgemm-update2
Aug 27, 2025
Merged

[sgl-kernel] misc: update deepgemm version for sgl-kernel#9340
zhyncs merged 93 commits intosgl-project:mainfrom
FlamingoPg:deepgemm-update2

Conversation

@FlamingoPg
Copy link
Collaborator

@FlamingoPg FlamingoPg commented Aug 19, 2025

Motivation

DeepGEMM updated for unify cuda version.
So we need upd for TORCH LIBRARY.

It depends on: #9167

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @FlamingoPg, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request updates the DeepGEMM dependency within the sgl-kernel project to a newer version that unifies CUDA compatibility. This update involves significant changes to the build system to properly integrate DeepGEMM as a Python module and minor code adjustments to align with the updated DeepGEMM API and improve type hinting.

Highlights

  • Updated DeepGEMM Dependency: The DeepGEMM dependency has been updated to a specific version ('cabi' tag from 'sgl-project/DeepGEMM') to ensure unified CUDA compatibility across the project.
  • Refactored DeepGEMM Build System: The build and installation process for DeepGEMM in CMakeLists.txt has been refactored to compile it as a proper Python module, including linking against nvrtc for JIT compilation.
  • DeepGEMM API Adaptation: The fp8_utils.py file was adjusted to use an updated DeepGEMM layout utility function, reflecting API changes in the new DeepGEMM version.
  • Improved Type Hinting: Type hinting in mxfp4_tensor.py was enhanced by explicitly importing 'Optional' from the typing module and using it for a function parameter, improving code clarity.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the DeepGEMM dependency to a newer version, likely to support Blackwell GPUs and unify CUDA version requirements. The changes involve updating an API call in fp8_utils.py, simplifying the dependency fetching in CMakeLists.txt, and changing how DeepGEMM is built and integrated from a simple directory copy to a compiled C++ extension. The changes are logical and consistent with the goal. I have one suggestion to improve the build script's robustness.

Comment on lines +476 to +479
install(
DIRECTORY ${repo-deepgemm_SOURCE_DIR}/deep_gemm/
DESTINATION deep_gemm
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's good practice to exclude version control and cache directories from the installation. Other install commands in this file use PATTERN to exclude .git and __pycache__ directories. It would be best to add these back to avoid including unnecessary files in the distribution and maintain consistency.

install(DIRECTORY ${repo-deepgemm_SOURCE_DIR}/deep_gemm/
        DESTINATION deep_gemm
        PATTERN ".git*" EXCLUDE
        PATTERN "__pycache__" EXCLUDE)

@zhyncs
Copy link
Collaborator

zhyncs commented Aug 19, 2025

[2025-08-19 21:19:19 TP7] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2552, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 321, in __init__
    self.tp_worker = TpWorkerClass(
                     ^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 67, in __init__
    self.worker = TpModelWorker(
                  ^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 84, in __init__
    self.model_runner = ModelRunner(
                        ^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 239, in __init__
    self.initialize(min_per_gpu_memory)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 285, in initialize
    self.load_model()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 681, in load_model
    self.model = get_model(
                 ^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/__init__.py", line 22, in get_model
    return loader.load_model(
           ^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 449, in load_model
    self.load_weights_and_postprocess(
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 457, in load_weights_and_postprocess
    model.load_weights(weights)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 2726, in load_weights
    self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 2428, in post_load_weights
    self._weight_requant_ue8m0(is_nextn)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 2454, in _weight_requant_ue8m0
    requant_weight_ue8m0_inplace(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8_utils.py", line 429, in requant_weight_ue8m0_inplace
    weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
                                         ^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8_utils.py", line 464, in _requant_weight_ue8m0
    out_s = _transform_scale(out_s, mn=out_w.shape[-2])
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8_utils.py", line 461, in _transform_scale
    sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1243, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Assertion error (/sgl-kernel/build/_deps/repo-deepgemm-src/csrc/apis/../jit_kernels/impls/../../jit/compiler.hpp:50): not library_root_path.empty()

[2025-08-19 21:19:19] Received sigquit from a child process. It usually means the child failed.
[2025-08-19 21:19:19 TP6] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2552, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 321, in __init__
    self.tp_worker = TpWorkerClass(
                     ^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 67, in __init__
    self.worker = TpModelWorker(
                  ^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 84, in __init__
    self.model_runner = ModelRunner(
                        ^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 239, in __init__
    self.initialize(min_per_gpu_memory)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 285, in initialize
    self.load_model()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 681, in load_model
    self.model = get_model(
                 ^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/__init__.py", line 22, in get_model
    return loader.load_model(
           ^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 449, in load_model
    self.load_weights_and_postprocess(
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 457, in load_weights_and_postprocess
    model.load_weights(weights)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 2726, in load_weights
    self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 2428, in post_load_weights
    self._weight_requant_ue8m0(is_nextn)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 2454, in _weight_requant_ue8m0
    requant_weight_ue8m0_inplace(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8_utils.py", line 429, in requant_weight_ue8m0_inplace
    weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
                                         ^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8_utils.py", line 464, in _requant_weight_ue8m0
    out_s = _transform_scale(out_s, mn=out_w.shape[-2])
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8_utils.py", line 461, in _transform_scale
    sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1243, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Assertion error (/sgl-kernel/build/_deps/repo-deepgemm-src/csrc/apis/../jit_kernels/impls/../../jit/compiler.hpp:50): not library_root_path.empty()

[2025-08-19 21:19:19] Received sigquit from a child process. It usually means the child failed.
[2025-08-19 21:19:19 TP5] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2552, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 321, in __init__
    self.tp_worker = TpWorkerClass(
                     ^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 67, in __init__
    self.worker = TpModelWorker(
                  ^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 84, in __init__
    self.model_runner = ModelRunner(
                        ^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 239, in __init__
    self.initialize(min_per_gpu_memory)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 285, in initialize
    self.load_model()
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 681, in load_model
    self.model = get_model(
                 ^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/__init__.py", line 22, in get_model
    return loader.load_model(
           ^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 449, in load_model
    self.load_weights_and_postprocess(
  File "/sgl-workspace/sglang/python/sglang/srt/model_loader/loader.py", line 457, in load_weights_and_postprocess
    model.load_weights(weights)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 2726, in load_weights
    self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 2428, in post_load_weights
    self._weight_requant_ue8m0(is_nextn)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 2454, in _weight_requant_ue8m0
    requant_weight_ue8m0_inplace(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8_utils.py", line 429, in requant_weight_ue8m0_inplace
    weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
                                         ^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8_utils.py", line 464, in _requant_weight_ue8m0
    out_s = _transform_scale(out_s, mn=out_w.shape[-2])
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8_utils.py", line 461, in _transform_scale
    sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1243, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Assertion error (/sgl-kernel/build/_deps/repo-deepgemm-src/csrc/apis/../jit_kernels/impls/../../jit/compiler.hpp:50): not library_root_path.empty()

[2025-08-19 21:19:19] Received sigquit from a child process. It usually means the child failed.
[1]    1594853 killed     python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8

@zhyncs
Copy link
Collaborator

zhyncs commented Aug 19, 2025

>>> import deep_gemm
Warning: Missing operations: ['init', 'set_num_sms', 'get_num_sms', 'set_tc_util', 'get_tc_util', 'fp8_gemm_nt', 'fp8_gemm_nn', 'fp8_gemm_tn', 'fp8_gemm_tt', 'm_grouped_fp8_gemm_nt_contiguous', 'm_grouped_fp8_gemm_nn_contiguous', 'm_grouped_fp8_gemm_nt_masked', 'k_grouped_fp8_gemm_tn_contiguous', 'transform_sf_into_required_layout']

@rainj-me
Copy link
Collaborator

Tested with cuda 13.0

Screenshot 2025-08-27 at 12 39 02 AM

@zhyncs zhyncs merged commit aa3eba8 into sgl-project:main Aug 27, 2025
17 of 60 checks passed
@namanlalitnyu
Copy link

@zhyncs @FlamingoPg
Wanted to share that I'm observing the following error after running the sglang server from the above changes. It looks like something related to the merged changes, but could you please verify it once.

My machine has cuda 12.8 and python 3.12, and I installed sglang from source.
Screenshot 2025-08-27 at 2 14 54 PM

@zhyncs
Copy link
Collaborator

zhyncs commented Aug 27, 2025

@namanlalitnyu use the latest main please and install sgl-kernel==0.3.7

@namanlalitnyu
Copy link

@zhyncs thanks for your comment!
But, with the following mentioned above, I am getting this error now.
Screenshot 2025-08-27 at 2 41 59 PM

@yizhang2077 yizhang2077 mentioned this pull request Aug 29, 2025
5 tasks
kaixih added a commit to kaixih/sglang that referenced this pull request Sep 4, 2025
kaixih added a commit to kaixih/sglang that referenced this pull request Sep 4, 2025
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
…t#9340)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
Co-authored-by: fzyzcjy <ch271828n@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working collaboration dependencies Pull requests that update a dependency file enhancement New feature or request high priority

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants