Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/references/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS` | Number of workers for parallel DeepGEMM kernel compilation | `4` |
| `SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE` | Indicator flag used during the DeepGEMM precompile script | `"false"` |
| `SGLANG_DG_CACHE_DIR` | Directory for caching compiled DeepGEMM kernels | `~/.cache/deep_gemm` |
| `SGL_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` |
| `SGL_USE_DEEPGEMM_BMM` | Use DeepGEMM for Batched Matrix Multiplication (BMM) operations | `"false"` |
| `SGLANG_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` |
| `SGLANG_USE_DEEPGEMM_BMM` | Use DeepGEMM for Batched Matrix Multiplication (BMM) operations | `"false"` |
| `SGLANG_JIT_DEEPGEMM_FAST_WARMUP` | Precompile less kernels during warmup, which reduces the warmup time from 30min to less than 3min. Might cause performance degradation during runtime. | `"false"` |

## DeepEP Configuration

Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,14 @@ class Envs:
# DeepGemm
SGLANG_ENABLE_JIT_DEEPGEMM = EnvBool(True)
SGLANG_JIT_DEEPGEMM_PRECOMPILE = EnvBool(True)
SGLANG_JIT_DEEPGEMM_FAST_WARMUP = EnvBool(False)
SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS = EnvInt(4)
SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE = EnvBool(False)
SGLANG_DG_CACHE_DIR = EnvStr(os.path.expanduser("~/.cache/deep_gemm"))
SGLANG_DG_USE_NVRTC = EnvBool(False)
SGLANG_USE_DEEPGEMM_BMM = EnvBool(False)

# DeepSeek MHA Optimization
SGLANG_CHUNKED_PREFIX_CACHE_THRESHOLD = EnvInt(8192)

# DeepEP
Expand Down
46 changes: 38 additions & 8 deletions python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_DO_COMPILE_ALL = True
_IS_FIRST_RANK_ON_NODE = envs.SGLANG_IS_FIRST_RANK_ON_NODE.get()
_IN_PRECOMPILE_STAGE = envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.get()
_FAST_WARMUP = envs.SGLANG_JIT_DEEPGEMM_FAST_WARMUP.get()

# Force redirect deep_gemm cache_dir
os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
Expand All @@ -44,14 +45,43 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
global _DO_COMPILE_ALL
global _IS_FIRST_RANK_ON_NODE

# Generate m_max
m_max = 1024 * 16
if server_args.chunked_prefill_size < 1:
m_max = 1024 * 64
elif server_args.chunked_prefill_size > 8192:
m_max = server_args.chunked_prefill_size * 2
m_max = min(1024 * 128, m_max)
_BUILTIN_M_LIST = list(range(1, m_max + 1))
_BUILTIN_M_LIST = []

if _FAST_WARMUP:
# In fast warmup mode, only compile a small set of typical Ms

# First cover all the small bs to ensure decode performance
_BUILTIN_M_LIST += list(range(1, 1025))

# Then cover larger batch sizes with gradually increasing steps
# For example, when chunekd prefill size is 16384
# The sampled Ms would be:
# 1024, 1026, ... 2046 (step 2)
# 2048, 2052, ... 4092 (step 4)
# 4096, 5004, ... 8184 (step 8)
# 8192, 9008, ... 16384 (step 16)
# Totally 1024 + 1024 / 2 + 2048 / 4 + 4096 / 8 + 8192 / 16 = 3072 kernels
next_m, sample_step = 1024, 2
max_prefill_bs = (
min(server_args.chunked_prefill_size, 32 * 1024)
if server_args.chunked_prefill_size >= 1
else 16 * 1024
)
Comment thread
Fridge003 marked this conversation as resolved.
while next_m < max_prefill_bs:
_BUILTIN_M_LIST += list(range(next_m, 2 * next_m, sample_step))
next_m = next_m * 2
sample_step = sample_step * 2
_BUILTIN_M_LIST.append(max_prefill_bs)
_BUILTIN_M_LIST = sorted(list(set(_BUILTIN_M_LIST)))
else:
# When fast warmup isn't enabled, generate m_max and compile all the covered Ms.
m_max = 1024 * 16
if server_args.chunked_prefill_size < 1:
m_max = 1024 * 64
elif server_args.chunked_prefill_size > 8192:
m_max = server_args.chunked_prefill_size * 2
m_max = min(1024 * 128, m_max)
_BUILTIN_M_LIST += list(range(1, m_max + 1))

_IS_FIRST_RANK_ON_NODE = server_args.base_gpu_id == gpu_id

Expand Down
Loading