diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 7182e3d57ba6..217b3fce78e8 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -495,20 +495,44 @@ def _process_mxfp8_linear_weight_scale(self, layer: Module) -> None: return if get_fp8_gemm_runner_backend().is_flashinfer_trtllm(): + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a + + weight = layer.weight.data + scale_u8 = layer.weight_scale_inv.data + n, k = weight.shape + epilogue_tile_m = 128 + + copy_or_rebind_param( + layer, + "weight", + shuffle_matrix_a( + weight.contiguous().view(torch.uint8), epilogue_tile_m + ).view(torch.float8_e4m3fn), + ) + copy_or_rebind_param( + layer, + "weight_scale_inv", + shuffle_matrix_sf_a( + scale_u8.contiguous().view(torch.uint8).reshape(n, k // 32), + epilogue_tile_m, + num_elts_per_sf=32, + ) + .reshape_as(scale_u8) + .contiguous(), + ) + elif get_fp8_gemm_runner_backend().is_flashinfer_cutlass(): from flashinfer import block_scale_interleave scale_u8 = layer.weight_scale_inv.data - new_swizzled = block_scale_interleave(scale_u8.contiguous()).contiguous() + copy_or_rebind_param( + layer, + "weight_scale_inv", + block_scale_interleave(scale_u8.contiguous()).contiguous(), + ) else: # Triton path consumes canonical 2D UE8M0 scales directly. return - copy_or_rebind_param(layer, "weight_scale_inv_swizzled", new_swizzled) - layer._weight_scale_inv_swizzled_src_version = layer.weight_scale_inv._version - layer._weight_scale_inv_swizzled_src_data_ptr = ( - layer.weight_scale_inv.data_ptr() - ) - def _quantize_mxfp8_weights(self, layer: Module) -> None: weight = layer.weight.data qweight, weight_scale = mxfp8_group_quantize(weight) @@ -657,22 +681,18 @@ def apply( ) if self.use_mxfp8: - if get_fp8_gemm_runner_backend().is_flashinfer_trtllm(): - weight_scale = layer.weight_scale_inv_swizzled - else: - weight_scale = layer.weight_scale_inv if isinstance(x, tuple): return self.w8a8_mxfp8_linear( input=x[0], weight=layer.weight, - weight_scale=weight_scale, + weight_scale=layer.weight_scale_inv, input_scale=x[1], bias=bias, ) return self.w8a8_mxfp8_linear( input=x, weight=layer.weight, - weight_scale=weight_scale, + weight_scale=layer.weight_scale_inv, input_scale=None, bias=bias, ) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index f5bcf975186e..695e2bfb9633 100755 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -214,6 +214,7 @@ def _check_cutlass_block_fp8_hardware_support() -> bool: if is_blackwell_supported() and is_flashinfer_available(): + from flashinfer import SfLayout from flashinfer import mm_mxfp8 as _raw_flashinfer_mm_mxfp8 from flashinfer import mxfp8_quantize as _raw_flashinfer_mxfp8_quantize from flashinfer.gemm import gemm_fp8_nt_groupwise as _raw_gemm_fp8_nt_groupwise @@ -303,12 +304,13 @@ def flashinfer_mxfp8_quantize( input, is_sf_swizzled_layout=is_sf_swizzled_layout, alignment=alignment, + sf_swizzle_layout=SfLayout.layout_128x4, ) @register_custom_op( op_name="flashinfer_mm_mxfp8", mutates_args=[], - fake_impl=lambda q_input, weight_t, x_scale_u8, weight_scale_t, out_dtype, backend="auto": ( + fake_impl=lambda q_input, weight_t, x_scale_u8, weight_scale_t, out_dtype, use_8x4_sf_layout=False, backend="auto": ( q_input.new_empty((q_input.shape[0], weight_t.shape[1]), dtype=out_dtype) ), ) @@ -318,6 +320,7 @@ def flashinfer_mm_mxfp8( x_scale_u8: torch.Tensor, weight_scale_t: torch.Tensor, out_dtype: torch.dtype, + use_8x4_sf_layout: bool = False, backend: str = "auto", ) -> torch.Tensor: return _raw_flashinfer_mm_mxfp8( @@ -326,6 +329,7 @@ def flashinfer_mm_mxfp8( x_scale_u8, weight_scale_t, out_dtype=out_dtype, + use_8x4_sf_layout=use_8x4_sf_layout, backend=backend, ) @@ -357,11 +361,13 @@ def dispatch_w8a8_mxfp8_linear() -> Callable: """Dispatch MXFP8 linear kernel by --fp8-gemm-backend. For MXFP8, Triton remains the default path. We only route to FlashInfer - when backend is explicitly set to flashinfer_trtllm. + when backend is explicitly set to flashinfer_cutlass or flashinfer_trtllm. """ backend = get_fp8_gemm_runner_backend() if backend.is_flashinfer_trtllm(): return flashinfer_mxfp8_blockscaled_linear + elif backend.is_flashinfer_cutlass(): + return flashinfer_mxfp8_blockscaled_linear return triton_mxfp8_blockscaled_linear @@ -962,6 +968,7 @@ def flashinfer_mxfp8_blockscaled_linear( ) else: q_input = input_2d + x_scale_u8 = input_scale.contiguous() if output_dtype is None: if input_2d.dtype in (torch.float16, torch.bfloat16, torch.float32): @@ -971,19 +978,34 @@ def flashinfer_mxfp8_blockscaled_linear( # Ensure transposed tensors are contiguous for FlashInfer's internal runner. weight_t = weight.contiguous().t() - weight_scale_t = ( - weight_scale.contiguous().t() - if weight_scale.ndim == 2 - else weight_scale.contiguous() - ) - output = flashinfer_mm_mxfp8( - q_input, - weight_t, - x_scale_u8, - weight_scale_t, - out_dtype=output_dtype, - backend="auto", - ) + + if get_fp8_gemm_runner_backend().is_flashinfer_trtllm(): + + weight_scale_t = weight_scale.contiguous().view(-1) + output = flashinfer_mm_mxfp8( + q_input, + weight_t, + x_scale_u8, + weight_scale_t, + out_dtype=output_dtype, + use_8x4_sf_layout=False, + backend="trtllm", + ) + elif get_fp8_gemm_runner_backend().is_flashinfer_cutlass(): + weight_scale_t = ( + weight_scale.contiguous().t() + if weight_scale.ndim == 2 + else weight_scale.contiguous() + ) + output = flashinfer_mm_mxfp8( + q_input, + weight_t, + x_scale_u8, + weight_scale_t, + out_dtype=output_dtype, + use_8x4_sf_layout=False, + backend="cutlass", + ) if bias is not None: output += bias diff --git a/test/registered/quant/test_fp8_blockwise_gemm.py b/test/registered/quant/test_fp8_blockwise_gemm.py index 19d179d27958..48a819051126 100644 --- a/test/registered/quant/test_fp8_blockwise_gemm.py +++ b/test/registered/quant/test_fp8_blockwise_gemm.py @@ -131,5 +131,10 @@ class TestMXFP8GemmFlashinferTrtllm(MXFP8GemmBase, unittest.TestCase): backend = "flashinfer_trtllm" +@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") +class TestMXFP8GemmFlashinferCutlass(MXFP8GemmBase, unittest.TestCase): + backend = "flashinfer_cutlass" + + if __name__ == "__main__": unittest.main()