From 1a6da54818268abe7c097499c5aa549cf36ae830 Mon Sep 17 00:00:00 2001 From: Zhewen Li Date: Tue, 12 May 2026 19:04:28 -0700 Subject: [PATCH] Revert "[MXFP4] Support for linear layers + compressed-tensors integration (#41664)" This reverts commit a7b801e26d6b9d96bb49e939c0b6b3acf1d85796. --- tests/quantization/test_compressed_tensors.py | 29 -------- .../model_executor/kernels/linear/__init__.py | 68 ----------------- .../kernels/linear/mxfp4/__init__.py | 12 --- .../kernels/linear/mxfp4/base.py | 67 ----------------- .../kernels/linear/mxfp4/flashinfer.py | 74 ------------------- .../kernels/linear/mxfp4/marlin.py | 52 ------------- .../compressed_tensors/compressed_tensors.py | 4 +- .../compressed_tensors_moe_w4a4_mxfp4.py | 1 - .../compressed_tensors/schemes/__init__.py | 4 +- ...4.py => compressed_tensors_w4a16_mxfp4.py} | 30 +++++--- vllm/utils/flashinfer.py | 43 +---------- 11 files changed, 26 insertions(+), 358 deletions(-) delete mode 100644 vllm/model_executor/kernels/linear/mxfp4/__init__.py delete mode 100644 vllm/model_executor/kernels/linear/mxfp4/base.py delete mode 100644 vllm/model_executor/kernels/linear/mxfp4/flashinfer.py delete mode 100644 vllm/model_executor/kernels/linear/mxfp4/marlin.py rename vllm/model_executor/layers/quantization/compressed_tensors/schemes/{compressed_tensors_w4a4_mxfp4.py => compressed_tensors_w4a16_mxfp4.py} (76%) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 9fdcd59a58da..7e80b0ecfa61 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -24,7 +24,6 @@ CompressedTensorsConfig, CompressedTensorsLinearMethod, CompressedTensorsW4A4Fp4, - CompressedTensorsW4A4Mxfp4, CompressedTensorsW4A8Fp8, CompressedTensorsW4A16Fp4, CompressedTensorsW8A8Fp8, @@ -690,31 +689,3 @@ def check_model(model): llm.apply_model(check_model) output = llm.generate_greedy("Hello my name is", max_tokens=4) assert output - - -@pytest.mark.skipif( - not current_platform.is_cuda() or not current_platform.has_device_capability(80), - reason="MXFP4 requires ampere or newer", -) -def test_compressed_tensors_mxfp4(vllm_runner): - model_path = "nm-testing/TinyLlama-1.1B-Chat-v1.0-MXFP4" - with vllm_runner(model_path, enforce_eager=True) as llm: - - def check_model(model): - layer = model.model.layers[0] - - qkv_proj = layer.self_attn.qkv_proj - o_proj = layer.self_attn.o_proj - gate_up_proj = layer.mlp.gate_up_proj - down_proj = layer.mlp.down_proj - - for proj in (qkv_proj, o_proj, gate_up_proj, down_proj): - assert isinstance(proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(proj.scheme, CompressedTensorsW4A4Mxfp4) - - # Verify group size - assert proj.scheme.group_size == 32 - - llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=4) - assert output diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index 4776494f2359..b9a96a035309 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -58,16 +58,6 @@ XPUW4A8IntLinearKernel, XPUwNa16LinearKernel, ) -from vllm.model_executor.kernels.linear.mxfp4 import ( - MxFp4LinearKernel, - MxFp4LinearLayerConfig, -) -from vllm.model_executor.kernels.linear.mxfp4.flashinfer import ( - FlashInferMxFp4LinearKernel, -) -from vllm.model_executor.kernels.linear.mxfp4.marlin import ( - MarlinMxFp4LinearKernel, -) from vllm.model_executor.kernels.linear.mxfp8 import ( Mxfp8LinearKernel, Mxfp8LinearLayerConfig, @@ -286,13 +276,6 @@ ], } -_POSSIBLE_MXFP4_KERNELS: dict[PlatformEnum, list[type[MxFp4LinearKernel]]] = { - PlatformEnum.CUDA: [ - FlashInferMxFp4LinearKernel, - MarlinMxFp4LinearKernel, - ], -} - # TODO make all kernels inherit from MMLinearKernel # then bound _KernelT only to MMLinearKernel _KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel | MMLinearKernel) @@ -587,48 +570,6 @@ def init_mxfp8_linear_kernel() -> Mxfp8LinearKernel: ) -def init_mxfp4_linear_kernel() -> MxFp4LinearKernel: - """Select and instantiate the best MXFP4 linear kernel for the - current platform.""" - force_kernel: type[MxFp4LinearKernel] | None = None - if envs.VLLM_MXFP4_USE_MARLIN: - force_kernel = MarlinMxFp4LinearKernel - - if force_kernel is not None: - is_supported, reason = force_kernel.is_supported() - if not is_supported: - raise ValueError( - f"Forced MXFP4 kernel {force_kernel.__name__} is not " - f"supported: {reason}" - ) - logger.info_once("Using %s for MXFP4 GEMM", force_kernel.__name__) - return force_kernel(MxFp4LinearLayerConfig()) - - platform = current_platform._enum - possible = _POSSIBLE_MXFP4_KERNELS.get(platform, []) - - failure_reasons = [] - for kernel_cls in possible: - if kernel_cls.__name__ in envs.VLLM_DISABLED_KERNELS: - failure_reasons.append( - f" {kernel_cls.__name__} disabled by environment variable" - ) - continue - - is_supported, reason = kernel_cls.is_supported() - if not is_supported: - failure_reasons.append(f"{kernel_cls.__name__}: {reason}") - continue - - logger.info_once("Using %s for MXFP4 GEMM", kernel_cls.__name__) - return kernel_cls(MxFp4LinearLayerConfig()) - - raise ValueError( - "Failed to find a kernel that can implement the " - "MXFP4 linear layer. Reasons: \n" + "\n".join(failure_reasons) - ) - - def init_wfp8_a16_linear_kernel( weight_quant_key: QuantKey, activation_quant_key: QuantKey, @@ -789,10 +730,6 @@ def register_linear_kernel( if platform not in _POSSIBLE_NVFP4_KERNELS: _POSSIBLE_NVFP4_KERNELS[platform] = [] _POSSIBLE_NVFP4_KERNELS[platform].append(kernel_class) - elif kernel_type == "mxfp4": - if platform not in _POSSIBLE_MXFP4_KERNELS: - _POSSIBLE_MXFP4_KERNELS[platform] = [] - _POSSIBLE_MXFP4_KERNELS[platform].append(kernel_class) else: raise ValueError(f"Unrecognized kernel type: {kernel_type}") @@ -840,11 +777,6 @@ def register_linear_kernel( "init_mxfp8_linear_kernel", "Mxfp8LinearKernel", "Mxfp8LinearLayerConfig", - "init_mxfp4_linear_kernel", - "MxFp4LinearKernel", - "MxFp4LinearLayerConfig", - "FlashInferMxFp4LinearKernel", - "MarlinMxFp4LinearKernel", "FlashInferCutlassMxfp8LinearKernel", "MarlinMxfp8LinearKernel", "XPUMxFp8LinearKernel", diff --git a/vllm/model_executor/kernels/linear/mxfp4/__init__.py b/vllm/model_executor/kernels/linear/mxfp4/__init__.py deleted file mode 100644 index 0927cd945f6f..000000000000 --- a/vllm/model_executor/kernels/linear/mxfp4/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from vllm.model_executor.kernels.linear.mxfp4.base import ( - MxFp4LinearKernel, - MxFp4LinearLayerConfig, -) - -__all__ = [ - "MxFp4LinearKernel", - "MxFp4LinearLayerConfig", -] diff --git a/vllm/model_executor/kernels/linear/mxfp4/base.py b/vllm/model_executor/kernels/linear/mxfp4/base.py deleted file mode 100644 index 868faa4731d5..000000000000 --- a/vllm/model_executor/kernels/linear/mxfp4/base.py +++ /dev/null @@ -1,67 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from abc import ABC, abstractmethod -from dataclasses import dataclass - -import torch - - -@dataclass -class MxFp4LinearLayerConfig: - """Configuration for an MXFP4 linear layer. - - All MXFP4 layers share the same structure: packed uint8 weights (2 FP4 values per - byte) and per-block weight scales (group size 32). - """ - - pass - - -class MxFp4LinearKernel(ABC): - """Base class for MXFP4 quantized linear kernels. - - Each subclass implements a specific GEMM backend (CUTLASS, Marlin, etc). - The kernel selection mechanism iterates over registered subclasses in - priority order,calling ``is_supported`` and ``can_implement`` to find the best - match for the current hardware. - """ - - def __init__(self, config: MxFp4LinearLayerConfig) -> None: - assert self.can_implement(config)[0] - assert self.is_supported()[0] - self.config = config - - @classmethod - @abstractmethod - def is_supported( - cls, compute_capability: int | None = None - ) -> tuple[bool, str | None]: - """Return whether this kernel can run on the current platform.""" - raise NotImplementedError - - @classmethod - @abstractmethod - def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]: - """Return whether this kernel can handle *config*.""" - raise NotImplementedError - - @abstractmethod - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - """Transform weights into the format required by this kernel. - - Called once after checkpoint weights have been loaded onto the - device. Implementations should repack / swizzle / pad weights - and scales in-place on *layer*. - """ - raise NotImplementedError - - @abstractmethod - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - """Run the quantized GEMM.""" - raise NotImplementedError diff --git a/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py b/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py deleted file mode 100644 index 8889986f05b9..000000000000 --- a/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py +++ /dev/null @@ -1,74 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch -from torch.nn.parameter import Parameter - -from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import ( - swizzle_mxfp4_scales, -) -from vllm.platforms import current_platform -from vllm.utils.flashinfer import has_flashinfer_cutedsl - -from .base import MxFp4LinearKernel, MxFp4LinearLayerConfig - -_MXFP4_GROUP_SIZE = 32 - - -class FlashInferMxFp4LinearKernel(MxFp4LinearKernel): - """MXFP4 W4A4 GEMM via FlashInfer CUTLASS (SM100+).""" - - @classmethod - def is_supported( - cls, compute_capability: int | None = None - ) -> tuple[bool, str | None]: - if current_platform.has_device_capability(100) and has_flashinfer_cutedsl(): - return True, None - return False, "FlashInfer + >=sm_100 (Blackwell) required" - - @classmethod - def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]: - return True, None - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - N, scale_K = layer.weight_scale.shape - K = scale_K * _MXFP4_GROUP_SIZE - - # swizzle pads N to the next multiple of 128 for CUTLASS tiling - padded_N = ((N + 127) // 128) * 128 - layer.weight_scale = Parameter( - swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(padded_N, -1), - requires_grad=False, - ) - - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - from vllm.utils.flashinfer import ( - flashinfer_mxfp4_quantize, - flashinfer_scaled_fp4_mm, - ) - - weight = layer.weight - out_shape = x.shape[:-1] + (layer.output_size_per_partition,) - x_2d = x.reshape(-1, x.shape[-1]) - - x_fp4, x_scale = flashinfer_mxfp4_quantize(x_2d) - out = flashinfer_scaled_fp4_mm( - x_fp4, - weight, - x_scale, - layer.weight_scale, - alpha=None, - out_dtype=x.dtype, - backend="cute-dsl", - block_size=_MXFP4_GROUP_SIZE, - use_nvfp4=False, - ) - - if bias is not None: - out = out + bias - return out.view(out_shape) diff --git a/vllm/model_executor/kernels/linear/mxfp4/marlin.py b/vllm/model_executor/kernels/linear/mxfp4/marlin.py deleted file mode 100644 index 38440752072e..000000000000 --- a/vllm/model_executor/kernels/linear/mxfp4/marlin.py +++ /dev/null @@ -1,52 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -from .base import MxFp4LinearKernel, MxFp4LinearLayerConfig - - -class MarlinMxFp4LinearKernel(MxFp4LinearKernel): - @classmethod - def is_supported( - cls, compute_capability: int | None = None - ) -> tuple[bool, str | None]: - from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - is_fp4_marlin_supported, - ) - - if is_fp4_marlin_supported(): - return True, None - return False, "Marlin FP4 not available" - - @classmethod - def can_implement(cls, c: MxFp4LinearLayerConfig) -> tuple[bool, str | None]: - return True, None - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - prepare_fp4_layer_for_marlin, - ) - - prepare_fp4_layer_for_marlin(layer) - - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - apply_fp4_marlin_linear, - ) - - return apply_fp4_marlin_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - weight_global_scale=None, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias, - ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 85f12a464fe1..8574c017e4fd 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -42,10 +42,10 @@ CompressedTensors24, CompressedTensorsScheme, CompressedTensorsW4A4Fp4, - CompressedTensorsW4A4Mxfp4, CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4, + CompressedTensorsW4A16Mxfp4, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A8Mxfp8, @@ -625,7 +625,7 @@ def _get_scheme_from_parts( return CompressedTensorsW4A16Fp4() if self._is_mxfp4(weight_quant): - return CompressedTensorsW4A4Mxfp4() + return CompressedTensorsW4A16Mxfp4() if self._is_mxfp8(weight_quant): return CompressedTensorsW8A8Mxfp8() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py index af42222e0c53..de86e79234f2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py @@ -42,7 +42,6 @@ def __init__(self, moe): super().__init__(moe) self.group_size = 32 self.mxfp4_backend = Mxfp4MoeBackend.MARLIN - # use cutlass if supported, otherwise fallback to marlin for weight-only FP4 self.use_cutlass_mxfp4 = CutlassExpertsMxfp4._supports_current_device() self.experts_cls: type[mk.FusedMoEExperts] if self.use_cutlass_mxfp4: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 0b0d8c230617..457794eb0a09 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -2,10 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from .compressed_tensors_scheme import CompressedTensorsScheme -from .compressed_tensors_w4a4_mxfp4 import CompressedTensorsW4A4Mxfp4 from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8 from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int +from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4 from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 @@ -25,7 +25,7 @@ "WNA16_SUPPORTED_BITS", "CompressedTensors24", "CompressedTensorsW4A16Fp4", - "CompressedTensorsW4A4Mxfp4", + "CompressedTensorsW4A16Mxfp4", "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int", "CompressedTensorsW4A8Fp8", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py similarity index 76% rename from vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py rename to vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py index 7b994e10e448..77cea0f83e1c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py @@ -5,21 +5,24 @@ import torch from torch.nn.parameter import Parameter -from vllm.model_executor.kernels.linear import init_mxfp4_linear_kernel from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, +) from vllm.model_executor.parameter import ( GroupQuantScaleParameter, ModelWeightParameter, ) -__all__ = ["CompressedTensorsW4A4Mxfp4"] +__all__ = ["CompressedTensorsW4A16Mxfp4"] -class CompressedTensorsW4A4Mxfp4(CompressedTensorsScheme): +class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme): """ - Compressed tensors scheme for MXFP4. + Compressed tensors scheme for MXFP4 weight-only quantization. Supports models quantized with the compressed-tensors mxfp4-pack-quantized format. @@ -28,14 +31,10 @@ class CompressedTensorsW4A4Mxfp4(CompressedTensorsScheme): - 4-bit float weights (E2M1) packed into uint8 - Per-group E8M0 scales with group_size=32 - No global scale (unlike NVFP4) - - On SM100+ with FlashInfer: true W4A4 (activations dynamically quantized). - Otherwise: W4A16 weight-only via Marlin. """ def __init__(self): self.group_size = 32 - self.kernel = init_mxfp4_linear_kernel() @classmethod def get_min_capability(cls) -> int: @@ -83,9 +82,11 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Rename weight_packed to weight that marlin expects layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) del layer.weight_packed - self.kernel.process_weights_after_loading(layer) + + prepare_fp4_layer_for_marlin(layer) def apply_weights( self, @@ -93,4 +94,13 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.kernel.apply_weights(layer, x, bias) + return apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=None, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 54d8645dc1bc..44fcc19c2d2b 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -20,7 +20,6 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.math_utils import cdiv logger = init_logger(__name__) @@ -481,8 +480,6 @@ def flashinfer_mm_fp4( dtype: torch.dtype, use_8x4_sf_layout: bool, backend: str, - block_size: int = 16, - use_nvfp4: bool = True, ) -> torch.Tensor: from flashinfer import mm_fp4 as flashinfer_mm_fp4_ @@ -493,9 +490,8 @@ def flashinfer_mm_fp4( B_scale, g_scale, dtype, - block_size=block_size, + block_size=16, use_8x4_sf_layout=use_8x4_sf_layout, - use_nvfp4=use_nvfp4, backend=backend, ) @@ -511,36 +507,9 @@ def flashinfer_mm_fp4_fake( dtype: torch.dtype, use_8x4_sf_layout: bool, backend: str, - block_size: int = 16, - use_nvfp4: bool = True, ) -> torch.Tensor: return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device) - @torch.library.custom_op( - "vllm::flashinfer_mxfp4_quantize", - mutates_args=[], - device_types="cuda", - ) - def flashinfer_mxfp4_quantize( - a: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - from flashinfer import mxfp4_quantize as _mxfp4_quantize - - return _mxfp4_quantize(a) - - @torch.library.register_fake("vllm::flashinfer_mxfp4_quantize") - def flashinfer_mxfp4_quantize_fake( - a: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - m, k = a.shape - sf_vec_size = 32 - padded_m = cdiv(m, 128) * 128 - sf_cols = cdiv(k // sf_vec_size, 4) * 4 - return ( - torch.empty(m, k // 2, dtype=torch.uint8, device=a.device), - torch.empty(padded_m, sf_cols, dtype=torch.uint8, device=a.device), - ) - @torch.library.custom_op( "vllm::bmm_fp8", mutates_args=[], @@ -689,20 +658,15 @@ def flashinfer_scaled_fp4_mm( b: torch.Tensor, block_scale_a: torch.Tensor, block_scale_b: torch.Tensor, - alpha: torch.Tensor | None, + alpha: torch.Tensor, out_dtype: torch.dtype, backend: str, - block_size: int = 16, - use_nvfp4: bool = True, ) -> torch.Tensor: assert a.ndim == 2 and b.ndim == 2 assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 assert a.stride(-1) == 1 and b.stride(-1) == 1 assert a.shape[1] == b.shape[1] - if alpha is None: - alpha = torch.ones(1, dtype=torch.float32, device=a.device) - if backend in ("cutlass", "cudnn"): block_scale_a = block_scale_a.view(torch.uint8) block_scale_b = block_scale_b.view(torch.uint8) @@ -718,8 +682,6 @@ def flashinfer_scaled_fp4_mm( out_dtype, use_8x4_sf_layout=use_8x4_sf_layout, backend=backend, - block_size=block_size, - use_nvfp4=use_nvfp4, ) @@ -942,7 +904,6 @@ def is_flashinfer_cudnn_fp8_prefill_attn_supported() -> bool: "supports_trtllm_attention", "can_use_trtllm_attention", "use_trtllm_attention", - "flashinfer_mxfp4_quantize", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp4_mm_out", "flashinfer_scaled_fp8_mm",