diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 948860f8c431..ca4bad8045bf 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -67,6 +67,7 @@ from sglang.srt.distributed.parallel_state import destroy_distributed_environment from sglang.srt.entrypoints.engine import _set_envs_and_config from sglang.srt.layers.moe import initialize_moe_config +from sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.scheduler_dp_attn_mixin import prepare_mlp_sync_batch_raw @@ -647,6 +648,7 @@ def latency_test( ): initialize_moe_config(server_args) initialize_fp8_gemm_config(server_args) + initialize_fp4_gemm_config(server_args) # Set CPU affinity if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 356c66018f6b..a8cc419c850a 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -503,6 +503,10 @@ def _convert_SGL_to_SGLANG(): "SGLANG_SUPPORT_CUTLASS_BLOCK_FP8", "It will be completely removed in 0.5.7. Please use '--fp8-gemm-backend=cutlass' instead.", ) +_warn_deprecated_env_to_cli_flag( + "SGLANG_FLASHINFER_FP4_GEMM_BACKEND", + "It will be completely removed in 0.5.9. Please use '--fp4-gemm-backend' instead.", +) _warn_deprecated_env_to_cli_flag( "SGLANG_SCHEDULER_DECREASE_PREFILL_IDLE", "Please use '--enable-prefill-delayer' instead.", diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 89a77efc5c32..142d5aba2c11 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -15,8 +15,8 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) +from sglang.srt.layers.quantization.fp4_utils import get_fp4_gemm_runner_backend from sglang.srt.layers.quantization.modelopt_quant import ( - FLASHINFER_FP4_GEMM_BACKEND, enable_flashinfer_fp4_gemm, fp4_gemm, fp4_quantize, @@ -98,7 +98,7 @@ def process_weights_after_loading(self, layer) -> None: layer.weight_global_scale.max().to(torch.float32), requires_grad=False ) - if FLASHINFER_FP4_GEMM_BACKEND == "trtllm": + if get_fp4_gemm_runner_backend().is_trtllm(): # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. # FlashInfer provides nvfp4_quantize to quantize + shuffle the # layout but we use our own quantization so we have to call diff --git a/python/sglang/srt/layers/quantization/fp4_utils.py b/python/sglang/srt/layers/quantization/fp4_utils.py new file mode 100644 index 000000000000..e6e0ff0bb925 --- /dev/null +++ b/python/sglang/srt/layers/quantization/fp4_utils.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import logging +from enum import Enum +from typing import TYPE_CHECKING + +from sglang.srt.environ import envs + +if TYPE_CHECKING: + from sglang.srt.server_args import ServerArgs + +logger = logging.getLogger(__name__) + + +class Fp4GemmRunnerBackend(Enum): + """Enum for FP4 GEMM runner backend selection.""" + + AUTO = "auto" + CUDNN = "cudnn" + CUTLASS = "cutlass" + TRTLLM = "trtllm" + + def is_auto(self) -> bool: + return self == Fp4GemmRunnerBackend.AUTO + + def is_cudnn(self) -> bool: + return self == Fp4GemmRunnerBackend.CUDNN + + def is_cutlass(self) -> bool: + return self == Fp4GemmRunnerBackend.CUTLASS + + def is_trtllm(self) -> bool: + return self == Fp4GemmRunnerBackend.TRTLLM + + +FP4_GEMM_RUNNER_BACKEND: Fp4GemmRunnerBackend | None = None + + +def initialize_fp4_gemm_config(server_args: ServerArgs) -> None: + """Initialize FP4 GEMM configuration from server args.""" + global FP4_GEMM_RUNNER_BACKEND + + backend = server_args.fp4_gemm_runner_backend + + # Handle deprecated env var for backward compatibility + # TODO: Remove this in a future version + if envs.SGLANG_FLASHINFER_FP4_GEMM_BACKEND.is_set(): + env_backend = envs.SGLANG_FLASHINFER_FP4_GEMM_BACKEND.get() + if backend == "auto": + logger.warning( + "SGLANG_FLASHINFER_FP4_GEMM_BACKEND is deprecated. " + f"Please use '--fp4-gemm-backend={env_backend}' instead." + ) + backend = env_backend + else: + logger.warning( + f"FP4 GEMM backend set to '{backend}' via --fp4-gemm-backend overrides " + "environment variable SGLANG_FLASHINFER_FP4_GEMM_BACKEND. " + "Using server argument value." + ) + + FP4_GEMM_RUNNER_BACKEND = Fp4GemmRunnerBackend(backend) + + +def get_fp4_gemm_runner_backend() -> Fp4GemmRunnerBackend: + """Get the current FP4 GEMM runner backend.""" + global FP4_GEMM_RUNNER_BACKEND + if FP4_GEMM_RUNNER_BACKEND is None: + FP4_GEMM_RUNNER_BACKEND = Fp4GemmRunnerBackend.AUTO + return FP4_GEMM_RUNNER_BACKEND diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 9c7197346bbb..bea8c487ef69 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -30,6 +30,7 @@ QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.fp4_utils import get_fp4_gemm_runner_backend from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, @@ -126,7 +127,10 @@ def fp4_gemm( out_dtype: torch.dtype, out_features: int, ) -> torch.Tensor: - backend = FLASHINFER_FP4_GEMM_BACKEND if FLASHINFER_FP4_GEMM_BACKEND else "cutlass" + fp4_backend = get_fp4_gemm_runner_backend() + # TODO(shuw@nvidia.com): Remove the "cutlass" default override after flashinfer 0.6.0 + # and let flashinfer's auto backend selection handle it. + backend = fp4_backend.value if not fp4_backend.is_auto() else "cutlass" if enable_flashinfer_fp4_gemm: return flashinfer_fp4_gemm( input, weight, input_sf, weight_sf, alpha, out_dtype, backend=backend @@ -150,7 +154,6 @@ def _sgl_kernel_scaled_fp4_quant_fake( # TODO make it true by default when the DeepEP PR is merged MOE_NVFP4_DISPATCH = envs.SGLANG_MOE_NVFP4_DISPATCH.get() -FLASHINFER_FP4_GEMM_BACKEND = envs.SGLANG_FLASHINFER_FP4_GEMM_BACKEND.get() # Supported activation schemes for the current configuration ACTIVATION_SCHEMES = ["static"] @@ -1152,7 +1155,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.input_scale_inv = Parameter( (1 / input_scale_2).to(torch.float32), requires_grad=False ) - if FLASHINFER_FP4_GEMM_BACKEND == "trtllm": + if get_fp4_gemm_runner_backend().is_trtllm(): # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. # FlashInfer provides nvfp4_quantize to quantize + shuffle the # layout but we use our own quantization so we have to call @@ -1221,11 +1224,6 @@ def apply( if enable_flashinfer_fp4_gemm: w = layer.weight.T w_scale_interleaved = layer.weight_scale_interleaved.T - # TODO(shuw@nvidia.com) - # Remove the default after flashinfer bumped to 0.5.1 - backend = ( - FLASHINFER_FP4_GEMM_BACKEND if FLASHINFER_FP4_GEMM_BACKEND else "cutlass" - ) out = fp4_gemm( x_fp4, w, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5471bfeed505..dd585fc8bce4 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -65,6 +65,7 @@ get_attention_tp_group, ) from sglang.srt.layers.moe import initialize_moe_config +from sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config from sglang.srt.managers.io_struct import ( AbortReq, @@ -473,10 +474,9 @@ def init_moe_gemm_config(self): if hasattr(self.model_config.hf_config, "num_experts_per_tok"): initialize_moe_config(self.server_args) - # Initialize GEMM-related configuration (currently FP8 Blockwise GEMM backend). - # Other GEMM backends (e.g. FP4, BF16, etc.) can be added here in the future. - # This is needed for FP8 quantization. + # Initialize GEMM-related configuration for FP8 and FP4 backends. initialize_fp8_gemm_config(self.server_args) + initialize_fp4_gemm_config(self.server_args) # This must be called after initialize_moe_config self.require_mlp_sync = require_mlp_sync(self.server_args) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 75b79a15b6dd..d07dbc7ee80d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -192,6 +192,13 @@ "aiter", ] +FP4_GEMM_RUNNER_BACKEND_CHOICES = [ + "auto", + "cudnn", + "cutlass", + "trtllm", +] + MAMBA_SSM_DTYPE_CHOICES = ["float32", "bfloat16"] MAMBA_SCHEDULER_STRATEGY_CHOICES = ["auto", "no_buffer", "extra_buffer"] @@ -226,6 +233,10 @@ def add_fp8_gemm_runner_backend_choices(choices): FP8_GEMM_RUNNER_BACKEND_CHOICES.extend(choices) +def add_fp4_gemm_runner_backend_choices(choices): + FP4_GEMM_RUNNER_BACKEND_CHOICES.extend(choices) + + def add_deterministic_attention_backend_choices(choices): DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices) @@ -422,6 +433,7 @@ class ServerArgs: grammar_backend: Optional[str] = None mm_attention_backend: Optional[str] = None fp8_gemm_runner_backend: str = "auto" + fp4_gemm_runner_backend: str = "auto" nsa_prefill_backend: str = "flashmla_sparse" nsa_decode_backend: str = "fa3" disable_flashinfer_autotune: bool = False @@ -3521,6 +3533,20 @@ def add_cli_args(parser: argparse.ArgumentParser): "NOTE: This replaces the deprecated environment variables " "SGLANG_ENABLE_FLASHINFER_FP8_GEMM and SGLANG_SUPPORT_CUTLASS_BLOCK_FP8.", ) + parser.add_argument( + "--fp4-gemm-backend", + type=str, + choices=FP4_GEMM_RUNNER_BACKEND_CHOICES, + default=ServerArgs.fp4_gemm_runner_backend, + dest="fp4_gemm_runner_backend", + help="Choose the runner backend for NVFP4 GEMM operations. " + "Options: 'auto' (default, selects between cudnn/cutlass based on CUDA/cuDNN version), " + "'cudnn' (cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), " + "'cutlass' (CUTLASS backend, optimal on CUDA 12), " + "'trtllm' (TensorRT-LLM backend, requires different weight preparation with shuffling). " + "NOTE: This replaces the deprecated environment variable " + "SGLANG_FLASHINFER_FP4_GEMM_BACKEND.", + ) parser.add_argument( "--disable-flashinfer-autotune", default=ServerArgs.disable_flashinfer_autotune, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index fbc7c8154476..98df290dd8ac 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -35,7 +35,7 @@ TestFile("test_deepseek_v3_fp4_4gpu.py", 1500), TestFile("test_fp8_blockwise_gemm.py", 280), TestFile("test_gpt_oss_4gpu.py", 700), - TestFile("test_llama31_fp4.py", 90), + TestFile("test_nvfp4_gemm.py", 360), ], # "per-commit-8-gpu-b200": [ # TestFile("test_mistral_large3_basic.py", 275), # Moved to nightly - large model diff --git a/test/srt/test_llama31_fp4.py b/test/srt/test_nvfp4_gemm.py similarity index 62% rename from test/srt/test_llama31_fp4.py rename to test/srt/test_nvfp4_gemm.py index be1b04648420..e0b8fa9916ff 100644 --- a/test/srt/test_llama31_fp4.py +++ b/test/srt/test_nvfp4_gemm.py @@ -8,21 +8,27 @@ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, + try_cached_model, ) -MODEL_PATH = "nvidia/Llama-3.1-8B-Instruct-FP4" +MODEL_PATH = "nvidia/Llama-3.1-8B-Instruct-NVFP4" -@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") -class TestLlama31FP4(unittest.TestCase): +class FP4GemmBase: + backend = None + @classmethod def setUpClass(cls): - cls.model = MODEL_PATH + if cls.backend is None: + raise NotImplementedError("Subclass must set 'backend' attribute") + cls.model = try_cached_model(MODEL_PATH) cls.base_url = DEFAULT_URL_FOR_TEST other_args = [ "--trust-remote-code", "--quantization", "modelopt_fp4", + "--fp4-gemm-backend", + cls.backend, ] cls.process = popen_launch_server( cls.model, @@ -52,5 +58,25 @@ def test_gsm8k(self): self.assertGreater(metrics["accuracy"], 0.64) +@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") +class TestFP4GemmAuto(FP4GemmBase, unittest.TestCase): + backend = "auto" + + +@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") +class TestFP4GemmCutlass(FP4GemmBase, unittest.TestCase): + backend = "cutlass" + + +@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") +class TestFP4GemmCudnn(FP4GemmBase, unittest.TestCase): + backend = "cudnn" + + +@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") +class TestFP4GemmTrtllm(FP4GemmBase, unittest.TestCase): + backend = "trtllm" + + if __name__ == "__main__": unittest.main()