Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ class Envs:
# DeepGemm
SGLANG_ENABLE_JIT_DEEPGEMM = EnvBool(True)
SGLANG_JIT_DEEPGEMM_PRECOMPILE = EnvBool(True)
SGLANG_JIT_DEEPGEMM_FAST_WARMUP = EnvBool(False)
Comment thread
Fridge003 marked this conversation as resolved.
SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS = EnvInt(4)
SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE = EnvBool(False)
SGLANG_DG_CACHE_DIR = EnvStr(os.path.expanduser("~/.cache/deep_gemm"))
Expand Down
58 changes: 46 additions & 12 deletions python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import deep_gemm


_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
_BUILTIN_M_LIST: List[int] = []
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get()
_DO_COMPILE_ALL = True
_IS_FIRST_RANK_ON_NODE = envs.SGLANG_IS_FIRST_RANK_ON_NODE.get()
Expand All @@ -44,14 +44,43 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
global _DO_COMPILE_ALL
global _IS_FIRST_RANK_ON_NODE

# Generate m_max
m_max = 1024 * 16
if server_args.chunked_prefill_size < 1:
m_max = 1024 * 64
elif server_args.chunked_prefill_size > 8192:
m_max = server_args.chunked_prefill_size * 2
m_max = min(1024 * 128, m_max)
_BUILTIN_M_LIST = list(range(1, m_max + 1))
_BUILTIN_M_LIST = []

if envs.SGLANG_JIT_DEEPGEMM_FAST_WARMUP.get():
# In fast warmup mode, only compile a small set of typical Ms

# First cover all the small bs to ensure decode performance
_BUILTIN_M_LIST += list(range(1, 1025))

# Then cover larger batch sizes with gradually increasing steps
# For example, when chunked prefill size is 16384
# The sampled Ms would be:
# 1024, 1026, ... 2046 (step 2)
# 2048, 2052, ... 4092 (step 4)
# 4096, 5004, ... 8184 (step 8)
# 8192, 9008, ... 16384 (step 16)
# Totally 1024 + 1024/2 + 2048/4 + 4096/8 + 8192/16 = 3072 kernels
next_m, sample_step = 1024, 2
max_prefill_bs = (
min(server_args.chunked_prefill_size, 32 * 1024)
if server_args.chunked_prefill_size >= 1
else 16 * 1024
)
while next_m < max_prefill_bs:
_BUILTIN_M_LIST += list(range(next_m, min(2 * next_m, max_prefill_bs), sample_step))
next_m = next_m * 2
sample_step = sample_step * 2
_BUILTIN_M_LIST.append(max_prefill_bs)
_BUILTIN_M_LIST = sorted(set(_BUILTIN_M_LIST))
else:
# When fast warmup isn't enabled, generate m_max and compile all the covered Ms.
m_max = 1024 * 16
if server_args.chunked_prefill_size < 1:
m_max = 1024 * 64
elif server_args.chunked_prefill_size > 8192:
m_max = server_args.chunked_prefill_size * 2
m_max = min(1024 * 128, m_max)
_BUILTIN_M_LIST += list(range(1, m_max + 1))

_IS_FIRST_RANK_ON_NODE = server_args.base_gpu_id == gpu_id

Expand Down Expand Up @@ -163,12 +192,17 @@ def _compile_deep_gemm_one_type_all(
kernel_type, max_m=max_m, n=n, k=k, num_groups=num_groups
)

old_compile_mode = deep_gemm.get_compile_mode()
deep_gemm.set_compile_mode(1)
has_compile_mode_api = hasattr(deep_gemm, "get_compile_mode") and hasattr(
deep_gemm, "set_compile_mode"
)
if has_compile_mode_api:
old_compile_mode = deep_gemm.get_compile_mode()
deep_gemm.set_compile_mode(1)
# TODO can use multi thread
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
executor.execute(m=m)
deep_gemm.set_compile_mode(old_compile_mode)
if has_compile_mode_api:
deep_gemm.set_compile_mode(old_compile_mode)

# clean up input buffers
torch.cuda.current_stream().synchronize()
Expand Down
Loading