diff --git a/docs/features/quantization/modelopt.md b/docs/features/quantization/modelopt.md index 5c846767bc5b..ad417bcb30ae 100644 --- a/docs/features/quantization/modelopt.md +++ b/docs/features/quantization/modelopt.md @@ -17,6 +17,7 @@ following `quantization.quant_algo` values: - `FP8_PER_CHANNEL_PER_TOKEN`: per-channel weight scale and dynamic per-token activation quantization. - `FP8_PB_WO` (ModelOpt may emit `fp8_pb_wo`): block-scaled FP8 weight-only (typically 128×128 blocks). - `NVFP4`: ModelOpt NVFP4 checkpoints (use `quantization="modelopt_fp4"`). +- `MXFP8`: ModelOpt MXFP8 checkpoints (use `quantization="modelopt_mxfp8"`). ## Quantizing HuggingFace Models with PTQ diff --git a/vllm/config/model.py b/vllm/config/model.py index 86b484181803..0c44bfc31141 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -878,6 +878,7 @@ def _verify_quantization(self) -> None: "moe_wna16", "modelopt", "modelopt_fp4", + "modelopt_mxfp8", "petit_nvfp4", # Ensure heavy backends are probed last to avoid unnecessary # imports during override detection (e.g., MXFP4 imports Triton) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 6650367da035..0103fd0430f1 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -477,6 +477,7 @@ def make( "mxfp4", "mxfp6_e3m2", "mxfp6_e2m3", + "mxfp8", } assert not isinstance(weight_dtype, str) or weight_dtype in { "nvfp4", @@ -484,6 +485,7 @@ def make( "mxfp6_e3m2", "mxfp6_e2m3", "int4", + "mxfp8", } if weight_dtype is None: diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 82de32af347d..09e67f562d0c 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -17,6 +17,7 @@ "fp_quant", "modelopt", "modelopt_fp4", + "modelopt_mxfp8", "gguf", "gptq_marlin", "awq_marlin", @@ -119,7 +120,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .gptq import GPTQConfig from .gptq_marlin import GPTQMarlinConfig from .inc import INCConfig - from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config + from .modelopt import ModelOptFp8Config, ModelOptMxFp8Config, ModelOptNvFp4Config from .moe_wna16 import MoeWNA16Config from .mxfp4 import Mxfp4Config from .petit import PetitNvFp4Config @@ -133,6 +134,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "fp_quant": FPQuantConfig, "modelopt": ModelOptFp8Config, "modelopt_fp4": ModelOptNvFp4Config, + "modelopt_mxfp8": ModelOptMxFp8Config, "gguf": GGUFConfig, "gptq_marlin": GPTQMarlinConfig, "awq_marlin": AWQMarlinConfig, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e76c109eceda..23638217510c 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -63,6 +63,13 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( get_marlin_input_dtype, ) +from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( + MXFP8_BLOCK_SIZE, + MXFP8_SCALE_DTYPE, + MXFP8_VALUE_DTYPE, + Mxfp8LinearBackend, + Mxfp8LinearOp, +) from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( apply_nvfp4_linear, convert_to_nvfp4_linear_kernel_format, @@ -103,6 +110,8 @@ "FP8_PB_WO", # FP4 "NVFP4", + # MXFP8 + "MXFP8", ] KV_CACHE_QUANT_ALGOS = ["FP8"] @@ -386,12 +395,12 @@ def override_quantization_method( quant_config = hf_quant_cfg["quantization"] if isinstance(quant_config, dict): quant_algo = str(quant_config.get("quant_algo", "")) - if "FP8" in quant_algo.upper(): + if quant_algo.upper() == "FP8": return "modelopt" else: # Check for compressed-tensors style config with specific quant_algo quant_algo = str(hf_quant_cfg.get("quant_algo", "")) - if "FP8" in quant_algo.upper(): + if quant_algo.upper() == "FP8": return "modelopt" return None @@ -1549,3 +1558,239 @@ def apply( ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod + + +class ModelOptMxFp8Config(ModelOptQuantConfigBase): + """Config class for ModelOpt MXFP8.""" + + def __init__( + self, + is_checkpoint_mxfp8_serialized: bool, + kv_cache_quant_algo: str | None, + exclude_modules: list[str], + ) -> None: + super().__init__(exclude_modules) + self.is_checkpoint_mxfp8_serialized = is_checkpoint_mxfp8_serialized + + if not is_checkpoint_mxfp8_serialized: + raise ValueError( + "MXFP8 quantization requires a serialized checkpoint. " + "Dynamic quantization is not supported." + ) + + logger.warning( + "Detected ModelOpt MXFP8 checkpoint. Please note that " + "the format is experimental and could change in future." + ) + + self.kv_cache_quant_algo = kv_cache_quant_algo + + def get_name(self) -> QuantizationMethods: + return "modelopt_mxfp8" + + def get_supported_act_dtypes(self) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + # MXFP8 hardware acceleration requires Blackwell (SM100) or newer + return 100 + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> "QuantizeMethodBase | None": + # MXFP8 does not yet support MoE models + if isinstance(layer, FusedMoE): + raise NotImplementedError( + "MXFP8 quantization does not yet support MoE models. " + "Please use FP8 or NVFP4 quantization for MoE models." + ) + return super().get_quant_method(layer, prefix) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant + ) -> QuantizationMethods | None: + """Detect if this ModelOpt MXFP8 config should be used based on + quantization config.""" + if hf_quant_cfg is None: + return None + + # Use the community standard 'quant_method' + quant_method = hf_quant_cfg.get("quant_method", "").lower() + + # Only proceed if the method is explicitly "modelopt" + if quant_method != "modelopt": + return None + + # Look for ModelOpt-specific config structure + if "quantization" in hf_quant_cfg: + quant_config = hf_quant_cfg["quantization"] + if isinstance(quant_config, dict): + quant_algo = str(quant_config.get("quant_algo", "")).upper() + if "MXFP8" in quant_algo: + return "modelopt_mxfp8" + else: + # Check for compressed-tensors style config with specific quant_algo + quant_algo = str(hf_quant_cfg.get("quant_algo", "")).upper() + if "MXFP8" in quant_algo: + return "modelopt_mxfp8" + + return None + + @classmethod + def _from_config( + cls, + *, + quant_method: str, + kv_cache_quant_method: str | None, + exclude_modules: list[str], + original_config: dict[str, Any], + **kwargs: Any, + ) -> "ModelOptMxFp8Config": + is_checkpoint_mxfp8_serialized = "MXFP8" in quant_method.upper() + + # For MXFP8, validate required fields in the config + if is_checkpoint_mxfp8_serialized and "quantization" in original_config: + quant_config = original_config["quantization"] + required_fields = ["kv_cache_quant_algo", "exclude_modules"] + missing_fields = [ + field for field in required_fields if field not in quant_config + ] + if missing_fields: + raise ValueError( + f"MXFP8 quantization requires the following fields in " + f"hf_quant_config.json: {missing_fields}" + ) + + return cls( + is_checkpoint_mxfp8_serialized, + kv_cache_quant_method, + exclude_modules, + ) + + +class ModelOptMxFp8LinearMethod(LinearMethodBase): + """Linear method for ModelOpt MXFP8 quantization.""" + + def __init__(self, quant_config: ModelOptMxFp8Config) -> None: + self.quant_config = quant_config + + if not self.quant_config.is_checkpoint_mxfp8_serialized: + raise ValueError( + "MXFP8 currently only supports serialized checkpoints. " + "Dynamic quantization is not supported." + ) + + backend: Mxfp8LinearBackend = Mxfp8LinearBackend.EMULATION + self.mxfp8_linear_op = Mxfp8LinearOp(backend=backend) + logger.info_once("Using %s backend for MXFP8 GEMM", backend.value) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + + if not self.quant_config.is_checkpoint_mxfp8_serialized: + raise ValueError( + "MXFP8 quantization was selected, but checkpoint is not " + "MXFP8 serialized. Dynamic quantization is not supported." + ) + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + if input_size_per_partition % MXFP8_BLOCK_SIZE != 0: + raise ValueError( + f"MXFP8 requires input dimension to be divisible by " + f"{MXFP8_BLOCK_SIZE}, got {input_size_per_partition}" + ) + + # Weight tensor: FP8 E4M3 format + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=MXFP8_VALUE_DTYPE, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # Weight scale tensor (E8M0 encoded as uint8), one scale per block of 32 along K + weight_scale = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // MXFP8_BLOCK_SIZE, + dtype=MXFP8_SCALE_DTYPE, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if layer.weight.ndim != 2: + raise ValueError( + f"MXFP8 weight must be 2D tensor [N, K], got {layer.weight.ndim}D " + f"with shape {tuple(layer.weight.shape)}" + ) + + if layer.weight.dtype != MXFP8_VALUE_DTYPE: + raise ValueError( + f"MXFP8 weight must be {MXFP8_VALUE_DTYPE} (FP8 E4M3), " + f"got {layer.weight.dtype}. The checkpoint may not be properly " + f"quantized with MXFP8." + ) + + weight = layer.weight.data # [N, K] + N, K = weight.shape + scale_k = K // MXFP8_BLOCK_SIZE + + # Slice weight_scale to match weight dimensions (handles padding) + weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous() + + layer.weight = Parameter(weight.contiguous(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + if layer.weight.dtype != MXFP8_VALUE_DTYPE: + raise ValueError( + f"Weight dtype {layer.weight.dtype} != expected {MXFP8_VALUE_DTYPE}" + ) + if layer.weight_scale.dtype != MXFP8_SCALE_DTYPE: + raise ValueError( + f"Weight scale dtype {layer.weight_scale.dtype} != " + f"expected {MXFP8_SCALE_DTYPE}" + ) + + return self.mxfp8_linear_op.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=x.dtype, + bias=bias, + ) + + +# Register the method classes for ModelOptMxFp8Config +ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod +ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py index bed771fd1c4d..9f0e0c0a4d8e 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py @@ -1,24 +1,134 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum + import torch from vllm.logger import init_logger +from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) -def mxfp8_e4m3_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - try: - from flashinfer import mxfp8_quantize as mxfp8_e4m3_quantize - except ImportError as err: - raise ImportError( - "The package `flashinfer` is required to do " - "MX-FP8 quantization. Please install it with" - "`pip install flashinfer`" - ) from err +class Mxfp8LinearBackend(Enum): + EMULATION = "emulation" + + +# MXFP8 constants +MXFP8_VALUE_DTYPE = torch.float8_e4m3fn +MXFP8_SCALE_DTYPE = torch.uint8 +MXFP8_BLOCK_SIZE = 32 + - x_q, x_scales = mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False) - if x_scales.ndim == 1: +def _mxfp8_e4m3_quantize_impl( + x: torch.Tensor, is_sf_swizzled_layout: bool = False +) -> tuple[torch.Tensor, torch.Tensor]: + from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize + + x_q, x_scales = flashinfer_mxfp8_quantize( + x, is_sf_swizzled_layout=is_sf_swizzled_layout + ) + if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout: x_scales = x_scales.view(x.size(0), -1) return x_q, x_scales + + +def mxfp8_e4m3_quantize( + x: torch.Tensor, is_sf_swizzled_layout: bool = False +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops.vllm.mxfp8_quantize(x, is_sf_swizzled_layout) + + +def dequant_mxfp8_to_bf16(x: torch.Tensor, scales: torch.Tensor) -> torch.Tensor: + """Dequantize MXFP8 tensor to BF16.""" + x_float = x.to(torch.float32) + + num_blocks = x.shape[-1] // MXFP8_BLOCK_SIZE + x_blocked = x_float.view(*x.shape[:-1], num_blocks, MXFP8_BLOCK_SIZE) + + descale = torch.exp2(scales.to(torch.float32) - 127.0) + + dequantized = x_blocked * descale.unsqueeze(-1) + + dequantized = dequantized.view(*x.shape) + + return dequantized.to(torch.bfloat16) + + +def mxfp8_e4m3_quantize_fake( + x: torch.Tensor, is_sf_swizzled_layout: bool = False +) -> tuple[torch.Tensor, torch.Tensor]: + """Fake implementation for torch.compile tracing.""" + fp_data = torch.empty_like(x, dtype=MXFP8_VALUE_DTYPE) + + block_size = MXFP8_BLOCK_SIZE + + if x.ndim == 2: + M, N = x.shape + K = (N + block_size - 1) // block_size + if is_sf_swizzled_layout: + M_padded = ((M + 127) // 128) * 128 + K_padded = ((K + 3) // 4) * 4 + scales = torch.empty( + M_padded * K_padded, dtype=MXFP8_SCALE_DTYPE, device=x.device + ) + else: + scales = torch.empty((M, K), dtype=MXFP8_SCALE_DTYPE, device=x.device) + elif x.ndim == 3: + B, M, N = x.shape + K = (N + block_size - 1) // block_size + if is_sf_swizzled_layout: + M_padded = ((M + 127) // 128) * 128 + K_padded = ((K + 3) // 4) * 4 + scales = torch.empty( + B * M_padded * K_padded, dtype=MXFP8_SCALE_DTYPE, device=x.device + ) + else: + scales = torch.empty((B, M, K), dtype=MXFP8_SCALE_DTYPE, device=x.device) + else: + scale_shape = list(x.shape) + scale_shape[-1] = (x.shape[-1] + block_size - 1) // block_size + scales = torch.empty(scale_shape, dtype=MXFP8_SCALE_DTYPE, device=x.device) + + return fp_data, scales + + +direct_register_custom_op( + op_name="mxfp8_quantize", + op_func=_mxfp8_e4m3_quantize_impl, + fake_impl=mxfp8_e4m3_quantize_fake, +) + + +class Mxfp8LinearOp: + def __init__(self, backend: Mxfp8LinearBackend): + if backend not in Mxfp8LinearBackend: + raise ValueError(f"Unsupported backend: {backend}") + + self.backend = backend + + def apply( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: torch.dtype, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + # Validate weight_scale dtype and shape (must be 2D for TORCH backend) + if weight_scale.dtype != MXFP8_SCALE_DTYPE: + raise ValueError( + f"TORCH backend requires {MXFP8_SCALE_DTYPE} weight_scale dtype, " + f"got {weight_scale.dtype}." + ) + if weight_scale.ndim != 2: + raise ValueError( + f"TORCH backend requires 2D weight_scale, got {weight_scale.ndim}D. " + f"Ensure process_weights_after_loading was called." + ) + + weight_bf16 = dequant_mxfp8_to_bf16(weight, weight_scale) + + output = torch.nn.functional.linear(input, weight_bf16, bias) + return output.to(out_dtype)