diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 7e80b0ecfa61..9fdcd59a58da 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -24,6 +24,7 @@ CompressedTensorsConfig, CompressedTensorsLinearMethod, CompressedTensorsW4A4Fp4, + CompressedTensorsW4A4Mxfp4, CompressedTensorsW4A8Fp8, CompressedTensorsW4A16Fp4, CompressedTensorsW8A8Fp8, @@ -689,3 +690,31 @@ def check_model(model): llm.apply_model(check_model) output = llm.generate_greedy("Hello my name is", max_tokens=4) assert output + + +@pytest.mark.skipif( + not current_platform.is_cuda() or not current_platform.has_device_capability(80), + reason="MXFP4 requires ampere or newer", +) +def test_compressed_tensors_mxfp4(vllm_runner): + model_path = "nm-testing/TinyLlama-1.1B-Chat-v1.0-MXFP4" + with vllm_runner(model_path, enforce_eager=True) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + for proj in (qkv_proj, o_proj, gate_up_proj, down_proj): + assert isinstance(proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(proj.scheme, CompressedTensorsW4A4Mxfp4) + + # Verify group size + assert proj.scheme.group_size == 32 + + llm.apply_model(check_model) + output = llm.generate_greedy("Hello my name is", max_tokens=4) + assert output diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index b9a96a035309..4776494f2359 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -58,6 +58,16 @@ XPUW4A8IntLinearKernel, XPUwNa16LinearKernel, ) +from vllm.model_executor.kernels.linear.mxfp4 import ( + MxFp4LinearKernel, + MxFp4LinearLayerConfig, +) +from vllm.model_executor.kernels.linear.mxfp4.flashinfer import ( + FlashInferMxFp4LinearKernel, +) +from vllm.model_executor.kernels.linear.mxfp4.marlin import ( + MarlinMxFp4LinearKernel, +) from vllm.model_executor.kernels.linear.mxfp8 import ( Mxfp8LinearKernel, Mxfp8LinearLayerConfig, @@ -276,6 +286,13 @@ ], } +_POSSIBLE_MXFP4_KERNELS: dict[PlatformEnum, list[type[MxFp4LinearKernel]]] = { + PlatformEnum.CUDA: [ + FlashInferMxFp4LinearKernel, + MarlinMxFp4LinearKernel, + ], +} + # TODO make all kernels inherit from MMLinearKernel # then bound _KernelT only to MMLinearKernel _KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel | MMLinearKernel) @@ -570,6 +587,48 @@ def init_mxfp8_linear_kernel() -> Mxfp8LinearKernel: ) +def init_mxfp4_linear_kernel() -> MxFp4LinearKernel: + """Select and instantiate the best MXFP4 linear kernel for the + current platform.""" + force_kernel: type[MxFp4LinearKernel] | None = None + if envs.VLLM_MXFP4_USE_MARLIN: + force_kernel = MarlinMxFp4LinearKernel + + if force_kernel is not None: + is_supported, reason = force_kernel.is_supported() + if not is_supported: + raise ValueError( + f"Forced MXFP4 kernel {force_kernel.__name__} is not " + f"supported: {reason}" + ) + logger.info_once("Using %s for MXFP4 GEMM", force_kernel.__name__) + return force_kernel(MxFp4LinearLayerConfig()) + + platform = current_platform._enum + possible = _POSSIBLE_MXFP4_KERNELS.get(platform, []) + + failure_reasons = [] + for kernel_cls in possible: + if kernel_cls.__name__ in envs.VLLM_DISABLED_KERNELS: + failure_reasons.append( + f" {kernel_cls.__name__} disabled by environment variable" + ) + continue + + is_supported, reason = kernel_cls.is_supported() + if not is_supported: + failure_reasons.append(f"{kernel_cls.__name__}: {reason}") + continue + + logger.info_once("Using %s for MXFP4 GEMM", kernel_cls.__name__) + return kernel_cls(MxFp4LinearLayerConfig()) + + raise ValueError( + "Failed to find a kernel that can implement the " + "MXFP4 linear layer. Reasons: \n" + "\n".join(failure_reasons) + ) + + def init_wfp8_a16_linear_kernel( weight_quant_key: QuantKey, activation_quant_key: QuantKey, @@ -730,6 +789,10 @@ def register_linear_kernel( if platform not in _POSSIBLE_NVFP4_KERNELS: _POSSIBLE_NVFP4_KERNELS[platform] = [] _POSSIBLE_NVFP4_KERNELS[platform].append(kernel_class) + elif kernel_type == "mxfp4": + if platform not in _POSSIBLE_MXFP4_KERNELS: + _POSSIBLE_MXFP4_KERNELS[platform] = [] + _POSSIBLE_MXFP4_KERNELS[platform].append(kernel_class) else: raise ValueError(f"Unrecognized kernel type: {kernel_type}") @@ -777,6 +840,11 @@ def register_linear_kernel( "init_mxfp8_linear_kernel", "Mxfp8LinearKernel", "Mxfp8LinearLayerConfig", + "init_mxfp4_linear_kernel", + "MxFp4LinearKernel", + "MxFp4LinearLayerConfig", + "FlashInferMxFp4LinearKernel", + "MarlinMxFp4LinearKernel", "FlashInferCutlassMxfp8LinearKernel", "MarlinMxfp8LinearKernel", "XPUMxFp8LinearKernel", diff --git a/vllm/model_executor/kernels/linear/mxfp4/__init__.py b/vllm/model_executor/kernels/linear/mxfp4/__init__.py new file mode 100644 index 000000000000..0927cd945f6f --- /dev/null +++ b/vllm/model_executor/kernels/linear/mxfp4/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.model_executor.kernels.linear.mxfp4.base import ( + MxFp4LinearKernel, + MxFp4LinearLayerConfig, +) + +__all__ = [ + "MxFp4LinearKernel", + "MxFp4LinearLayerConfig", +] diff --git a/vllm/model_executor/kernels/linear/mxfp4/base.py b/vllm/model_executor/kernels/linear/mxfp4/base.py new file mode 100644 index 000000000000..868faa4731d5 --- /dev/null +++ b/vllm/model_executor/kernels/linear/mxfp4/base.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import torch + + +@dataclass +class MxFp4LinearLayerConfig: + """Configuration for an MXFP4 linear layer. + + All MXFP4 layers share the same structure: packed uint8 weights (2 FP4 values per + byte) and per-block weight scales (group size 32). + """ + + pass + + +class MxFp4LinearKernel(ABC): + """Base class for MXFP4 quantized linear kernels. + + Each subclass implements a specific GEMM backend (CUTLASS, Marlin, etc). + The kernel selection mechanism iterates over registered subclasses in + priority order,calling ``is_supported`` and ``can_implement`` to find the best + match for the current hardware. + """ + + def __init__(self, config: MxFp4LinearLayerConfig) -> None: + assert self.can_implement(config)[0] + assert self.is_supported()[0] + self.config = config + + @classmethod + @abstractmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + """Return whether this kernel can run on the current platform.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]: + """Return whether this kernel can handle *config*.""" + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Transform weights into the format required by this kernel. + + Called once after checkpoint weights have been loaded onto the + device. Implementations should repack / swizzle / pad weights + and scales in-place on *layer*. + """ + raise NotImplementedError + + @abstractmethod + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + """Run the quantized GEMM.""" + raise NotImplementedError diff --git a/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py b/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py new file mode 100644 index 000000000000..8889986f05b9 --- /dev/null +++ b/vllm/model_executor/kernels/linear/mxfp4/flashinfer.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import ( + swizzle_mxfp4_scales, +) +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer_cutedsl + +from .base import MxFp4LinearKernel, MxFp4LinearLayerConfig + +_MXFP4_GROUP_SIZE = 32 + + +class FlashInferMxFp4LinearKernel(MxFp4LinearKernel): + """MXFP4 W4A4 GEMM via FlashInfer CUTLASS (SM100+).""" + + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if current_platform.has_device_capability(100) and has_flashinfer_cutedsl(): + return True, None + return False, "FlashInfer + >=sm_100 (Blackwell) required" + + @classmethod + def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]: + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + N, scale_K = layer.weight_scale.shape + K = scale_K * _MXFP4_GROUP_SIZE + + # swizzle pads N to the next multiple of 128 for CUTLASS tiling + padded_N = ((N + 127) // 128) * 128 + layer.weight_scale = Parameter( + swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(padded_N, -1), + requires_grad=False, + ) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + from vllm.utils.flashinfer import ( + flashinfer_mxfp4_quantize, + flashinfer_scaled_fp4_mm, + ) + + weight = layer.weight + out_shape = x.shape[:-1] + (layer.output_size_per_partition,) + x_2d = x.reshape(-1, x.shape[-1]) + + x_fp4, x_scale = flashinfer_mxfp4_quantize(x_2d) + out = flashinfer_scaled_fp4_mm( + x_fp4, + weight, + x_scale, + layer.weight_scale, + alpha=None, + out_dtype=x.dtype, + backend="cute-dsl", + block_size=_MXFP4_GROUP_SIZE, + use_nvfp4=False, + ) + + if bias is not None: + out = out + bias + return out.view(out_shape) diff --git a/vllm/model_executor/kernels/linear/mxfp4/marlin.py b/vllm/model_executor/kernels/linear/mxfp4/marlin.py new file mode 100644 index 000000000000..38440752072e --- /dev/null +++ b/vllm/model_executor/kernels/linear/mxfp4/marlin.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from .base import MxFp4LinearKernel, MxFp4LinearLayerConfig + + +class MarlinMxFp4LinearKernel(MxFp4LinearKernel): + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + is_fp4_marlin_supported, + ) + + if is_fp4_marlin_supported(): + return True, None + return False, "Marlin FP4 not available" + + @classmethod + def can_implement(cls, c: MxFp4LinearLayerConfig) -> tuple[bool, str | None]: + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + ) + + prepare_fp4_layer_for_marlin(layer) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + ) + + return apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=None, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 2910e63678fe..6a7e71884ccd 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -42,10 +42,10 @@ CompressedTensors24, CompressedTensorsScheme, CompressedTensorsW4A4Fp4, + CompressedTensorsW4A4Mxfp4, CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Mxfp4, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A8Mxfp8, @@ -625,7 +625,7 @@ def _get_scheme_from_parts( return CompressedTensorsW4A16Fp4() if self._is_mxfp4(weight_quant): - return CompressedTensorsW4A16Mxfp4() + return CompressedTensorsW4A4Mxfp4() if self._is_mxfp8(weight_quant): return CompressedTensorsW8A8Mxfp8() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py index 2ac2e28f20b5..1d4b887e78f5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py @@ -41,6 +41,7 @@ def __init__(self, moe): super().__init__(moe) self.group_size = 32 self.mxfp4_backend = Mxfp4MoeBackend.MARLIN + # use cutlass if supported, otherwise fallback to marlin for weight-only FP4 self.use_cutlass_mxfp4 = CutlassExpertsMxfp4._supports_current_device() self.experts_cls: type[mk.FusedMoEExperts] if self.use_cutlass_mxfp4: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 457794eb0a09..0b0d8c230617 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -2,10 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from .compressed_tensors_scheme import CompressedTensorsScheme +from .compressed_tensors_w4a4_mxfp4 import CompressedTensorsW4A4Mxfp4 from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8 from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int -from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4 from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 @@ -25,7 +25,7 @@ "WNA16_SUPPORTED_BITS", "CompressedTensors24", "CompressedTensorsW4A16Fp4", - "CompressedTensorsW4A16Mxfp4", + "CompressedTensorsW4A4Mxfp4", "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int", "CompressedTensorsW4A8Fp8", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py similarity index 76% rename from vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py rename to vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py index 77cea0f83e1c..7b994e10e448 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxfp4.py @@ -5,24 +5,21 @@ import torch from torch.nn.parameter import Parameter +from vllm.model_executor.kernels.linear import init_mxfp4_linear_kernel from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - apply_fp4_marlin_linear, - prepare_fp4_layer_for_marlin, -) from vllm.model_executor.parameter import ( GroupQuantScaleParameter, ModelWeightParameter, ) -__all__ = ["CompressedTensorsW4A16Mxfp4"] +__all__ = ["CompressedTensorsW4A4Mxfp4"] -class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme): +class CompressedTensorsW4A4Mxfp4(CompressedTensorsScheme): """ - Compressed tensors scheme for MXFP4 weight-only quantization. + Compressed tensors scheme for MXFP4. Supports models quantized with the compressed-tensors mxfp4-pack-quantized format. @@ -31,10 +28,14 @@ class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme): - 4-bit float weights (E2M1) packed into uint8 - Per-group E8M0 scales with group_size=32 - No global scale (unlike NVFP4) + + On SM100+ with FlashInfer: true W4A4 (activations dynamically quantized). + Otherwise: W4A16 weight-only via Marlin. """ def __init__(self): self.group_size = 32 + self.kernel = init_mxfp4_linear_kernel() @classmethod def get_min_capability(cls) -> int: @@ -82,11 +83,9 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # Rename weight_packed to weight that marlin expects layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) del layer.weight_packed - - prepare_fp4_layer_for_marlin(layer) + self.kernel.process_weights_after_loading(layer) def apply_weights( self, @@ -94,13 +93,4 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return apply_fp4_marlin_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - weight_global_scale=None, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias, - ) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 44fcc19c2d2b..54d8645dc1bc 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -20,6 +20,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv logger = init_logger(__name__) @@ -480,6 +481,8 @@ def flashinfer_mm_fp4( dtype: torch.dtype, use_8x4_sf_layout: bool, backend: str, + block_size: int = 16, + use_nvfp4: bool = True, ) -> torch.Tensor: from flashinfer import mm_fp4 as flashinfer_mm_fp4_ @@ -490,8 +493,9 @@ def flashinfer_mm_fp4( B_scale, g_scale, dtype, - block_size=16, + block_size=block_size, use_8x4_sf_layout=use_8x4_sf_layout, + use_nvfp4=use_nvfp4, backend=backend, ) @@ -507,9 +511,36 @@ def flashinfer_mm_fp4_fake( dtype: torch.dtype, use_8x4_sf_layout: bool, backend: str, + block_size: int = 16, + use_nvfp4: bool = True, ) -> torch.Tensor: return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device) + @torch.library.custom_op( + "vllm::flashinfer_mxfp4_quantize", + mutates_args=[], + device_types="cuda", + ) + def flashinfer_mxfp4_quantize( + a: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + from flashinfer import mxfp4_quantize as _mxfp4_quantize + + return _mxfp4_quantize(a) + + @torch.library.register_fake("vllm::flashinfer_mxfp4_quantize") + def flashinfer_mxfp4_quantize_fake( + a: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + m, k = a.shape + sf_vec_size = 32 + padded_m = cdiv(m, 128) * 128 + sf_cols = cdiv(k // sf_vec_size, 4) * 4 + return ( + torch.empty(m, k // 2, dtype=torch.uint8, device=a.device), + torch.empty(padded_m, sf_cols, dtype=torch.uint8, device=a.device), + ) + @torch.library.custom_op( "vllm::bmm_fp8", mutates_args=[], @@ -658,15 +689,20 @@ def flashinfer_scaled_fp4_mm( b: torch.Tensor, block_scale_a: torch.Tensor, block_scale_b: torch.Tensor, - alpha: torch.Tensor, + alpha: torch.Tensor | None, out_dtype: torch.dtype, backend: str, + block_size: int = 16, + use_nvfp4: bool = True, ) -> torch.Tensor: assert a.ndim == 2 and b.ndim == 2 assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 assert a.stride(-1) == 1 and b.stride(-1) == 1 assert a.shape[1] == b.shape[1] + if alpha is None: + alpha = torch.ones(1, dtype=torch.float32, device=a.device) + if backend in ("cutlass", "cudnn"): block_scale_a = block_scale_a.view(torch.uint8) block_scale_b = block_scale_b.view(torch.uint8) @@ -682,6 +718,8 @@ def flashinfer_scaled_fp4_mm( out_dtype, use_8x4_sf_layout=use_8x4_sf_layout, backend=backend, + block_size=block_size, + use_nvfp4=use_nvfp4, ) @@ -904,6 +942,7 @@ def is_flashinfer_cudnn_fp8_prefill_attn_supported() -> bool: "supports_trtllm_attention", "can_use_trtllm_attention", "use_trtllm_attention", + "flashinfer_mxfp4_quantize", "flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp4_mm_out", "flashinfer_scaled_fp8_mm",