Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sglang.srt.environ import envs
from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ceil_div, get_bool_env_var
from sglang.srt.utils import ceil_div, get_available_gpu_memory, get_bool_env_var

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -132,8 +132,35 @@ def _compile_deep_gemm_one_type_all(
m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout()
m_list = sorted(list(set(m for m in m_list if m % m_alignment == 0)))

# Here the precompilation is only run on the first rank, so gpu_id should be 0
memory_budget = get_available_gpu_memory(device="cuda", gpu_id=0)

# If the memory budget is less memory requirement, we need to reduce max_m to avoid out of memory, which might further cause hanging during warmup
max_m = max(m_list)
required_memory = _BaseWarmupExecutor.get_memory_requirement(
kernel_type, max_m=max_m, n=n, k=k, num_groups=num_groups
)
logger.info(
f"Required memory for warmup: {required_memory}GB, Available memory: {memory_budget}GB"
)
if memory_budget < required_memory:
# TODO: Maybe compute the max_m based on the memory budget
while (
_BaseWarmupExecutor.get_memory_requirement(
kernel_type, max_m=max_m, n=n, k=k, num_groups=num_groups
)
> memory_budget
and max_m > 4096
):
max_m = max_m // 2
logger.warning(
f"Available memory {memory_budget}GB is less than required memory {required_memory}GB for warmup, reducing max_m to {max_m} to avoid out of memory"
)
m_list = [m for m in m_list if m <= max_m]

# Need some methods to estimate needed memory for warmup
executor = _BaseWarmupExecutor.create(
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
kernel_type, max_m=max_m, n=n, k=k, num_groups=num_groups
)

old_compile_mode = deep_gemm.get_compile_mode()
Expand Down Expand Up @@ -161,6 +188,26 @@ def create(kernel_type: DeepGemmKernelType, **kwargs):
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: _GroupedMaskedWarmupExecutor,
}[kernel_type](**kwargs)

@staticmethod
def get_memory_requirement(
kernel_type: DeepGemmKernelType, max_m: int, n: int, k: int, num_groups: int
) -> int:
# Return the required memory space in GB for warmup executor
_GB = 1 << 30
if kernel_type == DeepGemmKernelType.GEMM_NT_F8F8BF16:
return (max_m * k + n * k + max_m * n * 2) / _GB
elif kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG:
return (max_m * k + num_groups * n * k + max_m * 4 + max_m * n * 2) / _GB
elif kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED:
return (
num_groups * max_m * k
+ num_groups * n * k
+ num_groups * 4
+ num_groups * max_m * n * 2
) / _GB
else:
raise ValueError(f"Invalid kernel type: {kernel_type}")

def execute(self, m):
raise NotImplementedError

Expand Down
10 changes: 0 additions & 10 deletions test/srt/ep/test_deepep_large.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import unittest
from types import SimpleNamespace

Expand Down Expand Up @@ -49,10 +48,6 @@ def setUpClass(cls):
"2048",
"--disable-radix-cache",
],
env={
**os.environ,
"SGLANG_JIT_DEEPGEMM_PRECOMPILE": "0",
},
)

@classmethod
Expand All @@ -75,7 +70,6 @@ def test_gsm8k(self):
self.assertGreater(metrics["accuracy"], 0.92)


@unittest.skip("Can pass locally, but will cause Timeout on CI runner.")
class TestDeepseekMTP(CustomTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -118,10 +112,6 @@ def setUpClass(cls):
"2",
"--disable-radix-cache",
],
env={
**os.environ,
"SGLANG_JIT_DEEPGEMM_PRECOMPILE": "0",
},
)

@classmethod
Expand Down
Loading