From baafb30da207441bfcb167b5754b3f7252965fde Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 4 May 2026 20:30:42 +0000 Subject: [PATCH 1/7] update Signed-off-by: Dipika --- .../compressed_tensors/compressed_tensors.py | 4 +- .../compressed_tensors_moe_w4a4_mxfp4.py | 1 + .../compressed_tensors/schemes/__init__.py | 4 +- ...p4.py => compressed_tensors_w4a4_mxfp4.py} | 45 ++++++++++++++++--- .../utils/flashinfer_utils_fp4.py | 37 +++++++++++++++ vllm/utils/flashinfer.py | 30 ++++++++++++- 6 files changed, 111 insertions(+), 10 deletions(-) rename vllm/model_executor/layers/quantization/compressed_tensors/schemes/{compressed_tensors_w4a16_mxfp4.py => compressed_tensors_w4a4_mxfp4.py} (63%) create mode 100644 vllm/model_executor/layers/quantization/utils/flashinfer_utils_fp4.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 8d16a143b10a..30049635187b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -43,10 +43,10 @@ CompressedTensors24, CompressedTensorsScheme, CompressedTensorsW4A4Fp4, + CompressedTensorsW4A4Mxfp4, CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Mxfp4, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A8Mxfp8, @@ -626,7 +626,7 @@ def _get_scheme_from_parts( return CompressedTensorsW4A16Fp4() if self._is_mxfp4(weight_quant): - return CompressedTensorsW4A16Mxfp4() + return CompressedTensorsW4A4Mxfp4() if self._is_mxfp8(weight_quant): return CompressedTensorsW8A8Mxfp8() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py index 629e1c5ef1be..f5befe731d41 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py @@ -41,6 +41,7 @@ def __init__(self, moe): super().__init__(moe) self.group_size = 32 self.mxfp4_backend = Mxfp4MoeBackend.MARLIN + # use cutlass if supported, otherwise fallback to marlin for weight-only FP4 self.use_cutlass_mxfp4 = CutlassExpertsMxfp4._supports_current_device() self.experts_cls: type[mk.FusedMoEExperts] if self.use_cutlass_mxfp4: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 457794eb0a09..0b0d8c230617 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -2,10 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from .compressed_tensors_scheme import CompressedTensorsScheme +from .compressed_tensors_w4a4_mxfp4 import CompressedTensorsW4A4Mxfp4 from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8 from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int -from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4 from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 @@ -25,7 +25,7 @@ "WNA16_SUPPORTED_BITS", "CompressedTensors24", "CompressedTensorsW4A16Fp4", - "CompressedTensorsW4A16Mxfp4", + "CompressedTensorsW4A4Mxfp4", "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int", "CompressedTensorsW4A8Fp8", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py similarity index 63% rename from vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py rename to vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py index 77cea0f83e1c..a4c71263fe24 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py @@ -5,9 +5,15 @@ import torch from torch.nn.parameter import Parameter +# from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import ( +# swizzle_mxfp4_scales, +# ) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils_fp4 import ( + apply_mxfp4_flashinfer_linear, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin, @@ -16,13 +22,15 @@ GroupQuantScaleParameter, ModelWeightParameter, ) +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer -__all__ = ["CompressedTensorsW4A16Mxfp4"] +__all__ = ["CompressedTensorsW4A4Mxfp4"] -class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme): +class CompressedTensorsW4A4Mxfp4(CompressedTensorsScheme): """ - Compressed tensors scheme for MXFP4 weight-only quantization. + Compressed tensors scheme for MXFP4. Supports models quantized with the compressed-tensors mxfp4-pack-quantized format. @@ -31,10 +39,17 @@ class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme): - 4-bit float weights (E2M1) packed into uint8 - Per-group E8M0 scales with group_size=32 - No global scale (unlike NVFP4) + + On SM100+ with FlashInfer: true W4A4 (activations dynamically quantized). + Otherwise: W4A16 weight-only via Marlin. """ def __init__(self): self.group_size = 32 + p = current_platform + self.use_flashinfer = ( + p.is_cuda() and p.is_device_capability_family(100) and has_flashinfer() + ) @classmethod def get_min_capability(cls) -> int: @@ -82,11 +97,23 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # Rename weight_packed to weight that marlin expects layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) del layer.weight_packed - prepare_fp4_layer_for_marlin(layer) + if self.use_flashinfer: + # TODO: verify whether FlashInfer cute-dsl needs a specific + # swizzle for checkpoint weight scales (flat [N, K//32] E8M0). + # swizzle_mxfp4_scales targets the CUTLASS MoE tiled layout and + # may not match FlashInfer's 128x4 layout — test first. + # N, scale_K = layer.weight_scale.shape + # K = scale_K * self.group_size + # layer.weight_scale = Parameter( + # swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(N, -1), + # requires_grad=False, + # ) + layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) + else: + prepare_fp4_layer_for_marlin(layer) def apply_weights( self, @@ -94,6 +121,14 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: + if self.use_flashinfer: + return apply_mxfp4_flashinfer_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + size_n=layer.output_size_per_partition, + bias=bias, + ) return apply_fp4_marlin_linear( input=x, weight=layer.weight, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils_fp4.py new file mode 100644 index 000000000000..03025352762e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils_fp4.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + + +def apply_mxfp4_flashinfer_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + size_n: int, + bias: torch.Tensor | None = None, +) -> torch.Tensor: + from vllm.utils.flashinfer import flashinfer_mm_fp4, flashinfer_mxfp4_quantize + + x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + x_fp4, x_scale = flashinfer_mxfp4_quantize(x) + + dummy_alpha = torch.ones(1, dtype=torch.float32, device=x.device) + out = flashinfer_mm_fp4( + x_fp4, + weight.t(), # [N, K//2] -> [K//2, N] + x_scale, + weight_scale.t(), # [N, K//32] -> [K//32, N] + dummy_alpha, + input.dtype, + use_8x4_sf_layout=False, + backend="cute-dsl", + block_size=32, + use_nvfp4=False, + ) + + if bias is not None: + out = out + bias + return out.view(out_shape) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 828ff08a067d..f1fe7e886ed0 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -480,6 +480,8 @@ def flashinfer_mm_fp4( dtype: torch.dtype, use_8x4_sf_layout: bool, backend: str, + block_size: int = 16, + use_nvfp4: bool = True, ) -> torch.Tensor: from flashinfer import mm_fp4 as flashinfer_mm_fp4_ @@ -490,8 +492,9 @@ def flashinfer_mm_fp4( B_scale, g_scale, dtype, - block_size=16, + block_size=block_size, use_8x4_sf_layout=use_8x4_sf_layout, + use_nvfp4=use_nvfp4, backend=backend, ) @@ -507,9 +510,33 @@ def flashinfer_mm_fp4_fake( dtype: torch.dtype, use_8x4_sf_layout: bool, backend: str, + block_size: int = 16, + use_nvfp4: bool = True, ) -> torch.Tensor: return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device) + @torch.library.custom_op( + "vllm::flashinfer_mxfp4_quantize", + mutates_args=[], + device_types="cuda", + ) + def flashinfer_mxfp4_quantize( + a: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + from flashinfer import mxfp4_quantize as _mxfp4_quantize + + return _mxfp4_quantize(a) + + @torch.library.register_fake("vllm::flashinfer_mxfp4_quantize") + def flashinfer_mxfp4_quantize_fake( + a: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + m, n = a.shape + return ( + torch.empty(m, n // 2, dtype=torch.uint8, device=a.device), + torch.empty(m, n // 32, dtype=torch.uint8, device=a.device), + ) + @torch.library.custom_op( "vllm::bmm_fp8", mutates_args=[], @@ -864,6 +891,7 @@ def is_flashinfer_cudnn_fp8_prefill_attn_supported() -> bool: "can_use_trtllm_attention", "use_trtllm_attention", "flashinfer_scaled_fp4_mm", + "flashinfer_mxfp4_quantize", "flashinfer_scaled_fp8_mm", "flashinfer_scaled_fp8_mm_out", "flashinfer_quant_nvfp4_8x4_sf_layout", From 8e28f85ac161c9af146309f6e66d1eac28abf068 Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 4 May 2026 20:57:19 +0000 Subject: [PATCH 2/7] update Signed-off-by: Dipika --- .../schemes/compressed_tensors_w4a4_mxfp4.py | 23 ++++++++----------- .../utils/flashinfer_utils_fp4.py | 17 +++++++------- vllm/utils/flashinfer.py | 9 +++++++- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py index a4c71263fe24..22ae35df203a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py @@ -5,9 +5,9 @@ import torch from torch.nn.parameter import Parameter -# from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import ( -# swizzle_mxfp4_scales, -# ) +from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import ( + swizzle_mxfp4_scales, +) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) @@ -101,17 +101,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: del layer.weight_packed if self.use_flashinfer: - # TODO: verify whether FlashInfer cute-dsl needs a specific - # swizzle for checkpoint weight scales (flat [N, K//32] E8M0). - # swizzle_mxfp4_scales targets the CUTLASS MoE tiled layout and - # may not match FlashInfer's 128x4 layout — test first. - # N, scale_K = layer.weight_scale.shape - # K = scale_K * self.group_size - # layer.weight_scale = Parameter( - # swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(N, -1), - # requires_grad=False, - # ) - layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) + N, scale_K = layer.weight_scale.shape + K = scale_K * self.group_size + layer.weight_scale = Parameter( + swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(N, -1), + requires_grad=False, + ) else: prepare_fp4_layer_for_marlin(layer) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils_fp4.py index 03025352762e..d616ae1a0676 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils_fp4.py @@ -11,22 +11,23 @@ def apply_mxfp4_flashinfer_linear( size_n: int, bias: torch.Tensor | None = None, ) -> torch.Tensor: - from vllm.utils.flashinfer import flashinfer_mm_fp4, flashinfer_mxfp4_quantize + from vllm.utils.flashinfer import ( + flashinfer_mxfp4_quantize, + flashinfer_scaled_fp4_mm, + ) x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (size_n,) x_fp4, x_scale = flashinfer_mxfp4_quantize(x) - dummy_alpha = torch.ones(1, dtype=torch.float32, device=x.device) - out = flashinfer_mm_fp4( + out = flashinfer_scaled_fp4_mm( x_fp4, - weight.t(), # [N, K//2] -> [K//2, N] + weight, x_scale, - weight_scale.t(), # [N, K//32] -> [K//32, N] - dummy_alpha, - input.dtype, - use_8x4_sf_layout=False, + weight_scale, + alpha=None, + out_dtype=input.dtype, backend="cute-dsl", block_size=32, use_nvfp4=False, diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index f1fe7e886ed0..f483a3cd43e5 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -685,15 +685,20 @@ def flashinfer_scaled_fp4_mm( b: torch.Tensor, block_scale_a: torch.Tensor, block_scale_b: torch.Tensor, - alpha: torch.Tensor, + alpha: torch.Tensor | None, out_dtype: torch.dtype, backend: str, + block_size: int = 16, + use_nvfp4: bool = True, ) -> torch.Tensor: assert a.ndim == 2 and b.ndim == 2 assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 assert a.stride(-1) == 1 and b.stride(-1) == 1 assert a.shape[1] == b.shape[1] + if alpha is None: + alpha = torch.ones(1, dtype=torch.float32, device=a.device) + if backend in ("cutlass", "cudnn"): block_scale_a = block_scale_a.view(torch.uint8) block_scale_b = block_scale_b.view(torch.uint8) @@ -709,6 +714,8 @@ def flashinfer_scaled_fp4_mm( out_dtype, use_8x4_sf_layout=use_8x4_sf_layout, backend=backend, + block_size=block_size, + use_nvfp4=use_nvfp4, ) From f935818accfa67c24173e5894cba9ebb2f196cb9 Mon Sep 17 00:00:00 2001 From: Dipika Date: Tue, 5 May 2026 18:14:04 +0000 Subject: [PATCH 3/7] use linear kernel abstraction Signed-off-by: Dipika --- .../model_executor/kernels/linear/__init__.py | 68 ++++++++++++++++++ .../kernels/linear/mxfp4/__init__.py | 12 ++++ .../kernels/linear/mxfp4/base.py | 67 +++++++++++++++++ .../kernels/linear/mxfp4/flashinfer.py | 71 +++++++++++++++++++ .../kernels/linear/mxfp4/marlin.py | 52 ++++++++++++++ .../schemes/compressed_tensors_w4a4_mxfp4.py | 48 ++----------- .../utils/flashinfer_utils_fp4.py | 38 ---------- 7 files changed, 274 insertions(+), 82 deletions(-) create mode 100644 vllm/model_executor/kernels/linear/mxfp4/__init__.py create mode 100644 vllm/model_executor/kernels/linear/mxfp4/base.py create mode 100644 vllm/model_executor/kernels/linear/mxfp4/flashinfer.py create mode 100644 vllm/model_executor/kernels/linear/mxfp4/marlin.py delete mode 100644 vllm/model_executor/layers/quantization/utils/flashinfer_utils_fp4.py diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index 5d513f767f03..a3621ac19e9e 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -58,6 +58,16 @@ XPUW4A8IntLinearKernel, XPUwNa16LinearKernel, ) +from vllm.model_executor.kernels.linear.mxfp4 import ( + MxFp4LinearKernel, + MxFp4LinearLayerConfig, +) +from vllm.model_executor.kernels.linear.mxfp4.flashinfer import ( + FlashInferMxFp4LinearKernel, +) +from vllm.model_executor.kernels.linear.mxfp4.marlin import ( + MarlinMxFp4LinearKernel, +) from vllm.model_executor.kernels.linear.mxfp8 import ( Mxfp8LinearKernel, Mxfp8LinearLayerConfig, @@ -272,6 +282,13 @@ ], } +_POSSIBLE_MXFP4_KERNELS: dict[PlatformEnum, list[type[MxFp4LinearKernel]]] = { + PlatformEnum.CUDA: [ + FlashInferMxFp4LinearKernel, + MarlinMxFp4LinearKernel, + ], +} + # TODO make all kernels inherit from MMLinearKernel # then bound _KernelT only to MMLinearKernel _KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel | MMLinearKernel) @@ -566,6 +583,48 @@ def init_mxfp8_linear_kernel() -> Mxfp8LinearKernel: ) +def init_mxfp4_linear_kernel() -> MxFp4LinearKernel: + """Select and instantiate the best MXFP4 linear kernel for the + current platform.""" + force_kernel: type[MxFp4LinearKernel] | None = None + if envs.VLLM_MXFP4_USE_MARLIN: + force_kernel = MarlinMxFp4LinearKernel + + if force_kernel is not None: + is_supported, reason = force_kernel.is_supported() + if not is_supported: + raise ValueError( + f"Forced MXFP4 kernel {force_kernel.__name__} is not " + f"supported: {reason}" + ) + logger.info_once("Using %s for MXFP4 GEMM", force_kernel.__name__) + return force_kernel(MxFp4LinearLayerConfig()) + + platform = current_platform._enum + possible = _POSSIBLE_MXFP4_KERNELS.get(platform, []) + + failure_reasons = [] + for kernel_cls in possible: + if kernel_cls.__name__ in envs.VLLM_DISABLED_KERNELS: + failure_reasons.append( + f" {kernel_cls.__name__} disabled by environment variable" + ) + continue + + is_supported, reason = kernel_cls.is_supported() + if not is_supported: + failure_reasons.append(f"{kernel_cls.__name__}: {reason}") + continue + + logger.info_once("Using %s for MXFP4 GEMM", kernel_cls.__name__) + return kernel_cls(MxFp4LinearLayerConfig()) + + raise ValueError( + "Failed to find a kernel that can implement the " + "MXFP4 linear layer. Reasons: \n" + "\n".join(failure_reasons) + ) + + def init_wfp8_a16_linear_kernel( weight_quant_key: QuantKey, activation_quant_key: QuantKey, @@ -726,6 +785,10 @@ def register_linear_kernel( if platform not in _POSSIBLE_NVFP4_KERNELS: _POSSIBLE_NVFP4_KERNELS[platform] = [] _POSSIBLE_NVFP4_KERNELS[platform].append(kernel_class) + elif kernel_type == "mxfp4": + if platform not in _POSSIBLE_MXFP4_KERNELS: + _POSSIBLE_MXFP4_KERNELS[platform] = [] + _POSSIBLE_MXFP4_KERNELS[platform].append(kernel_class) else: raise ValueError(f"Unrecognized kernel type: {kernel_type}") @@ -773,6 +836,11 @@ def register_linear_kernel( "init_mxfp8_linear_kernel", "Mxfp8LinearKernel", "Mxfp8LinearLayerConfig", + "init_mxfp4_linear_kernel", + "MxFp4LinearKernel", + "MxFp4LinearLayerConfig", + "FlashInferMxFp4LinearKernel", + "MarlinMxFp4LinearKernel", "FlashInferCutlassMxfp8LinearKernel", "MarlinMxfp8LinearKernel", "XPUMxFp8LinearKernel", diff --git a/vllm/model_executor/kernels/linear/mxfp4/__init__.py b/vllm/model_executor/kernels/linear/mxfp4/__init__.py new file mode 100644 index 000000000000..0927cd945f6f --- /dev/null +++ b/vllm/model_executor/kernels/linear/mxfp4/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.model_executor.kernels.linear.mxfp4.base import ( + MxFp4LinearKernel, + MxFp4LinearLayerConfig, +) + +__all__ = [ + "MxFp4LinearKernel", + "MxFp4LinearLayerConfig", +] diff --git a/vllm/model_executor/kernels/linear/mxfp4/base.py b/vllm/model_executor/kernels/linear/mxfp4/base.py new file mode 100644 index 000000000000..868faa4731d5 --- /dev/null +++ b/vllm/model_executor/kernels/linear/mxfp4/base.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import torch + + +@dataclass +class MxFp4LinearLayerConfig: + """Configuration for an MXFP4 linear layer. + + All MXFP4 layers share the same structure: packed uint8 weights (2 FP4 values per + byte) and per-block weight scales (group size 32). + """ + + pass + + +class MxFp4LinearKernel(ABC): + """Base class for MXFP4 quantized linear kernels. + + Each subclass implements a specific GEMM backend (CUTLASS, Marlin, etc). + The kernel selection mechanism iterates over registered subclasses in + priority order,calling ``is_supported`` and ``can_implement`` to find the best + match for the current hardware. + """ + + def __init__(self, config: MxFp4LinearLayerConfig) -> None: + assert self.can_implement(config)[0] + assert self.is_supported()[0] + self.config = config + + @classmethod + @abstractmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + """Return whether this kernel can run on the current platform.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]: + """Return whether this kernel can handle *config*.""" + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Transform weights into the format required by this kernel. + + Called once after checkpoint weights have been loaded onto the + device. Implementations should repack / swizzle / pad weights + and scales in-place on *layer*. + """ + raise NotImplementedError + + @abstractmethod + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + """Run the quantized GEMM.""" + raise NotImplementedError diff --git a/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py b/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py new file mode 100644 index 000000000000..ff3564993e87 --- /dev/null +++ b/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import ( + swizzle_mxfp4_scales, +) +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer + +from .base import MxFp4LinearKernel, MxFp4LinearLayerConfig + +_MXFP4_GROUP_SIZE = 32 + + +class FlashInferMxFp4LinearKernel(MxFp4LinearKernel): + """MXFP4 W4A4 GEMM via FlashInfer CUTLASS (SM100+).""" + + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if current_platform.has_device_capability(100) and has_flashinfer(): + return True, None + return False, "FlashInfer + >=sm_100 (Blackwell) required" + + @classmethod + def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]: + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + N, scale_K = layer.weight_scale.shape + K = scale_K * _MXFP4_GROUP_SIZE + layer.weight_scale = Parameter( + swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(N, -1), + requires_grad=False, + ) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + from vllm.utils.flashinfer import ( + flashinfer_mxfp4_quantize, + flashinfer_scaled_fp4_mm, + ) + + weight = layer.weight + out_shape = x.shape[:-1] + (layer.output_size_per_partition,) + x_2d = x.reshape(-1, x.shape[-1]) + + x_fp4, x_scale = flashinfer_mxfp4_quantize(x_2d) + out = flashinfer_scaled_fp4_mm( + x_fp4, + weight, + x_scale, + layer.weight_scale, + alpha=None, + out_dtype=x.dtype, + backend="cute-dsl", + block_size=_MXFP4_GROUP_SIZE, + use_nvfp4=False, + ) + + if bias is not None: + out = out + bias + return out.view(out_shape) diff --git a/vllm/model_executor/kernels/linear/mxfp4/marlin.py b/vllm/model_executor/kernels/linear/mxfp4/marlin.py new file mode 100644 index 000000000000..38440752072e --- /dev/null +++ b/vllm/model_executor/kernels/linear/mxfp4/marlin.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from .base import MxFp4LinearKernel, MxFp4LinearLayerConfig + + +class MarlinMxFp4LinearKernel(MxFp4LinearKernel): + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + is_fp4_marlin_supported, + ) + + if is_fp4_marlin_supported(): + return True, None + return False, "Marlin FP4 not available" + + @classmethod + def can_implement(cls, c: MxFp4LinearLayerConfig) -> tuple[bool, str | None]: + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + ) + + prepare_fp4_layer_for_marlin(layer) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + ) + + return apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=None, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py index 22ae35df203a..7b994e10e448 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py @@ -5,25 +5,14 @@ import torch from torch.nn.parameter import Parameter -from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import ( - swizzle_mxfp4_scales, -) +from vllm.model_executor.kernels.linear import init_mxfp4_linear_kernel from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) -from vllm.model_executor.layers.quantization.utils.flashinfer_utils_fp4 import ( - apply_mxfp4_flashinfer_linear, -) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - apply_fp4_marlin_linear, - prepare_fp4_layer_for_marlin, -) from vllm.model_executor.parameter import ( GroupQuantScaleParameter, ModelWeightParameter, ) -from vllm.platforms import current_platform -from vllm.utils.flashinfer import has_flashinfer __all__ = ["CompressedTensorsW4A4Mxfp4"] @@ -46,10 +35,7 @@ class CompressedTensorsW4A4Mxfp4(CompressedTensorsScheme): def __init__(self): self.group_size = 32 - p = current_platform - self.use_flashinfer = ( - p.is_cuda() and p.is_device_capability_family(100) and has_flashinfer() - ) + self.kernel = init_mxfp4_linear_kernel() @classmethod def get_min_capability(cls) -> int: @@ -99,16 +85,7 @@ def create_weights( def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) del layer.weight_packed - - if self.use_flashinfer: - N, scale_K = layer.weight_scale.shape - K = scale_K * self.group_size - layer.weight_scale = Parameter( - swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(N, -1), - requires_grad=False, - ) - else: - prepare_fp4_layer_for_marlin(layer) + self.kernel.process_weights_after_loading(layer) def apply_weights( self, @@ -116,21 +93,4 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - if self.use_flashinfer: - return apply_mxfp4_flashinfer_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - size_n=layer.output_size_per_partition, - bias=bias, - ) - return apply_fp4_marlin_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - weight_global_scale=None, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias, - ) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils_fp4.py deleted file mode 100644 index d616ae1a0676..000000000000 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils_fp4.py +++ /dev/null @@ -1,38 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - - -def apply_mxfp4_flashinfer_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - size_n: int, - bias: torch.Tensor | None = None, -) -> torch.Tensor: - from vllm.utils.flashinfer import ( - flashinfer_mxfp4_quantize, - flashinfer_scaled_fp4_mm, - ) - - x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - x_fp4, x_scale = flashinfer_mxfp4_quantize(x) - - out = flashinfer_scaled_fp4_mm( - x_fp4, - weight, - x_scale, - weight_scale, - alpha=None, - out_dtype=input.dtype, - backend="cute-dsl", - block_size=32, - use_nvfp4=False, - ) - - if bias is not None: - out = out + bias - return out.view(out_shape) From fadc70111949df6a67d2448bfdff4938a719af98 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 5 May 2026 15:28:00 -0400 Subject: [PATCH 4/7] add test models Signed-off-by: Dipika Sikka --- tests/quantization/test_compressed_tensors.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 6b95d9e346db..854613b96e3d 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -24,6 +24,7 @@ CompressedTensorsConfig, CompressedTensorsLinearMethod, CompressedTensorsW4A4Fp4, + CompressedTensorsW4A4Mxfp4, CompressedTensorsW4A8Fp8, CompressedTensorsW4A16Fp4, CompressedTensorsW8A8Fp8, @@ -668,3 +669,31 @@ def check_model(model): llm.apply_model(check_model) output = llm.generate_greedy("Hello my name is", max_tokens=4) assert output + + +@pytest.mark.skipif( + not current_platform.is_cuda() or not current_platform.has_device_capability(80), + reason="MXFP4 requires ampere or newer", +) +def test_compressed_tensors_mxfp4(vllm_runner): + model_path = "nm-testing/TinyLlama-1.1B-Chat-v1.0-MXFP4" + with vllm_runner(model_path, enforce_eager=True) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + for proj in (qkv_proj, o_proj, gate_up_proj, down_proj): + assert isinstance(proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(proj.scheme, CompressedTensorsW4A4Mxfp4) + + # Verify group size + assert proj.scheme.group_size == 32 + + llm.apply_model(check_model) + output = llm.generate_greedy("Hello my name is", max_tokens=4) + assert output From e50f7244616275f651b3059eace467bb09f57a76 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 7 May 2026 18:15:31 -0400 Subject: [PATCH 5/7] fix padded mx linear Signed-off-by: Kyle Sayers --- vllm/model_executor/kernels/linear/mxfp4/flashinfer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py b/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py index ff3564993e87..47c28d712d00 100644 --- a/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py +++ b/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py @@ -33,8 +33,11 @@ def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: N, scale_K = layer.weight_scale.shape K = scale_K * _MXFP4_GROUP_SIZE + + # swizzle pads N to the next multiple of 128 for CUTLASS tiling + padded_N = ((N + 127) // 128) * 128 layer.weight_scale = Parameter( - swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(N, -1), + swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(padded_N, -1), requires_grad=False, ) From eda7aa3dea497c559de10860695a1a76df2869d8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 7 May 2026 18:39:27 -0400 Subject: [PATCH 6/7] update dummy shape Signed-off-by: Kyle Sayers --- vllm/utils/flashinfer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index f483a3cd43e5..5bd776399229 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -20,6 +20,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv logger = init_logger(__name__) @@ -531,10 +532,13 @@ def flashinfer_mxfp4_quantize( def flashinfer_mxfp4_quantize_fake( a: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - m, n = a.shape + m, k = a.shape + sf_vec_size = 32 + padded_m = cdiv(m, 128) * 128 + sf_cols = cdiv(k // sf_vec_size, 4) * 4 return ( - torch.empty(m, n // 2, dtype=torch.uint8, device=a.device), - torch.empty(m, n // 32, dtype=torch.uint8, device=a.device), + torch.empty(m, k // 2, dtype=torch.uint8, device=a.device), + torch.empty(padded_m, sf_cols, dtype=torch.uint8, device=a.device), ) @torch.library.custom_op( From fca2eed1801fc2425e26e30eb8f3e498cd697873 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 11 May 2026 18:19:40 -0400 Subject: [PATCH 7/7] Apply suggestions from code review Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Dipika Sikka --- vllm/model_executor/kernels/linear/mxfp4/flashinfer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py b/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py index 47c28d712d00..8889986f05b9 100644 --- a/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py +++ b/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py @@ -8,7 +8,7 @@ swizzle_mxfp4_scales, ) from vllm.platforms import current_platform -from vllm.utils.flashinfer import has_flashinfer +from vllm.utils.flashinfer import has_flashinfer_cutedsl from .base import MxFp4LinearKernel, MxFp4LinearLayerConfig @@ -22,7 +22,7 @@ class FlashInferMxFp4LinearKernel(MxFp4LinearKernel): def is_supported( cls, compute_capability: int | None = None ) -> tuple[bool, str | None]: - if current_platform.has_device_capability(100) and has_flashinfer(): + if current_platform.has_device_capability(100) and has_flashinfer_cutedsl(): return True, None return False, "FlashInfer + >=sm_100 (Blackwell) required"