Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 33 additions & 18 deletions python/sglang/srt/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
is_cuda,
is_flashinfer_available,
is_gfx95_supported,
is_hip,
is_sm90_supported,
is_sm100_supported,
is_sm120_supported,
is_triton_kernels_available,
mxfp_supported,
next_power_of_2,
Expand All @@ -52,8 +52,6 @@
from sglang.srt.utils.common import get_bool_env_var
from sglang.srt.utils.custom_op import register_custom_op

_is_sm100_supported = is_cuda() and is_sm100_supported()
_is_sm90_supported = is_cuda() and is_sm90_supported()
has_triton_kernels = is_triton_kernels_available()


Expand Down Expand Up @@ -140,23 +138,40 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout

value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
mx_axis=1
)
scale_layout, scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps
)
if _is_sm100_supported:
constraints = {
"is_persistent": True,
"epilogue_subtile": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
elif _is_sm90_supported:
if is_sm120_supported():
# SM120 (Blackwell desktop) doesn't support persistent kernels / TMA block layout
# Use StridedLayout and disable persistent kernels to avoid assertion errors
from triton_kernels.tensor_details.layout import StridedLayout

value_layout = StridedLayout
value_layout_opts = {}
scale_layout = StridedLayout
scale_layout_opts = {}
constraints = {
"split_k": 1,
"is_persistent": False,
"num_stages": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
else:
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
mx_axis=1
)
scale_layout, scale_layout_opts = (
layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps
)
)
if is_sm100_supported():
constraints = {
"is_persistent": True,
"epilogue_subtile": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
elif is_sm90_supported():
constraints = {
"split_k": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
# transpose the tensor so that the quantization axis is on dim1
quant_tensor = quant_tensor.transpose(-2, -1)
scale = scale.transpose(-2, -1)
Expand Down Expand Up @@ -324,7 +339,7 @@ def create_weights(
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
if _is_sm100_supported:
if is_sm100_supported():
if self.use_flashinfer:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256
Expand Down
10 changes: 8 additions & 2 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,10 +1486,16 @@ def _handle_model_specific_adjustments(self):
self.dtype = "bfloat16"

if self.moe_runner_backend == "auto":
if is_blackwell_supported() and is_mxfp4_quant_format:
if is_sm100_supported() and is_mxfp4_quant_format:
self.moe_runner_backend = "flashinfer_mxfp4"
logger.warning(
"Detected Blackwell and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
elif is_sm120_supported() and is_mxfp4_quant_format:
# trtllm-gen only supports SM100
self.moe_runner_backend = "triton_kernel"
logger.warning(
"Detected SM120 and MXFP4 quantization format for GPT-OSS model, enabling triton_kernel MOE kernel."
)
elif (
is_hip() and get_bool_env_var("SGLANG_USE_AITER")
Expand Down
Loading