Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion python/sglang/srt/layers/quantization/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
if is_cuda():
import deep_gemm
from deep_gemm import get_num_sms
from deep_gemm.jit.compiler import get_nvcc_compiler
from deep_gemm.jit_kernels.gemm import get_best_configs
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
from deep_gemm.jit_kernels.tuner import jit_tuner
Expand Down Expand Up @@ -48,7 +49,17 @@ def get_enable_jit_deepgemm():
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
# NVRTC may have performance loss with some cases.
# And NVCC JIT speed is also 9x faster in the ref commit
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
_USE_NVRTC_DEFAULT = "0"
if _ENABLE_JIT_DEEPGEMM:
try:
get_nvcc_compiler()
except:
logger.warning(
"NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
"and may have performance loss with some cases."
)
_USE_NVRTC_DEFAULT = "1"
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)


def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
Expand Down
Loading