Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | `triton` | `triton`, `csgmv` |
| `--max-lora-chunk-size` | Maximum chunk size for the ChunkedSGMV LoRA backend. Only used when `--lora-backend` is `csgmv`. Larger values may improve performance. | `16` | `16`, `32`, `64`, `128` |

## Kernel backend
## Kernel Backends (Attention, Sampling, Grammar, GEMM)
| Argument | Description | Defaults | Options |
| --- | --- | --- | --- |
| `--attention-backend` | Choose the kernels for attention layers. | `None` | `triton`, `torch_native`, `flex_attention`, `nsa`, `cutlass_mla`, `fa3`, `fa4`, `flashinfer`, `flashmla`, `trtllm_mla`, `trtllm_mha`, `dual_chunk_flash_attn`, `aiter`, `wave`, `intel_amx`, `ascend` |
Expand All @@ -245,6 +245,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--mm-attention-backend` | Set multimodal attention backend. | `None` | `sdpa`, `fa3`, `triton_attn`, `ascend_attn`, `aiter_attn` |
| `--nsa-prefill` | Choose the NSA backend for the prefill stage (overrides `--attention-backend` when running DeepSeek NSA-style attention). | `flashmla_sparse` | `flashmla_sparse`, `flashmla_decode`, `fa3`, `tilelang`, `aiter` |
| `--nsa-decode` | Choose the NSA backend for the decode stage when running DeepSeek NSA-style attention. Overrides `--attention-backend` for decoding. | `flashmla_kv` | `flashmla_prefill`, `flashmla_kv`, `fa3`, `tilelang`, `aiter` |
| `--fp8-gemm-backend` | Choose the runner backend for Blockwise FP8 GEMM operations. Options: 'auto' (default, auto-selects based on hardware), 'deep_gemm' (JIT-compiled; enabled by default on NVIDIA Hopper (SM90) and Blackwell (SM100) when DeepGEMM is installed), 'flashinfer_trtllm' (optimal for Blackwell and low-latency), 'cutlass' (optimal for Hopper/Blackwell GPUs and high-throughput), 'triton' (fallback, widely compatible), 'aiter' (ROCm only). **NOTE**: This replaces the deprecated environment variables SGLANG_ENABLE_FLASHINFER_FP8_GEMM and SGLANG_SUPPORT_CUTLASS_BLOCK_FP8. | `auto` | `auto`, `deep_gemm`, `flashinfer_trtllm`, `cutlass`, `triton`, `aiter` |

## Speculative decoding
| Argument | Description | Defaults | Options |
Expand Down
6 changes: 3 additions & 3 deletions docs/references/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ SGLang supports various environment variables that can be used to configure its

| Environment Variable | Description | Default Value |
| --- | --- | --- |
| `SGLANG_ENABLE_JIT_DEEPGEMM` | Enable Just-In-Time compilation of DeepGEMM kernels | `"true"` |
| `SGLANG_ENABLE_JIT_DEEPGEMM` | Enable Just-In-Time compilation of DeepGEMM kernels (enabled by default on NVIDIA Hopper (SM90) and Blackwell (SM100) GPUs when the DeepGEMM package is installed; set to `"0"` to disable) | `"true"` |
| `SGLANG_JIT_DEEPGEMM_PRECOMPILE` | Enable precompilation of DeepGEMM kernels | `"true"` |
| `SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS` | Number of workers for parallel DeepGEMM kernel compilation | `4` |
| `SGL_IN_DEEPGEMM_PRECOMPILE_STAGE` | Indicator flag used during the DeepGEMM precompile script | `"false"` |
Expand Down Expand Up @@ -78,9 +78,9 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` |
| `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` |
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
| `SGLANG_ENABLE_FLASHINFER_FP8_GEMM` | Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs | `false` |
| `SGLANG_ENABLE_FLASHINFER_FP8_GEMM` (deprecated) | Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs. **DEPRECATED**: Please use `--fp8-gemm-backend=flashinfer_trtllm` instead. | `false` |
Comment thread
Fridge003 marked this conversation as resolved.
| `SGLANG_FLASHINFER_FP4_GEMM_BACKEND` | Select backend for `mm_fp4` on Blackwell GPUS | `` |
| `SGLANG_SUPPORT_CUTLASS_BLOCK_FP8` | Use Cutlass kernels when running blockwise fp8 GEMM on Hopper or Blackwell GPUs | `false` |
| `SGLANG_SUPPORT_CUTLASS_BLOCK_FP8` (deprecated) | Use Cutlass kernels when running blockwise fp8 GEMM on Hopper or Blackwell GPUs. **DEPRECATED**: Please use `--fp8-gemm-backend=cutlass` instead. | `false` |
| `SGLANG_CUTLASS_MOE` (deprecated) | Use Cutlass FP8 MoE kernel on Blackwell GPUs (deprecated, use --moe-runner-backend=cutlass) | `false` |


Expand Down
2 changes: 2 additions & 0 deletions python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler_dp_attn_mixin import prepare_mlp_sync_batch_raw
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
Expand Down Expand Up @@ -631,6 +632,7 @@ def latency_test(
tp_rank,
):
initialize_moe_config(server_args)
initialize_fp8_gemm_config(server_args)

# Set CPU affinity
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
Expand Down
22 changes: 22 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,15 @@ def _print_deprecated_env(new_name: str, old_name: str):
os.environ[new_name] = os.environ[old_name]


def _warn_deprecated_env_to_cli_flag(env_name: str, suggestion: str):
"""Warn when a deprecated environment variable is used.

This is for env vars that are deprecated in favor of CLI flags.
"""
if env_name in os.environ:
warnings.warn(f"Environment variable {env_name} is deprecated. {suggestion}")


def _convert_SGL_to_SGLANG():
_print_deprecated_env("SGLANG_LOG_GC", "SGLANG_GC_LOG")
_print_deprecated_env(
Expand All @@ -388,6 +397,19 @@ def _convert_SGL_to_SGLANG():

_convert_SGL_to_SGLANG()

_warn_deprecated_env_to_cli_flag(
"SGLANG_ENABLE_FLASHINFER_FP8_GEMM",
"It will be completely removed in 0.5.7. Please use '--fp8-gemm-backend=flashinfer_trtllm' instead.",
)
_warn_deprecated_env_to_cli_flag(
"SGLANG_ENABLE_FLASHINFER_GEMM",
"It will be completely removed in 0.5.7. Please use '--fp8-gemm-backend=flashinfer_trtllm' instead.",
)
_warn_deprecated_env_to_cli_flag(
"SGLANG_SUPPORT_CUTLASS_BLOCK_FP8",
"It will be completely removed in 0.5.7. Please use '--fp8-gemm-backend=cutlass' instead.",
)


def example_with_exit_stack():
# Use this style of context manager in unit test
Expand Down
184 changes: 160 additions & 24 deletions python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from typing import Callable, List, Optional, Tuple
from __future__ import annotations

import logging
from enum import Enum
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple

import torch

from sglang.srt.environ import envs
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.utils import ceil_div, is_blackwell_supported, offloader

if TYPE_CHECKING:
from sglang.srt.server_args import ServerArgs

try:
from vllm import _custom_ops as ops
Expand All @@ -29,14 +35,20 @@
)
from sglang.srt.utils import (
ceil_align,
ceil_div,
get_bool_env_var,
get_cuda_version,
get_device_capability,
is_blackwell_supported,
is_cuda,
is_flashinfer_available,
is_hip,
is_sm90_supported,
offloader,
)

logger = logging.getLogger(__name__)

_is_hip = is_hip()
_is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
Expand Down Expand Up @@ -125,42 +137,166 @@ def normalize_e4m3fn_to_e4m3fnuz(
return weight, weight_scale, input_scale


# TODO(ch-wan): define these backends in --moe-runner-backend
def cutlass_block_fp8_supported() -> bool:
if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
return False
if _is_cuda:
major, minor = torch.cuda.get_device_capability()
sm_version = major * 10 + minor
cuda_version = tuple(map(int, torch.version.cuda.split(".")))
if cuda_version >= (12, 0) and sm_version >= 90:
return True
return False
class Fp8GemmRunnerBackend(Enum):
"""Enum for FP8 GEMM runner backend selection."""

AUTO = "auto"
FLASHINFER = "flashinfer_trtllm"
CUTLASS = "cutlass"
DEEP_GEMM = "deep_gemm"
TRITON = "triton"
AITER = "aiter"

CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
ENABLE_FLASHINFER_FP8_GEMM = (
envs.SGLANG_ENABLE_FLASHINFER_FP8_GEMM.get()
and is_blackwell_supported()
and is_flashinfer_available()
)
if ENABLE_FLASHINFER_FP8_GEMM:
def is_auto(self) -> bool:
return self == Fp8GemmRunnerBackend.AUTO

def is_flashinfer(self) -> bool:
return self == Fp8GemmRunnerBackend.FLASHINFER

def is_cutlass(self) -> bool:
return self == Fp8GemmRunnerBackend.CUTLASS

def is_deep_gemm(self) -> bool:
return self == Fp8GemmRunnerBackend.DEEP_GEMM

def is_triton(self) -> bool:
return self == Fp8GemmRunnerBackend.TRITON

def is_aiter(self) -> bool:
return self == Fp8GemmRunnerBackend.AITER


FP8_GEMM_RUNNER_BACKEND: Fp8GemmRunnerBackend | None = None


def _check_cutlass_block_fp8_hardware_support() -> bool:
"""Return True if CUTLASS block FP8 is supported (Hopper or newer with CUDA 12.0+)."""
return is_sm90_supported() or is_blackwell_supported()


if is_blackwell_supported() and is_flashinfer_available():
from flashinfer.gemm import gemm_fp8_nt_groupwise


def dispatch_w8a8_block_fp8_linear() -> Callable:
if ENABLE_FLASHINFER_FP8_GEMM:
"""
Dispatch to the appropriate FP8 block linear implementation.

This function selects the backend based on:
1. The --fp8-gemm-backend server argument (preferred)
2. Auto-detection based on hardware capabilities
"""
backend = get_fp8_gemm_runner_backend()

# Handle explicit backend selection via --fp8-gemm-backend
if not backend.is_auto():
return _dispatch_explicit_backend(backend)

# Auto mode: Select based purely on hardware/backend availability
return _dispatch_auto_backend()


def _dispatch_explicit_backend(backend: Fp8GemmRunnerBackend) -> Callable:
"""Dispatch based on explicitly selected backend."""
if backend.is_flashinfer():
if not (is_blackwell_supported() and is_flashinfer_available()):
raise RuntimeError(
"FlashInfer FP8 GEMM requested via --fp8-gemm-backend=flashinfer_trtllm, "
"but FlashInfer is not available or not supported on this hardware. "
"FlashInfer FP8 GEMM requires Blackwell GPUs and FlashInfer to be installed."
)
return flashinfer_gemm_w8a8_block_fp8_linear
elif CUTLASS_BLOCK_FP8_SUPPORTED:

elif backend.is_cutlass():
if not _check_cutlass_block_fp8_hardware_support():
raise RuntimeError(
"CUTLASS block FP8 requested via --fp8-gemm-backend=cutlass, "
"but hardware does not support it. CUTLASS block FP8 requires "
"Hopper (SM90+) GPUs with CUDA 12.0+."
)
return cutlass_w8a8_block_fp8_linear_with_fallback
elif _use_aiter:

elif backend.is_aiter():
if not _use_aiter:
raise RuntimeError(
"AITER backend requested via --fp8-gemm-backend=aiter, "
"but AITER is not available. AITER requires AMD GPUs with "
"SGLANG_USE_AITER=1 environment variable set."
)
return aiter_w8a8_block_fp8_linear
elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:

elif backend.is_deep_gemm():
if not deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
raise RuntimeError(
"DeepGEMM backend requested via --fp8-gemm-backend=deep_gemm, "
"but DeepGEMM is not available. This usually means the deep_gemm package "
"is not installed or has been disabled via SGLANG_ENABLE_JIT_DEEPGEMM=0."
)
return deepgemm_w8a8_block_fp8_linear_with_fallback

elif backend.is_triton():
return triton_w8a8_block_fp8_linear

else:
raise ValueError(f"Unknown FP8 GEMM backend: {backend}")


def _dispatch_auto_backend() -> Callable:
"""Auto-select the best backend based on hardware capabilities."""
# Priority order for auto selection:
# 1. DeepGEMM (if enabled and available)
# 2. FlashInfer TRTLLM (if Blackwell GPU and FlashInfer available)
# 3. CUTLASS (if Hopper+ GPU and CUDA 12.0+)
# 4. AITER (if AMD GPU with AITER enabled)
# 5. Triton (fallback)

Comment thread
Fridge003 marked this conversation as resolved.
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
Comment thread
Fridge003 marked this conversation as resolved.
return deepgemm_w8a8_block_fp8_linear_with_fallback
elif is_blackwell_supported() and is_flashinfer_available():
return flashinfer_gemm_w8a8_block_fp8_linear
elif _check_cutlass_block_fp8_hardware_support():
return cutlass_w8a8_block_fp8_linear_with_fallback
elif _use_aiter:
return aiter_w8a8_block_fp8_linear
else:
return triton_w8a8_block_fp8_linear


def initialize_fp8_gemm_config(server_args: ServerArgs) -> None:
"""Initialize FP8 GEMM configuration."""
global FP8_GEMM_RUNNER_BACKEND

backend = server_args.fp8_gemm_runner_backend

# TODO(brayden): Remove env-based overrides in v0.5.7, they will be fully removed in v0.5.7.
# Only check environment variables when the server args is not set, server args should take priority.
if backend == "auto":
if envs.SGLANG_ENABLE_FLASHINFER_FP8_GEMM.get():
backend = "flashinfer_trtllm"
elif envs.SGLANG_SUPPORT_CUTLASS_BLOCK_FP8.get():
backend = "cutlass"
else:
if (
envs.SGLANG_ENABLE_FLASHINFER_FP8_GEMM.get()
or envs.SGLANG_SUPPORT_CUTLASS_BLOCK_FP8.get()
):
logger.warning(
f"FP8 GEMM backend set to '{backend}' via --fp8-gemm-backend overrides "
"environment variables SGLANG_ENABLE_FLASHINFER_FP8_GEMM and "
"SGLANG_SUPPORT_CUTLASS_BLOCK_FP8. Using server argument value."
)

FP8_GEMM_RUNNER_BACKEND = Fp8GemmRunnerBackend(backend)


def get_fp8_gemm_runner_backend() -> Fp8GemmRunnerBackend:
"""Get the current FP8 GEMM runner backend."""
global FP8_GEMM_RUNNER_BACKEND
if FP8_GEMM_RUNNER_BACKEND is None:
FP8_GEMM_RUNNER_BACKEND = Fp8GemmRunnerBackend.AUTO
return FP8_GEMM_RUNNER_BACKEND


def flashinfer_gemm_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config
from sglang.srt.managers.io_struct import (
AbortReq,
BaseBatchReq,
Expand Down Expand Up @@ -305,6 +306,9 @@ def __init__(
# Init moe config
self.init_moe_config()

# Init GEMM config (FP8 GEMM, etc.)
self.init_gemm_config()

# Check whether overlap can be enabled
if not self.is_generation:
self.enable_overlap = False
Expand Down Expand Up @@ -966,6 +970,12 @@ def init_moe_config(self):
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
initialize_moe_config(self.server_args)

def init_gemm_config(self):
# Initialize GEMM-related configuration (currently FP8 Blockwise GEMM backend).
# Other GEMM backends (e.g. FP4, BF16, etc.) can be added here in the future.
# This is needed for FP8 quantization.
initialize_fp8_gemm_config(self.server_args)

@DynamicGradMode()
def event_loop_normal(self):
"""A normal scheduler loop."""
Expand Down
Loading
Loading