diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index 0372106683b1..ac1b52a4b44f 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -56,8 +56,9 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS` | Number of workers for parallel DeepGEMM kernel compilation | `4` | | `SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE` | Indicator flag used during the DeepGEMM precompile script | `"false"` | | `SGLANG_DG_CACHE_DIR` | Directory for caching compiled DeepGEMM kernels | `~/.cache/deep_gemm` | -| `SGL_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` | -| `SGL_USE_DEEPGEMM_BMM` | Use DeepGEMM for Batched Matrix Multiplication (BMM) operations | `"false"` | +| `SGLANG_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` | +| `SGLANG_USE_DEEPGEMM_BMM` | Use DeepGEMM for Batched Matrix Multiplication (BMM) operations | `"false"` | +| `SGLANG_JIT_DEEPGEMM_FAST_WARMUP` | Precompile less kernels during warmup, which reduces the warmup time from 30min to less than 3min. Might cause performance degradation during runtime. | `"false"` | ## DeepEP Configuration diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 1d33cd411f2a..19eae882c221 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -334,11 +334,14 @@ class Envs: # DeepGemm SGLANG_ENABLE_JIT_DEEPGEMM = EnvBool(True) SGLANG_JIT_DEEPGEMM_PRECOMPILE = EnvBool(True) + SGLANG_JIT_DEEPGEMM_FAST_WARMUP = EnvBool(False) SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS = EnvInt(4) SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE = EnvBool(False) SGLANG_DG_CACHE_DIR = EnvStr(os.path.expanduser("~/.cache/deep_gemm")) SGLANG_DG_USE_NVRTC = EnvBool(False) SGLANG_USE_DEEPGEMM_BMM = EnvBool(False) + + # DeepSeek MHA Optimization SGLANG_CHUNKED_PREFIX_CACHE_THRESHOLD = EnvInt(8192) # DeepEP diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py b/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py index 5e25e56a239c..267e6cebe7df 100644 --- a/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py @@ -27,6 +27,7 @@ _DO_COMPILE_ALL = True _IS_FIRST_RANK_ON_NODE = envs.SGLANG_IS_FIRST_RANK_ON_NODE.get() _IN_PRECOMPILE_STAGE = envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.get() +_FAST_WARMUP = envs.SGLANG_JIT_DEEPGEMM_FAST_WARMUP.get() # Force redirect deep_gemm cache_dir os.environ["DG_JIT_CACHE_DIR"] = os.getenv( @@ -44,14 +45,43 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): global _DO_COMPILE_ALL global _IS_FIRST_RANK_ON_NODE - # Generate m_max - m_max = 1024 * 16 - if server_args.chunked_prefill_size < 1: - m_max = 1024 * 64 - elif server_args.chunked_prefill_size > 8192: - m_max = server_args.chunked_prefill_size * 2 - m_max = min(1024 * 128, m_max) - _BUILTIN_M_LIST = list(range(1, m_max + 1)) + _BUILTIN_M_LIST = [] + + if _FAST_WARMUP: + # In fast warmup mode, only compile a small set of typical Ms + + # First cover all the small bs to ensure decode performance + _BUILTIN_M_LIST += list(range(1, 1025)) + + # Then cover larger batch sizes with gradually increasing steps + # For example, when chunekd prefill size is 16384 + # The sampled Ms would be: + # 1024, 1026, ... 2046 (step 2) + # 2048, 2052, ... 4092 (step 4) + # 4096, 5004, ... 8184 (step 8) + # 8192, 9008, ... 16384 (step 16) + # Totally 1024 + 1024 / 2 + 2048 / 4 + 4096 / 8 + 8192 / 16 = 3072 kernels + next_m, sample_step = 1024, 2 + max_prefill_bs = ( + min(server_args.chunked_prefill_size, 32 * 1024) + if server_args.chunked_prefill_size >= 1 + else 16 * 1024 + ) + while next_m < max_prefill_bs: + _BUILTIN_M_LIST += list(range(next_m, 2 * next_m, sample_step)) + next_m = next_m * 2 + sample_step = sample_step * 2 + _BUILTIN_M_LIST.append(max_prefill_bs) + _BUILTIN_M_LIST = sorted(list(set(_BUILTIN_M_LIST))) + else: + # When fast warmup isn't enabled, generate m_max and compile all the covered Ms. + m_max = 1024 * 16 + if server_args.chunked_prefill_size < 1: + m_max = 1024 * 64 + elif server_args.chunked_prefill_size > 8192: + m_max = server_args.chunked_prefill_size * 2 + m_max = min(1024 * 128, m_max) + _BUILTIN_M_LIST += list(range(1, m_max + 1)) _IS_FIRST_RANK_ON_NODE = server_args.base_gpu_id == gpu_id