From 07c340269092bc5f68d940d369844004ec3969f1 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 2 Feb 2026 17:32:20 +0800 Subject: [PATCH 1/4] update flag --- python/sglang/srt/environ.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 1d33cd411f2a..fdbcb4fe9a3d 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -334,11 +334,14 @@ class Envs: # DeepGemm SGLANG_ENABLE_JIT_DEEPGEMM = EnvBool(True) SGLANG_JIT_DEEPGEMM_PRECOMPILE = EnvBool(True) + SGLANG_JIT_DEEPGEMM_PRECOMPILE_MAX_KERNEL_NUM = EnvInt(-1) SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS = EnvInt(4) SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE = EnvBool(False) SGLANG_DG_CACHE_DIR = EnvStr(os.path.expanduser("~/.cache/deep_gemm")) SGLANG_DG_USE_NVRTC = EnvBool(False) SGLANG_USE_DEEPGEMM_BMM = EnvBool(False) + + # DeepSeek MHA Optimization SGLANG_CHUNKED_PREFIX_CACHE_THRESHOLD = EnvInt(8192) # DeepEP From 08140a4b7daa79b372dfc0aea0ba8306a07c3378 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 2 Feb 2026 18:03:15 +0800 Subject: [PATCH 2/4] impl --- docs/references/environment_variables.md | 5 +- python/sglang/srt/environ.py | 2 +- .../layers/deep_gemm_wrapper/compile_utils.py | 46 +++++++++++++++---- 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index 0372106683b1..ac1b52a4b44f 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -56,8 +56,9 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS` | Number of workers for parallel DeepGEMM kernel compilation | `4` | | `SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE` | Indicator flag used during the DeepGEMM precompile script | `"false"` | | `SGLANG_DG_CACHE_DIR` | Directory for caching compiled DeepGEMM kernels | `~/.cache/deep_gemm` | -| `SGL_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` | -| `SGL_USE_DEEPGEMM_BMM` | Use DeepGEMM for Batched Matrix Multiplication (BMM) operations | `"false"` | +| `SGLANG_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` | +| `SGLANG_USE_DEEPGEMM_BMM` | Use DeepGEMM for Batched Matrix Multiplication (BMM) operations | `"false"` | +| `SGLANG_JIT_DEEPGEMM_FAST_WARMUP` | Precompile less kernels during warmup, which reduces the warmup time from 30min to less than 3min. Might cause performance degradation during runtime. | `"false"` | ## DeepEP Configuration diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index fdbcb4fe9a3d..19eae882c221 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -334,7 +334,7 @@ class Envs: # DeepGemm SGLANG_ENABLE_JIT_DEEPGEMM = EnvBool(True) SGLANG_JIT_DEEPGEMM_PRECOMPILE = EnvBool(True) - SGLANG_JIT_DEEPGEMM_PRECOMPILE_MAX_KERNEL_NUM = EnvInt(-1) + SGLANG_JIT_DEEPGEMM_FAST_WARMUP = EnvBool(False) SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS = EnvInt(4) SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE = EnvBool(False) SGLANG_DG_CACHE_DIR = EnvStr(os.path.expanduser("~/.cache/deep_gemm")) diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py b/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py index 5e25e56a239c..94266c6e367a 100644 --- a/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py @@ -27,6 +27,7 @@ _DO_COMPILE_ALL = True _IS_FIRST_RANK_ON_NODE = envs.SGLANG_IS_FIRST_RANK_ON_NODE.get() _IN_PRECOMPILE_STAGE = envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.get() +_FAST_WARMUP = envs.SGLANG_JIT_DEEPGEMM_FAST_WARMUP.get() # Force redirect deep_gemm cache_dir os.environ["DG_JIT_CACHE_DIR"] = os.getenv( @@ -44,14 +45,43 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): global _DO_COMPILE_ALL global _IS_FIRST_RANK_ON_NODE - # Generate m_max - m_max = 1024 * 16 - if server_args.chunked_prefill_size < 1: - m_max = 1024 * 64 - elif server_args.chunked_prefill_size > 8192: - m_max = server_args.chunked_prefill_size * 2 - m_max = min(1024 * 128, m_max) - _BUILTIN_M_LIST = list(range(1, m_max + 1)) + _BUILTIN_M_LIST = [] + + if _FAST_WARMUP: + # In fast warmup mode, only compile a small set of typical Ms + + # First cover all the small bs to ensure decode performance + _BUILTIN_M_LIST += list(range(1, 1025)) + + # Then cover larger batch sizes with gradually increasing steps + # For example, when chunekd prefill size is 16384 + # The sampled Ms would be: + # 1024, 1026, ... 2046 (step 2) + # 2048, 2052, ... 4092 (step 4) + # 4096, 5004, ... 8184 (step 8) + # 8192, 9008, ... 16384 (step 16) + # Totally 1024 + 1024 / 2 + 2048 / 4 + 4096 / 8 + 8192 / 16 = 3072 kernels + next_m, sample_step = 1024, 2 + max_prefill_bs = ( + server_args.chunked_prefill_size + if server_args.chunked_prefill_size + else 16 * 1024 + ) + while next_m < max_prefill_bs: + _BUILTIN_M_LIST += list(range(next_m, 2 * next_m, sample_step)) + next_m = next_m * 2 + sample_step = sample_step * 2 + _BUILTIN_M_LIST.append(max_prefill_bs) + _BUILTIN_M_LIST = list(set(_BUILTIN_M_LIST)) + else: + # When fast warmup isn't enabled, generate m_max and compile all the covered Ms. + m_max = 1024 * 16 + if server_args.chunked_prefill_size < 1: + m_max = 1024 * 64 + elif server_args.chunked_prefill_size > 8192: + m_max = server_args.chunked_prefill_size * 2 + m_max = min(1024 * 128, m_max) + _BUILTIN_M_LIST += list(range(1, m_max + 1)) _IS_FIRST_RANK_ON_NODE = server_args.base_gpu_id == gpu_id From 402b1ad631cf2d1d64965ff596ee85501576488f Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 2 Feb 2026 20:50:20 +0800 Subject: [PATCH 3/4] fix large chunked prefill size --- python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py b/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py index 94266c6e367a..18c22873c241 100644 --- a/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py @@ -63,8 +63,8 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): # Totally 1024 + 1024 / 2 + 2048 / 4 + 4096 / 8 + 8192 / 16 = 3072 kernels next_m, sample_step = 1024, 2 max_prefill_bs = ( - server_args.chunked_prefill_size - if server_args.chunked_prefill_size + min(server_args.chunked_prefill_size, 32 * 1024) + if server_args.chunked_prefill_size >= 1 else 16 * 1024 ) while next_m < max_prefill_bs: From 9989cc701c4f0b96181af39b3ad5a0d669fd3b9c Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 2 Feb 2026 21:04:26 +0800 Subject: [PATCH 4/4] Update python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py b/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py index 18c22873c241..267e6cebe7df 100644 --- a/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py @@ -72,7 +72,7 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): next_m = next_m * 2 sample_step = sample_step * 2 _BUILTIN_M_LIST.append(max_prefill_bs) - _BUILTIN_M_LIST = list(set(_BUILTIN_M_LIST)) + _BUILTIN_M_LIST = sorted(list(set(_BUILTIN_M_LIST))) else: # When fast warmup isn't enabled, generate m_max and compile all the covered Ms. m_max = 1024 * 16