Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config
from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler_dp_attn_mixin import prepare_mlp_sync_batch_raw
Expand Down Expand Up @@ -647,6 +648,7 @@ def latency_test(
):
initialize_moe_config(server_args)
initialize_fp8_gemm_config(server_args)
initialize_fp4_gemm_config(server_args)

# Set CPU affinity
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,10 @@ def _convert_SGL_to_SGLANG():
"SGLANG_SUPPORT_CUTLASS_BLOCK_FP8",
"It will be completely removed in 0.5.7. Please use '--fp8-gemm-backend=cutlass' instead.",
)
_warn_deprecated_env_to_cli_flag(
"SGLANG_FLASHINFER_FP4_GEMM_BACKEND",
"It will be completely removed in 0.5.9. Please use '--fp4-gemm-backend' instead.",
)
_warn_deprecated_env_to_cli_flag(
"SGLANG_SCHEDULER_DECREASE_PREFILL_IDLE",
"Please use '--enable-prefill-delayer' instead.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from sglang.srt.layers.quantization.fp4_utils import get_fp4_gemm_runner_backend
from sglang.srt.layers.quantization.modelopt_quant import (
FLASHINFER_FP4_GEMM_BACKEND,
enable_flashinfer_fp4_gemm,
fp4_gemm,
fp4_quantize,
Expand Down Expand Up @@ -98,7 +98,7 @@ def process_weights_after_loading(self, layer) -> None:
layer.weight_global_scale.max().to(torch.float32), requires_grad=False
)

if FLASHINFER_FP4_GEMM_BACKEND == "trtllm":
if get_fp4_gemm_runner_backend().is_trtllm():
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
Expand Down
70 changes: 70 additions & 0 deletions python/sglang/srt/layers/quantization/fp4_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

import logging
from enum import Enum
from typing import TYPE_CHECKING

from sglang.srt.environ import envs

if TYPE_CHECKING:
from sglang.srt.server_args import ServerArgs

logger = logging.getLogger(__name__)


class Fp4GemmRunnerBackend(Enum):
"""Enum for FP4 GEMM runner backend selection."""

AUTO = "auto"
CUDNN = "cudnn"
CUTLASS = "cutlass"
TRTLLM = "trtllm"

def is_auto(self) -> bool:
return self == Fp4GemmRunnerBackend.AUTO

def is_cudnn(self) -> bool:
return self == Fp4GemmRunnerBackend.CUDNN

def is_cutlass(self) -> bool:
return self == Fp4GemmRunnerBackend.CUTLASS

def is_trtllm(self) -> bool:
return self == Fp4GemmRunnerBackend.TRTLLM


FP4_GEMM_RUNNER_BACKEND: Fp4GemmRunnerBackend | None = None


def initialize_fp4_gemm_config(server_args: ServerArgs) -> None:
"""Initialize FP4 GEMM configuration from server args."""
global FP4_GEMM_RUNNER_BACKEND

backend = server_args.fp4_gemm_runner_backend

# Handle deprecated env var for backward compatibility
# TODO: Remove this in a future version
if envs.SGLANG_FLASHINFER_FP4_GEMM_BACKEND.is_set():
env_backend = envs.SGLANG_FLASHINFER_FP4_GEMM_BACKEND.get()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deduplicate line48-line53 here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consolidated 👍

if backend == "auto":
logger.warning(
"SGLANG_FLASHINFER_FP4_GEMM_BACKEND is deprecated. "
f"Please use '--fp4-gemm-backend={env_backend}' instead."
)
backend = env_backend
else:
logger.warning(
f"FP4 GEMM backend set to '{backend}' via --fp4-gemm-backend overrides "
"environment variable SGLANG_FLASHINFER_FP4_GEMM_BACKEND. "
"Using server argument value."
)

FP4_GEMM_RUNNER_BACKEND = Fp4GemmRunnerBackend(backend)


def get_fp4_gemm_runner_backend() -> Fp4GemmRunnerBackend:
"""Get the current FP4 GEMM runner backend."""
global FP4_GEMM_RUNNER_BACKEND
if FP4_GEMM_RUNNER_BACKEND is None:
FP4_GEMM_RUNNER_BACKEND = Fp4GemmRunnerBackend.AUTO
return FP4_GEMM_RUNNER_BACKEND
14 changes: 6 additions & 8 deletions python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp4_utils import get_fp4_gemm_runner_backend
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
Expand Down Expand Up @@ -126,7 +127,10 @@ def fp4_gemm(
out_dtype: torch.dtype,
out_features: int,
) -> torch.Tensor:
backend = FLASHINFER_FP4_GEMM_BACKEND if FLASHINFER_FP4_GEMM_BACKEND else "cutlass"
fp4_backend = get_fp4_gemm_runner_backend()
# TODO(shuw@nvidia.com): Remove the "cutlass" default override after flashinfer 0.6.0
# and let flashinfer's auto backend selection handle it.
backend = fp4_backend.value if not fp4_backend.is_auto() else "cutlass"
if enable_flashinfer_fp4_gemm:
return flashinfer_fp4_gemm(
input, weight, input_sf, weight_sf, alpha, out_dtype, backend=backend
Expand All @@ -150,7 +154,6 @@ def _sgl_kernel_scaled_fp4_quant_fake(

# TODO make it true by default when the DeepEP PR is merged
MOE_NVFP4_DISPATCH = envs.SGLANG_MOE_NVFP4_DISPATCH.get()
FLASHINFER_FP4_GEMM_BACKEND = envs.SGLANG_FLASHINFER_FP4_GEMM_BACKEND.get()
# Supported activation schemes for the current configuration
ACTIVATION_SCHEMES = ["static"]

Expand Down Expand Up @@ -1152,7 +1155,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.input_scale_inv = Parameter(
(1 / input_scale_2).to(torch.float32), requires_grad=False
)
if FLASHINFER_FP4_GEMM_BACKEND == "trtllm":
if get_fp4_gemm_runner_backend().is_trtllm():
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
Expand Down Expand Up @@ -1221,11 +1224,6 @@ def apply(
if enable_flashinfer_fp4_gemm:
w = layer.weight.T
w_scale_interleaved = layer.weight_scale_interleaved.T
# TODO(shuw@nvidia.com)
# Remove the default after flashinfer bumped to 0.5.1
backend = (
FLASHINFER_FP4_GEMM_BACKEND if FLASHINFER_FP4_GEMM_BACKEND else "cutlass"
)
out = fp4_gemm(
x_fp4,
w,
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
get_attention_tp_group,
)
from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config
from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config
from sglang.srt.managers.io_struct import (
AbortReq,
Expand Down Expand Up @@ -473,10 +474,9 @@ def init_moe_gemm_config(self):
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
initialize_moe_config(self.server_args)

# Initialize GEMM-related configuration (currently FP8 Blockwise GEMM backend).
# Other GEMM backends (e.g. FP4, BF16, etc.) can be added here in the future.
# This is needed for FP8 quantization.
# Initialize GEMM-related configuration for FP8 and FP4 backends.
initialize_fp8_gemm_config(self.server_args)
initialize_fp4_gemm_config(self.server_args)

# This must be called after initialize_moe_config
self.require_mlp_sync = require_mlp_sync(self.server_args)
Expand Down
26 changes: 26 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@
"aiter",
]

FP4_GEMM_RUNNER_BACKEND_CHOICES = [
"auto",
"cudnn",
"cutlass",
"trtllm",
]

MAMBA_SSM_DTYPE_CHOICES = ["float32", "bfloat16"]

MAMBA_SCHEDULER_STRATEGY_CHOICES = ["auto", "no_buffer", "extra_buffer"]
Expand Down Expand Up @@ -226,6 +233,10 @@ def add_fp8_gemm_runner_backend_choices(choices):
FP8_GEMM_RUNNER_BACKEND_CHOICES.extend(choices)


def add_fp4_gemm_runner_backend_choices(choices):
FP4_GEMM_RUNNER_BACKEND_CHOICES.extend(choices)


def add_deterministic_attention_backend_choices(choices):
DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices)

Expand Down Expand Up @@ -422,6 +433,7 @@ class ServerArgs:
grammar_backend: Optional[str] = None
mm_attention_backend: Optional[str] = None
fp8_gemm_runner_backend: str = "auto"
fp4_gemm_runner_backend: str = "auto"
nsa_prefill_backend: str = "flashmla_sparse"
nsa_decode_backend: str = "fa3"
disable_flashinfer_autotune: bool = False
Expand Down Expand Up @@ -3521,6 +3533,20 @@ def add_cli_args(parser: argparse.ArgumentParser):
"NOTE: This replaces the deprecated environment variables "
"SGLANG_ENABLE_FLASHINFER_FP8_GEMM and SGLANG_SUPPORT_CUTLASS_BLOCK_FP8.",
)
parser.add_argument(
"--fp4-gemm-backend",
type=str,
choices=FP4_GEMM_RUNNER_BACKEND_CHOICES,
default=ServerArgs.fp4_gemm_runner_backend,
dest="fp4_gemm_runner_backend",
help="Choose the runner backend for NVFP4 GEMM operations. "
"Options: 'auto' (default, selects between cudnn/cutlass based on CUDA/cuDNN version), "
"'cudnn' (cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), "
"'cutlass' (CUTLASS backend, optimal on CUDA 12), "
"'trtllm' (TensorRT-LLM backend, requires different weight preparation with shuffling). "
"NOTE: This replaces the deprecated environment variable "
"SGLANG_FLASHINFER_FP4_GEMM_BACKEND.",
)
parser.add_argument(
"--disable-flashinfer-autotune",
default=ServerArgs.disable_flashinfer_autotune,
Expand Down
2 changes: 1 addition & 1 deletion test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
TestFile("test_deepseek_v3_fp4_4gpu.py", 1500),
TestFile("test_fp8_blockwise_gemm.py", 280),
TestFile("test_gpt_oss_4gpu.py", 700),
TestFile("test_llama31_fp4.py", 90),
TestFile("test_nvfp4_gemm.py", 360),
],
# "per-commit-8-gpu-b200": [
# TestFile("test_mistral_large3_basic.py", 275), # Moved to nightly - large model
Expand Down
34 changes: 30 additions & 4 deletions test/srt/test_llama31_fp4.py → test/srt/test_nvfp4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,27 @@
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
try_cached_model,
)

MODEL_PATH = "nvidia/Llama-3.1-8B-Instruct-FP4"
MODEL_PATH = "nvidia/Llama-3.1-8B-Instruct-NVFP4"


@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
class TestLlama31FP4(unittest.TestCase):
class FP4GemmBase:
backend = None

@classmethod
def setUpClass(cls):
cls.model = MODEL_PATH
if cls.backend is None:
raise NotImplementedError("Subclass must set 'backend' attribute")
cls.model = try_cached_model(MODEL_PATH)
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--trust-remote-code",
"--quantization",
"modelopt_fp4",
"--fp4-gemm-backend",
cls.backend,
]
cls.process = popen_launch_server(
cls.model,
Expand Down Expand Up @@ -52,5 +58,25 @@ def test_gsm8k(self):
self.assertGreater(metrics["accuracy"], 0.64)


@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
class TestFP4GemmAuto(FP4GemmBase, unittest.TestCase):
backend = "auto"


@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
class TestFP4GemmCutlass(FP4GemmBase, unittest.TestCase):
backend = "cutlass"


@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
class TestFP4GemmCudnn(FP4GemmBase, unittest.TestCase):
backend = "cudnn"


@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
class TestFP4GemmTrtllm(FP4GemmBase, unittest.TestCase):
backend = "trtllm"


if __name__ == "__main__":
unittest.main()
Loading