3434_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var (
3535 "SGL_JIT_DEEPGEMM_PRECOMPILE" , "true"
3636)
37- _DO_COMPILE = get_bool_env_var ("SGL_IS_FIRST_RANK_ON_NODE" , "true" )
37+ _DO_COMPILE_ALL = True
38+ _IS_FIRST_RANK_ON_NODE = get_bool_env_var ("SGL_IS_FIRST_RANK_ON_NODE" , "true" )
3839_COMPILE_WORKERS = get_int_env_var ("SGL_JIT_DEEPGEMM_COMPILE_WORKERS" , 4 )
39- _IN_PRE_COMPILE_STAGE = get_bool_env_var ("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE " , "false" )
40+ _IN_PRECOMPILE_STAGE = get_bool_env_var ("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE " , "false" )
4041
4142# Force redirect deep_gemm cache_dir
4243os .environ ["DG_CACHE_DIR" ] = os .getenv (
4647
4748def update_deep_gemm_config (gpu_id : int , server_args : ServerArgs ):
4849 global _BUILTIN_M_LIST
49- global _DO_COMPILE
50+ global _DO_COMPILE_ALL
51+ global _IS_FIRST_RANK_ON_NODE
5052
5153 # Generate m_max
5254 m_max = 1024 * 16
@@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
5759 m_max = min (1024 * 128 , m_max )
5860 _BUILTIN_M_LIST = list (range (1 , m_max + 1 ))
5961
60- # Check if is the first rank on node
61- _DO_COMPILE = ServerArgs .base_gpu_id == gpu_id
62+ _IS_FIRST_RANK_ON_NODE = ServerArgs .base_gpu_id == gpu_id
63+
64+ # Check if is the first rank on node.
65+ # Default each rank will try compile all Ms to
66+ # load all symbols at the launch stages.
67+ # Avoid loading symbols at the serving stages.
68+ _DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE
6269
6370
6471class DeepGemmKernelType (IntEnum ):
@@ -89,7 +96,7 @@ class DeepGemmKernelHelper:
8996
9097
9198def _compile_warning_1 ():
92- if not _IN_PRE_COMPILE_STAGE :
99+ if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE :
93100 logger .warning (
94101 "Entering DeepGEMM JIT Pre-Complie session. "
95102 "And it may takes a long time(Typically 10-20 mins) "
@@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all(
276283 query_key = (kernel_type , n , k , num_groups )
277284 if (
278285 _ENABLE_JIT_DEEPGEMM_PRECOMPILE
279- and _DO_COMPILE
286+ and _DO_COMPILE_ALL
280287 and _INITIALIZATION_DICT .get (query_key ) is None
281288 ):
282289 _INITIALIZATION_DICT [query_key ] = True
@@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all(
286293 logger .info (
287294 f"Try DeepGEMM JIT Compiling for "
288295 f"<{ kernel_helper .name } > N={ n } , K={ k } , num_groups={ num_groups } with all Ms."
289- f"{ ' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else '' } "
296+ f"{ ' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else '' } "
290297 )
291298
292299 # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
@@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16(
355362
356363@contextmanager
357364def _log_jit_build (M : int , N : int , K : int , kernel_type : DeepGemmKernelType ):
358- if _IN_PRE_COMPILE_STAGE :
365+ if _IN_PRECOMPILE_STAGE :
359366 yield
360367 return
361368
0 commit comments