Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ runtime_common = [

srt = [
"sglang[runtime_common]",
"sgl-kernel==0.1.5",
"sgl-kernel==0.1.6",
"flashinfer_python==0.2.5",
"torch==2.6.0",
"torchvision==0.21.0",
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda:
assert_pkg_version(
"sgl-kernel",
"0.1.5",
"0.1.6",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
)

Expand Down
98 changes: 42 additions & 56 deletions python/sglang/srt/layers/quantization/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
try:
import deep_gemm
from deep_gemm import get_num_sms
from deep_gemm.jit import build
from deep_gemm.jit.compiler import get_nvcc_compiler
from deep_gemm.jit_kernels.gemm import get_best_configs
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
from deep_gemm.jit_kernels.tuner import jit_tuner

sm_version = get_device_sm()
if sm_version == 90:
Expand Down Expand Up @@ -148,32 +148,28 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128

kwargs = {
"GEMM_TYPE": GemmType.GroupedMasked,
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"N": n,
"K": k,
"NUM_GROUPS": 1,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
_, _ = jit_tuner.compile_and_tune(
name="m_grouped_gemm_fp8_fp8_bf16_nt",
keys={
"N": n,
"K": k,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_GROUPS": num_groups,
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"GEMM_TYPE": GemmType.GroupedMasked,
},
space=(),
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)

code = FP8GemmRuntime.generate(kwargs)
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
Comment on lines 152 to +172
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

It appears NUM_GROUPS is hardcoded to 1 here. However, the function _compile_grouped_gemm_nt_f8f8bf16_masked_one receives a num_groups parameter, which is derived from input tensor shapes and can be greater than 1 (see usage in _maybe_compile_deep_gemm_one_type_all and its caller grouped_gemm_nt_f8f8bf16_masked).

The corresponding configure_func for GROUPED_GEMM_NT_F8F8BF16_MASKED also uses this dynamic num_groups:
lambda m, n, k, num_groups, num_sms: get_best_configs(m, n, k, num_groups, num_sms, is_grouped_masked=True)

If num_groups can be other than 1, hardcoding NUM_GROUPS: 1 when generating/building the kernel could lead to incorrect kernel compilation and subsequent runtime errors or incorrect results.

Should NUM_GROUPS here use the num_groups parameter passed to the function, similar to how it was used in the keys dictionary of the previous jit_tuner.compile_and_tune call?

Style Guide References

Suggested change
kwargs = {
"GEMM_TYPE": GemmType.GroupedMasked,
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"N": n,
"K": k,
"NUM_GROUPS": 1,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
_, _ = jit_tuner.compile_and_tune(
name="m_grouped_gemm_fp8_fp8_bf16_nt",
keys={
"N": n,
"K": k,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_GROUPS": num_groups,
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"GEMM_TYPE": GemmType.GroupedMasked,
},
space=(),
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)
code = FP8GemmRuntime.generate(kwargs)
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
kwargs = {
"GEMM_TYPE": GemmType.GroupedMasked,
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"N": n,
"K": k,
"NUM_GROUPS": num_groups,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}



def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
Expand All @@ -187,31 +183,26 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
num_tma_threads = 128
num_math_threads_per_group = 128
kwargs = {
"GEMM_TYPE": GemmType.GroupedContiguous,
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"N": n,
"K": k,
"NUM_GROUPS": 1,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
_, _ = jit_tuner.compile_and_tune(
name="m_grouped_gemm_fp8_fp8_bf16_nt",
keys={
"N": n,
"K": k,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_GROUPS": num_groups,
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"GEMM_TYPE": GemmType.GroupedContiguous,
},
space=(),
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)

code = FP8GemmRuntime.generate(kwargs)
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)


def _compile_gemm_nt_f8f8bf16_one(
Expand All @@ -228,28 +219,23 @@ def _compile_gemm_nt_f8f8bf16_one(
"GEMM_TYPE": GemmType.Normal,
"NUM_TMA_THREADS": num_tma_threads,
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
"N": n,
"K": k,
"NUM_GROUPS": 1,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config[0],
}
_, _ = jit_tuner.compile_and_tune(
name="gemm_fp8_fp8_bf16_nt",
keys={
"N": n,
"K": k,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
},
space=(),
kwargs=kwargs,
runtime_cls=FP8GemmRuntime,
)

code = FP8GemmRuntime.generate(kwargs)
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)


_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
Expand Down
Loading