Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 57 additions & 67 deletions python/sglang/srt/layers/quantization/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@
import deep_gemm
from deep_gemm import get_num_sms
from deep_gemm.jit_kernels.gemm import get_best_configs
from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes
from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
from deep_gemm.jit_kernels.m_grouped_gemm import (
template as deep_gemm_grouped_gemm_template,
)
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
from deep_gemm.jit_kernels.tuner import jit_tuner

sm_version = get_device_sm()
Expand All @@ -45,10 +41,15 @@ def get_enable_jit_deepgemm():
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")

# Force redirect deep_gemm cache_dir
os.environ["DG_CACHE_DIR"] = os.getenv(
"SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm"
os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
"SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
)

# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
# NVRTC may have performance loss with some cases.
# And NVCC JIT speed is also 9x faster in the ref commit
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")


def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
global _BUILTIN_M_LIST
Expand Down Expand Up @@ -130,10 +131,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
num_groups: int,
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
) -> None:
# Auto-tuning with compilation
global deep_gemm_includes, deep_gemm_grouped_gemm_template
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
_ = jit_tuner.compile_and_tune(
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
kwargs = {
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"BLOCK_K": block_k,
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
_, _ = jit_tuner.compile_and_tune(
name="m_grouped_gemm_fp8_fp8_bf16_nt",
keys={
"N": n,
Expand All @@ -146,24 +155,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"GEMM_TYPE": "GroupedMasked",
"GEMM_TYPE": GemmType.GroupedMasked,
},
space=(),
includes=deep_gemm_includes,
arg_defs=(
("lhs", torch.float8_e4m3fn),
("lhs_scales", torch.float),
("rhs", torch.float8_e4m3fn),
("rhs_scales", torch.float),
("out", torch.bfloat16),
("grouped_layout", torch.int32),
("m", int),
("stream", torch.cuda.Stream),
("num_sms", int),
("smem_size", int),
),
template=deep_gemm_grouped_gemm_template,
args=[],
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)


Expand All @@ -173,9 +169,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
num_groups: int,
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
) -> None:
global deep_gemm_includes, deep_gemm_grouped_gemm_template
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
_ = jit_tuner.compile_and_tune(
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
kwargs = {
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"BLOCK_K": block_k,
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
_, _ = jit_tuner.compile_and_tune(
name="m_grouped_gemm_fp8_fp8_bf16_nt",
keys={
"N": n,
Expand All @@ -188,25 +193,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"GEMM_TYPE": "GroupedContiguous",
"GEMM_TYPE": GemmType.GroupedContiguous,
},
space=(),
includes=deep_gemm_includes,
arg_defs=(
("lhs", torch.float8_e4m3fn),
("lhs_scales", torch.float),
("rhs", torch.float8_e4m3fn),
("rhs_scales", torch.float),
("out", torch.bfloat16),
("grouped_layout", torch.int32),
("m", int),
("num_groups", int),
("stream", torch.cuda.Stream),
("num_sms", int),
("smem_size", int),
),
template=deep_gemm_grouped_gemm_template,
args=[],
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)


Expand All @@ -216,9 +207,20 @@ def _compile_gemm_nt_f8f8bf16_one(
_: int, # _ is a dummy parameter to align with other interfaces
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
) -> None:
global deep_gemm_includes, deep_gemm_gemm_template
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
_ = jit_tuner.compile_and_tune(
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
kwargs = {
"GEMM_TYPE": GemmType.Normal,
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"NUM_GROUPS": 1,
"BLOCK_K": block_k,
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
_, _ = jit_tuner.compile_and_tune(
name="gemm_fp8_fp8_bf16_nt",
keys={
"N": n,
Expand All @@ -232,20 +234,8 @@ def _compile_gemm_nt_f8f8bf16_one(
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
},
space=(),
includes=deep_gemm_includes,
arg_defs=(
("lhs", torch.float8_e4m3fn),
("lhs_scales", torch.float),
("rhs", torch.float8_e4m3fn),
("rhs_scales", torch.float),
("out", torch.bfloat16),
("m", int),
("stream", torch.cuda.Stream),
("num_sms", int),
("smem_size", int),
),
template=deep_gemm_gemm_template,
args=[],
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)


Expand Down Expand Up @@ -373,7 +363,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):

from deep_gemm.jit.runtime import RuntimeCache

origin_func = RuntimeCache.__getitem__
origin_func = RuntimeCache.get

def __patched_func(self, *args, **kwargs):
ret = origin_func(self, *args, **kwargs)
Expand All @@ -385,6 +375,6 @@ def __patched_func(self, *args, **kwargs):
)
return ret

RuntimeCache.__getitem__ = __patched_func
RuntimeCache.get = __patched_func
yield
RuntimeCache.__getitem__ = origin_func
RuntimeCache.get = origin_func
Loading