From 1e0defd7e6d34998b49a5c140e6fc96bfa54d309 Mon Sep 17 00:00:00 2001 From: yingru Date: Mon, 27 Apr 2026 21:35:45 +0800 Subject: [PATCH] feat: add SGLANG_JIT_DEEPGEMM_FAST_WARMUP to reduce CUDA graph warmup time In the deepseek_v4 branch, DeepGEMM JIT compiles up to 16K M values during CUDA graph warmup. With TP=4 on B200, this exceeds NCCL timeout thresholds and causes initialization failures. SGLANG_JIT_DEEPGEMM_FAST_WARMUP=True replaces the full M-list with a sparse sampled set (~2560 values): all M in [1,1024] for decode performance, plus geometrically-spaced values up to chunked_prefill_size for prefill coverage. This reduces cold-start time from >30min to ~5.5min while preserving decode TPOT (~22ms/tok on B200). Also guard get/set_compile_mode calls with hasattr() to support DeepGEMM versions that do not expose this API. Signed-off-by: yingru --- python/sglang/srt/environ.py | 1 + .../layers/deep_gemm_wrapper/compile_utils.py | 58 +++++++++++++++---- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 13190a4d4f23..e3dee9ee21fe 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -336,6 +336,7 @@ class Envs: # DeepGemm SGLANG_ENABLE_JIT_DEEPGEMM = EnvBool(True) SGLANG_JIT_DEEPGEMM_PRECOMPILE = EnvBool(True) + SGLANG_JIT_DEEPGEMM_FAST_WARMUP = EnvBool(False) SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS = EnvInt(4) SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE = EnvBool(False) SGLANG_DG_CACHE_DIR = EnvStr(os.path.expanduser("~/.cache/deep_gemm")) diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py b/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py index 5e25e56a239c..2fc7edd84dff 100644 --- a/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py @@ -22,7 +22,7 @@ import deep_gemm -_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) +_BUILTIN_M_LIST: List[int] = [] _ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get() _DO_COMPILE_ALL = True _IS_FIRST_RANK_ON_NODE = envs.SGLANG_IS_FIRST_RANK_ON_NODE.get() @@ -44,14 +44,43 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): global _DO_COMPILE_ALL global _IS_FIRST_RANK_ON_NODE - # Generate m_max - m_max = 1024 * 16 - if server_args.chunked_prefill_size < 1: - m_max = 1024 * 64 - elif server_args.chunked_prefill_size > 8192: - m_max = server_args.chunked_prefill_size * 2 - m_max = min(1024 * 128, m_max) - _BUILTIN_M_LIST = list(range(1, m_max + 1)) + _BUILTIN_M_LIST = [] + + if envs.SGLANG_JIT_DEEPGEMM_FAST_WARMUP.get(): + # In fast warmup mode, only compile a small set of typical Ms + + # First cover all the small bs to ensure decode performance + _BUILTIN_M_LIST += list(range(1, 1025)) + + # Then cover larger batch sizes with gradually increasing steps + # For example, when chunked prefill size is 16384 + # The sampled Ms would be: + # 1024, 1026, ... 2046 (step 2) + # 2048, 2052, ... 4092 (step 4) + # 4096, 5004, ... 8184 (step 8) + # 8192, 9008, ... 16384 (step 16) + # Totally 1024 + 1024/2 + 2048/4 + 4096/8 + 8192/16 = 3072 kernels + next_m, sample_step = 1024, 2 + max_prefill_bs = ( + min(server_args.chunked_prefill_size, 32 * 1024) + if server_args.chunked_prefill_size >= 1 + else 16 * 1024 + ) + while next_m < max_prefill_bs: + _BUILTIN_M_LIST += list(range(next_m, min(2 * next_m, max_prefill_bs), sample_step)) + next_m = next_m * 2 + sample_step = sample_step * 2 + _BUILTIN_M_LIST.append(max_prefill_bs) + _BUILTIN_M_LIST = sorted(set(_BUILTIN_M_LIST)) + else: + # When fast warmup isn't enabled, generate m_max and compile all the covered Ms. + m_max = 1024 * 16 + if server_args.chunked_prefill_size < 1: + m_max = 1024 * 64 + elif server_args.chunked_prefill_size > 8192: + m_max = server_args.chunked_prefill_size * 2 + m_max = min(1024 * 128, m_max) + _BUILTIN_M_LIST += list(range(1, m_max + 1)) _IS_FIRST_RANK_ON_NODE = server_args.base_gpu_id == gpu_id @@ -163,12 +192,17 @@ def _compile_deep_gemm_one_type_all( kernel_type, max_m=max_m, n=n, k=k, num_groups=num_groups ) - old_compile_mode = deep_gemm.get_compile_mode() - deep_gemm.set_compile_mode(1) + has_compile_mode_api = hasattr(deep_gemm, "get_compile_mode") and hasattr( + deep_gemm, "set_compile_mode" + ) + if has_compile_mode_api: + old_compile_mode = deep_gemm.get_compile_mode() + deep_gemm.set_compile_mode(1) # TODO can use multi thread for m in tqdm(m_list, desc=f"DeepGEMM warmup"): executor.execute(m=m) - deep_gemm.set_compile_mode(old_compile_mode) + if has_compile_mode_api: + deep_gemm.set_compile_mode(old_compile_mode) # clean up input buffers torch.cuda.current_stream().synchronize()