diff --git a/python/sglang/srt/layers/quantization/fp4_utils.py b/python/sglang/srt/layers/quantization/fp4_utils.py index a40074770fc9..3e913e137f02 100644 --- a/python/sglang/srt/layers/quantization/fp4_utils.py +++ b/python/sglang/srt/layers/quantization/fp4_utils.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from sglang.srt.environ import envs +from sglang.srt.utils.common import is_sm120_supported if TYPE_CHECKING: from sglang.srt.server_args import ServerArgs @@ -75,6 +76,19 @@ def initialize_fp4_gemm_config(server_args: ServerArgs) -> None: "Using server argument value." ) + if backend == "auto": + if is_sm120_supported(): + # flashinfer_cutlass produces NaN in dense MLP layers with + # heterogeneous batches on SM120 (Blackwell). cudnn is stable. + # See: https://github.com/sgl-project/sglang/issues/20043 + backend = "flashinfer_cudnn" + logger.info( + "SM120 (Blackwell) detected: auto-selecting " + "fp4-gemm-backend=flashinfer_cudnn" + ) + else: + backend = "flashinfer_cutlass" + FP4_GEMM_RUNNER_BACKEND = Fp4GemmRunnerBackend(backend) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4fe003cc69bc..b1f39af79d02 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -466,7 +466,7 @@ class ServerArgs: grammar_backend: Optional[str] = None mm_attention_backend: Optional[str] = None fp8_gemm_runner_backend: str = "auto" - fp4_gemm_runner_backend: str = "flashinfer_cutlass" + fp4_gemm_runner_backend: str = "auto" nsa_prefill_backend: Optional[str] = ( None # None = auto-detect based on hardware/kv_cache_dtype ) @@ -4308,8 +4308,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.fp4_gemm_runner_backend, dest="fp4_gemm_runner_backend", help="Choose the runner backend for NVFP4 GEMM operations. " - "Options: 'flashinfer_cutlass' (default), " - "'auto' (auto-selects between flashinfer_cudnn/flashinfer_cutlass based on CUDA/cuDNN version), " + "Options: 'auto' (default; selects flashinfer_cudnn on SM120, flashinfer_cutlass otherwise), " + "'flashinfer_cutlass' (CUTLASS backend), " "'flashinfer_cudnn' (FlashInfer cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), " "'flashinfer_trtllm' (FlashInfer TensorRT-LLM backend, requires different weight preparation with shuffling). " "NOTE: This replaces the deprecated environment variable "