Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,20 +495,44 @@ def _process_mxfp8_linear_weight_scale(self, layer: Module) -> None:
return

if get_fp8_gemm_runner_backend().is_flashinfer_trtllm():
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

weight = layer.weight.data
scale_u8 = layer.weight_scale_inv.data
n, k = weight.shape
epilogue_tile_m = 128

copy_or_rebind_param(
layer,
"weight",
shuffle_matrix_a(
weight.contiguous().view(torch.uint8), epilogue_tile_m
).view(torch.float8_e4m3fn),
)
copy_or_rebind_param(
layer,
"weight_scale_inv",
shuffle_matrix_sf_a(
scale_u8.contiguous().view(torch.uint8).reshape(n, k // 32),
epilogue_tile_m,
num_elts_per_sf=32,
)
.reshape_as(scale_u8)
.contiguous(),
)
elif get_fp8_gemm_runner_backend().is_flashinfer_cutlass():
from flashinfer import block_scale_interleave

scale_u8 = layer.weight_scale_inv.data
new_swizzled = block_scale_interleave(scale_u8.contiguous()).contiguous()
copy_or_rebind_param(
layer,
"weight_scale_inv",
block_scale_interleave(scale_u8.contiguous()).contiguous(),
)
else:
# Triton path consumes canonical 2D UE8M0 scales directly.
return

copy_or_rebind_param(layer, "weight_scale_inv_swizzled", new_swizzled)
layer._weight_scale_inv_swizzled_src_version = layer.weight_scale_inv._version
layer._weight_scale_inv_swizzled_src_data_ptr = (
layer.weight_scale_inv.data_ptr()
)

def _quantize_mxfp8_weights(self, layer: Module) -> None:
weight = layer.weight.data
qweight, weight_scale = mxfp8_group_quantize(weight)
Expand Down Expand Up @@ -657,22 +681,18 @@ def apply(
)

if self.use_mxfp8:
if get_fp8_gemm_runner_backend().is_flashinfer_trtllm():
weight_scale = layer.weight_scale_inv_swizzled
else:
weight_scale = layer.weight_scale_inv
if isinstance(x, tuple):
return self.w8a8_mxfp8_linear(
input=x[0],
weight=layer.weight,
weight_scale=weight_scale,
weight_scale=layer.weight_scale_inv,
input_scale=x[1],
bias=bias,
)
return self.w8a8_mxfp8_linear(
input=x,
weight=layer.weight,
weight_scale=weight_scale,
weight_scale=layer.weight_scale_inv,
input_scale=None,
bias=bias,
)
Expand Down
52 changes: 37 additions & 15 deletions python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def _check_cutlass_block_fp8_hardware_support() -> bool:


if is_blackwell_supported() and is_flashinfer_available():
from flashinfer import SfLayout
from flashinfer import mm_mxfp8 as _raw_flashinfer_mm_mxfp8
from flashinfer import mxfp8_quantize as _raw_flashinfer_mxfp8_quantize
from flashinfer.gemm import gemm_fp8_nt_groupwise as _raw_gemm_fp8_nt_groupwise
Expand Down Expand Up @@ -303,12 +304,13 @@ def flashinfer_mxfp8_quantize(
input,
is_sf_swizzled_layout=is_sf_swizzled_layout,
alignment=alignment,
sf_swizzle_layout=SfLayout.layout_128x4,
)

@register_custom_op(
op_name="flashinfer_mm_mxfp8",
mutates_args=[],
fake_impl=lambda q_input, weight_t, x_scale_u8, weight_scale_t, out_dtype, backend="auto": (
fake_impl=lambda q_input, weight_t, x_scale_u8, weight_scale_t, out_dtype, use_8x4_sf_layout=False, backend="auto": (
q_input.new_empty((q_input.shape[0], weight_t.shape[1]), dtype=out_dtype)
),
)
Expand All @@ -318,6 +320,7 @@ def flashinfer_mm_mxfp8(
x_scale_u8: torch.Tensor,
weight_scale_t: torch.Tensor,
out_dtype: torch.dtype,
use_8x4_sf_layout: bool = False,
backend: str = "auto",
) -> torch.Tensor:
return _raw_flashinfer_mm_mxfp8(
Expand All @@ -326,6 +329,7 @@ def flashinfer_mm_mxfp8(
x_scale_u8,
weight_scale_t,
out_dtype=out_dtype,
use_8x4_sf_layout=use_8x4_sf_layout,
backend=backend,
)

Expand Down Expand Up @@ -357,11 +361,13 @@ def dispatch_w8a8_mxfp8_linear() -> Callable:
"""Dispatch MXFP8 linear kernel by --fp8-gemm-backend.

For MXFP8, Triton remains the default path. We only route to FlashInfer
when backend is explicitly set to flashinfer_trtllm.
when backend is explicitly set to flashinfer_cutlass or flashinfer_trtllm.
"""
backend = get_fp8_gemm_runner_backend()
if backend.is_flashinfer_trtllm():
return flashinfer_mxfp8_blockscaled_linear
elif backend.is_flashinfer_cutlass():
return flashinfer_mxfp8_blockscaled_linear
return triton_mxfp8_blockscaled_linear


Expand Down Expand Up @@ -962,6 +968,7 @@ def flashinfer_mxfp8_blockscaled_linear(
)
else:
q_input = input_2d
x_scale_u8 = input_scale.contiguous()

if output_dtype is None:
if input_2d.dtype in (torch.float16, torch.bfloat16, torch.float32):
Expand All @@ -971,19 +978,34 @@ def flashinfer_mxfp8_blockscaled_linear(

# Ensure transposed tensors are contiguous for FlashInfer's internal runner.
weight_t = weight.contiguous().t()
weight_scale_t = (
weight_scale.contiguous().t()
if weight_scale.ndim == 2
else weight_scale.contiguous()
)
output = flashinfer_mm_mxfp8(
q_input,
weight_t,
x_scale_u8,
weight_scale_t,
out_dtype=output_dtype,
backend="auto",
)

if get_fp8_gemm_runner_backend().is_flashinfer_trtllm():

weight_scale_t = weight_scale.contiguous().view(-1)
output = flashinfer_mm_mxfp8(
q_input,
weight_t,
x_scale_u8,
weight_scale_t,
out_dtype=output_dtype,
use_8x4_sf_layout=False,
backend="trtllm",
)
elif get_fp8_gemm_runner_backend().is_flashinfer_cutlass():
weight_scale_t = (
weight_scale.contiguous().t()
if weight_scale.ndim == 2
else weight_scale.contiguous()
)
output = flashinfer_mm_mxfp8(
q_input,
weight_t,
x_scale_u8,
weight_scale_t,
out_dtype=output_dtype,
use_8x4_sf_layout=False,
backend="cutlass",
)

if bias is not None:
output += bias
Expand Down
5 changes: 5 additions & 0 deletions test/registered/quant/test_fp8_blockwise_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,10 @@ class TestMXFP8GemmFlashinferTrtllm(MXFP8GemmBase, unittest.TestCase):
backend = "flashinfer_trtllm"


@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
class TestMXFP8GemmFlashinferCutlass(MXFP8GemmBase, unittest.TestCase):
backend = "flashinfer_cutlass"


if __name__ == "__main__":
unittest.main()
Loading