diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 1991c6935dcc..6045e29431ad 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -70,6 +70,7 @@ MXFP8_VALUE_DTYPE, Mxfp8LinearBackend, Mxfp8LinearOp, + swizzle_mxfp8_scale, ) from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( apply_nvfp4_linear, @@ -1689,9 +1690,9 @@ def __init__(self, quant_config: ModelOptMxFp8Config) -> None: "Dynamic quantization is not supported." ) - backend: Mxfp8LinearBackend = Mxfp8LinearBackend.EMULATION - self.mxfp8_linear_op = Mxfp8LinearOp(backend=backend) - logger.info_once("Using %s backend for MXFP8 GEMM", backend.value) + self.backend: Mxfp8LinearBackend = Mxfp8LinearBackend.FLASHINFER_CUTLASS + self.mxfp8_linear_op = Mxfp8LinearOp(backend=self.backend) + logger.info_once("Using %s backend for MXFP8 GEMM", self.backend.value) def create_weights( self, @@ -1749,7 +1750,38 @@ def create_weights( ) layer.register_parameter("weight_scale", weight_scale) + def _process_weights_after_loading_scale_2d(self, layer: torch.nn.Module) -> None: + """Not swizzled - MXFP8 GEMM emulation""" + weight = layer.weight.data # [N, K] + N, K = weight.shape + scale_k = K // MXFP8_BLOCK_SIZE + + # Slice weight_scale to match weight dimensions (handles padding) + weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous() + + layer.weight = Parameter(weight.contiguous(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + def _process_weights_after_loading_scale_1d(self, layer: torch.nn.Module) -> None: + """Swizzled - MXFP8 GEMM Flashinfer CUTLASS""" + weight = layer.weight.data # [N, K] + N, K = weight.shape + + # 2D weight scale + weight_scale = layer.weight_scale.data + + # Swizzle the weight scales + scale_k = K // MXFP8_BLOCK_SIZE + weight_scale_2d = weight_scale[:N, :scale_k].contiguous() + weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K) + + layer.weight = Parameter(weight.contiguous(), requires_grad=False) + layer.weight_scale = Parameter( + weight_scale_swizzled.contiguous(), requires_grad=False + ) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Validate weight tensor if layer.weight.ndim != 2: raise ValueError( f"MXFP8 weight must be 2D tensor [N, K], got {layer.weight.ndim}D " @@ -1763,15 +1795,23 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: f"quantized with MXFP8." ) - weight = layer.weight.data # [N, K] - N, K = weight.shape - scale_k = K // MXFP8_BLOCK_SIZE + # Validate weight scale tensor (should be 2D, not swizzled) + assert layer.weight_scale.ndim == 2, ( + f"MXFP8 weight scale must be 2D, got {layer.weight_scale.ndim}D" + ) + assert layer.weight_scale.dtype == MXFP8_SCALE_DTYPE, ( + f"MXFP8 weight scale must be {MXFP8_SCALE_DTYPE}," + f" got {layer.weight_scale.dtype}" + ) - # Slice weight_scale to match weight dimensions (handles padding) - weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous() + if self.backend == Mxfp8LinearBackend.EMULATION: + # Swizzled layout is not used + self._process_weights_after_loading_scale_2d(layer) + return - layer.weight = Parameter(weight.contiguous(), requires_grad=False) - layer.weight_scale = Parameter(weight_scale, requires_grad=False) + assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS + # Swizzled layout is required for Flashinfer CUTLASS + self._process_weights_after_loading_scale_1d(layer) def apply( self, diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py index 9f0e0c0a4d8e..ee849b167aba 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py @@ -6,6 +6,7 @@ import torch from vllm.logger import init_logger +from vllm.utils import flashinfer as vllm_flashinfer from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -13,6 +14,7 @@ class Mxfp8LinearBackend(Enum): EMULATION = "emulation" + FLASHINFER_CUTLASS = "flashinfer-cutlass" # MXFP8 constants @@ -21,6 +23,30 @@ class Mxfp8LinearBackend(Enum): MXFP8_BLOCK_SIZE = 32 +def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor: + """Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout.""" + scaling_vector_size = MXFP8_BLOCK_SIZE # 32 for MXFP8 + factor = scaling_vector_size * 4 # 128 + + num_m_tiles = (M + 127) // 128 + num_k_tiles = (K + factor - 1) // factor + + m_padded = num_m_tiles * 128 + k_scale_padded = num_k_tiles * 4 + + scale_cols = K // scaling_vector_size + sf_padded = torch.zeros( + (m_padded, k_scale_padded), dtype=sf.dtype, device=sf.device + ) + sf_padded[:M, :scale_cols] = sf + + sf_reshaped = sf_padded.view(num_m_tiles, 4, 32, num_k_tiles, 4) + + sf_swizzled = sf_reshaped.transpose(1, 3) + + return sf_swizzled.contiguous().view(-1) + + def _mxfp8_e4m3_quantize_impl( x: torch.Tensor, is_sf_swizzled_layout: bool = False ) -> tuple[torch.Tensor, torch.Tensor]: @@ -108,7 +134,7 @@ def __init__(self, backend: Mxfp8LinearBackend): self.backend = backend - def apply( + def _apply_emulation( self, input: torch.Tensor, weight: torch.Tensor, @@ -132,3 +158,79 @@ def apply( output = torch.nn.functional.linear(input, weight_bf16, bias) return output.to(out_dtype) + + def _apply_flashinfer_cutlass( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: torch.dtype, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + N, K = weight.shape + + input_shape = input.shape + input_2d = input.view(-1, K) + M_orig = input_2d.shape[0] + + # Minimum dimension size for F8_128x4 block scaling layout + min_dim = 128 + + assert min_dim <= K, ( + f"mm_mxfp8 requires K >= {min_dim}, got K={K}. " + f"in_features is too small for mm_mxfp8." + ) + assert K % MXFP8_BLOCK_SIZE == 0, ( + f"mm_mxfp8 requires K to be divisible by {MXFP8_BLOCK_SIZE}, got K={K}." + ) + assert min_dim <= N, ( + f"mm_mxfp8 requires N >= {min_dim}, got N={N}. " + f"out_features is too small for mm_mxfp8." + ) + + M_padded = ((M_orig + min_dim - 1) // min_dim) * min_dim + if M_padded != M_orig: + pad_rows = M_padded - M_orig + input_2d = torch.nn.functional.pad(input_2d, (0, 0, 0, pad_rows)) + + input_mxfp8, input_scale = mxfp8_e4m3_quantize( + input_2d, + is_sf_swizzled_layout=True, # Swizzled for best accuracy + ) + + if not weight.is_contiguous(): + weight = weight.contiguous() + + output = vllm_flashinfer.mm_mxfp8( + input_mxfp8, + weight.t(), + input_scale, + weight_scale, + out_dtype=out_dtype, + backend="cutlass", + ) + + if M_padded != M_orig: + output = output[:M_orig, :] + + if bias is not None: + output = output + bias + + output_shape = (*input_shape[:-1], N) + return output.view(output_shape) + + def apply( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: torch.dtype, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.backend == Mxfp8LinearBackend.EMULATION: + return self._apply_emulation(input, weight, weight_scale, out_dtype, bias) + + assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS + return self._apply_flashinfer_cutlass( + input, weight, weight_scale, out_dtype, bias + ) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 88e31718adff..333e66f68a87 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -553,6 +553,83 @@ def flashinfer_nvfp4_quantize_fake( rounded_m, rounded_n, dtype=torch.uint8, device=a.device ) + @torch.library.custom_op( + "vllm::mm_mxfp8", + mutates_args=[], + device_types="cuda", + ) + def mm_mxfp8( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + out_dtype: torch.dtype, + backend: str = "cutlass", + ) -> torch.Tensor: + from flashinfer import mm_mxfp8 as mm_mxfp8_ + + return mm_mxfp8_( + A, + B, + A_scale, + B_scale, + out=None, + out_dtype=out_dtype, + backend=backend, + ) + + @torch.library.register_fake( + "vllm::mm_mxfp8", + ) + def mm_mxfp8_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + out_dtype: torch.dtype, + backend: str = "cutlass", + ) -> torch.Tensor: + # A is [m, k], B is [k, n] -> output [m, n] + return torch.empty(A.shape[0], B.shape[1], dtype=out_dtype, device=A.device) + + +def flashinfer_mm_mxfp8( + a: torch.Tensor, + b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, + out_dtype: torch.dtype, + backend: str = "cutlass", +) -> torch.Tensor: + """MXFP8 MM helper - mirrors flashinfer_scaled_fp4_mm API. + + Takes non-transposed weights and handles transpose internally. + + CRITICAL: mm_mxfp8 CUTLASS kernel requires SWIZZLED 1D scales for optimal + performance and accuracy. Both input and weight scales should be in + swizzled format from FlashInfer's mxfp8_quantize(is_sf_swizzled_layout=True). + """ + # a shape [M, K] + # b shape [K, N] + assert a.ndim == 2 and b.ndim == 2 + assert a.shape[1] == b.shape[1] # K dimension must match + + if block_scale_b.ndim != 1: + raise ValueError( + "mm_mxfp8 expects 1D swizzled weight scales for CUTLASS; " + f"got shape={tuple(block_scale_b.shape)}" + ) + + # Output tensor [M, N] + return mm_mxfp8( + a, + b.t(), # Transpose weight: [N, K] -> [K, N] + block_scale_a, + block_scale_b, + out_dtype, + backend=backend, + ) + def flashinfer_scaled_fp4_mm( a: torch.Tensor,